Python: Solving ODEs with Deep Learning in TensorFlow

This post explains how to implement TensorFlow code to solve ODE problems. This approach is known as 'Physics-Informed Neural Networks' (PINNs) or 'Neural ODEs.'


Neural ODEs in TensorFlow


The Brothers Karamazov


I borrowed content from https://i-systems.github.io/tutorial/KSNVE/220525/01_PINN.html, and I made some modifications to it.


Solving ODE with Neural Networks


Given an initial condition \(u(0)=u_0\), consider a system of ordinary differential equations for \(t\in[0,1]\) \[\begin{align} {u^\prime} = f(u,t) \end{align}\] To solve this, we approximate the solution by a neural network: \[\begin{align} \text{NN}(t) \approx u(t) \end{align}\] The loss function consists of two parts: minimizing the error and handling the initial condition \[\begin{align} {L(\omega)} = \sum_{i} \left(\frac{d \text{NN}(t_i)}{dt}-f(\text{NN}(t_i),t_i)\right)^2 + (\text{NN}(0)-u_0)^2 \end{align}\] where \(\omega\) are the parameters that define the neural network NN that approximates u .


Example ODE


Consider the following ODE \[\begin{align} \frac{du}{dt} = \text{cos} 2 \pi t, u(0) = 1 \\ \end{align}\] The exact solution is \(u(t) = \frac{1}{2\pi}\sin2\pi t + 1\).

The folloiwng figure outlines the calculation method.


An important point is that the input variable corresponds to the time t.


Python Jupyter Notebook Code


The ODE problem mentioned can be resolved by employing MLP layers with a tanh activation function. It's important to emphasize that the parameter learning process involves the direct application of the gradient descent method due to handling derivatives. This differs from the conventional use of the compile() and fit() methods, typically aimed at minimizing only the mean squared error (MSE).

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random
 
# NN(t) = u(t)
NN = tf.keras.models.Sequential([
    tf.keras.layers.Input((1,)),
    tf.keras.layers.Dense(units = 100, activation = 'tanh'),
    tf.keras.layers.Dense(units = 100, activation = 'tanh'),
    tf.keras.layers.Dense(units = 100, activation = 'tanh'),
    tf.keras.layers.Dense(units = 1)
])
 
NN.summary()
optm = tf.keras.optimizers.Adam(learning_rate = 0.001)
 
# ODE loss
def ode_system(t, net):
    t = t.reshape(-1,1)
    t = tf.constant(t, dtype = tf.float32)    
    t_0 = tf.zeros((1,1))
    one = tf.ones((1,1))   
    
    # u(t) and du/dt
    with tf.GradientTape() as tape:
        
        # tape.watch(t) is used to explicitly declare a tensor t 
        # to be watched by a gradient tape. 
        # This is used in scenarios where you want to compute gradients 
        # with respect to a specific tensor t. When you watch a tensor, 
        # it allows TensorFlow's automatic differentiation system 
        # to keep track of operations involving t so that gradients 
        # can be computed during a subsequent call to tape.gradient.
        tape.watch(t)
             
        # u(t): Approximate NN solution
        u = net(t)
        
        # Compute the gradient of 'u' with respect to 't'
        u_t = tape.gradient(u, t)
    
    # two losses
    ode_loss = u_t - tf.math.cos(2*np.pi*t) 
    IC_loss = net(t_0) - one
    
    # total loss
    mean_loss_ode = tf.reduce_mean(tf.square(ode_loss))
    scalr_loss_IC = tf.reduce_mean(tf.square(IC_loss))
    total_loss = mean_loss_ode + scalr_loss_IC
 
    return total_loss
 
# training data (time)
train_t = (np.random.rand(30)*2).reshape(-11)
train_loss_record = []
 
# learning parameters using the gradient decent 
for itr in range(3000):
    with tf.GradientTape() as tape:
        
        # training ODE loss
        train_loss = ode_system(train_t, NN)
        train_loss_record.append(train_loss)
           
        # gradient with respect to parameters
        grad_w = tape.gradient(train_loss, NN.trainable_variables)
        
        # update parameters
        optm.apply_gradients(zip(grad_w, NN.trainable_variables))
    
    if itr % 1000 == 0:
        print(train_loss.numpy())
        
cs


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None100)               200       
                                                                 
 dense_1 (Dense)             (None100)               10100     
                                                                 
 dense_2 (Dense)             (None100)               10100     
                                                                 
 dense_3 (Dense)             (None1)                 101       
                                                                 
=================================================================
Total params: 20,501
Trainable params: 20,501
Non-trainable params: 0
_________________________________________________________________
1.4987797
0.00016946398
0.00011333717
 
cs


When we plot the training loss and the final solution using test data, it becomes evident that the approximate neural network solution closely matches the exact solution.

# draw training loss
plt.figure(figsize = (6,5))
 
plt.subplot(211)  # 1 row, 2 columns, 1st subplot
plt.plot(train_loss_record)
plt.title('Training loss')
 
# compare the approximate NN solution with the exact one
test_t = np.linspace(02100)
 
train_u = np.sin(2*np.pi*train_t)/(2*np.pi) + 1
true_u = np.sin(2*np.pi*test_t)/(2*np.pi) + 1
pred_u = NN.predict(test_t).ravel()
 
#plt.figure(figsize = (6,3))
plt.subplot(212)  # 1 row, 2 columns, 1st subplot
plt.plot(train_t, train_u, 'ok', label = 'Train')
plt.plot(test_t, true_u, '-k',label = 'True')
plt.plot(test_t, pred_u, '--r', label = 'Pred')
plt.legend(fontsize = 11, frameon=False)
plt.xlabel('t', fontsize = 15)
plt.ylabel('u', fontsize = 15)
plt.title('Comparison: NN and exact solution')
 
plt.tight_layout()  # Helps prevent overlapping labels and titles
plt.show()
 
cs

It is believed that this is a valuable technique for learning solutions to differential equations directly from data. I also believe that this approach can be applied to a variety of financial problems involving ODE components.


No comments:

Post a Comment