Solving the heat equation with tensorflow

I’m new to using TensorFlow.
I’m trying to take advantage of the parallelization performances to solve PDEs.
I wrote the attached code to solve a simple heat equation on a cubic domain, using finite differences and then tested the execution with CPU and GPU.

I found the GPU very fast for time iterations (16 seconds vs 64 seconds using CPU). However, the initialisation time is enormous if compared with the GPU one (and thus producing a total execution time that is larger than the CPU one); indeed:

CPU (Tensorflow version is: 2.6.0):
initialisation of DXYZ, elapsed: 0.003410 sec
U variable, elapsed: 0.001667 sec
s2 tensor, elapsed: 0.000479 sec

total initialization: 0.005556 sec
solution, elapsed: 64.972245 sec
TOTAL, elapsed: 64.977802 sec

GPU(Tensorflow version is: 2.4.1):
initialisation of DXYZ, elapsed: 428.317861 sec
U variable, elapsed: 0.199900 sec
s2 tensor, elapsed: 0.000174 sec

total initialization: 428.517935 sec
solution, elapsed: 16.511286 sec
TOTAL, elapsed: 445.029221 sec

My suspect is this depends on the graph construction (the first time Tensorflow is invoked is to build 3 tf. constant of size 1 and this operation took the vast majority of the time)
Is this a normal behaviour or does this depends on how I installed TensorFlow? I installed both versions with anaconda: install tensorflow using anaconda.

Is there a way to speed up the process?

Thanks,

Cesare

import numpy as np
import time
import tensorflow as tf
from tensorflow.python.client import timeline
tf.config.run_functions_eagerly(True)

if(tf.config.list_physical_devices('GPU')):
      print('GPU device' )
else:
      print('CPU device' )
 
  
  
print('Tensorflow version is: {0}'.format(tf.__version__))
  


class ResultWriter:
    def __init__(self, config={}):
        self.width        = 1
        self.height       = 1
        self.depth        = 1
        self.samples      = 1
        self.dt_per_plot  = 1
        self.not_saved    = True
        self.prefix_name  = 'cube3D'

        if(len(config)>0):
            for key, val in config.items():
                setattr(self, key, val)
        n       = int(self.samples//self.dt_per_plot)
        self.cube    = np.zeros([n, self.height, self.width, self.depth], dtype=np.float32)
        self.counter = 0

    def imshow(self,VolData):
        self.cube[self.counter,:,:,:] = VolData
        self.counter = self.counter + 1

    def wait(self):
        if self.not_saved:
            self.save()
        self.not_saved = False
        for x in [0,1,2]:
            pass

    def save(self):
        fname = '{0}_{1}_{2}_{3}'.format(self.prefix_name,self.height,self.width,self.depth)
        print('saving file {0}'.format(fname))
        np.save(fname, self.cube)

    def __del__(self):
        if self.not_saved:
            self.save()



@tf.function
def laplace(X0,DX,DY,DZ):
    padmode = 'symmetric'
    paddings = tf.constant([[1,1], [1,1], [1,1]])
    X = tf.pad(X0, paddings=paddings, mode=padmode) 
    dxsq = DX*DX
    dysq = DY*DY
    dzsq = DZ*DZ
    lapla = ((X[0:-2,1:-1,1:-1]  -2.0*X[1:-1,1:-1,1:-1] + X[2:,1:-1,1:-1])/dxsq 
          +  (X[1:-1,0:-2,1:-1]  -2.0*X[1:-1,1:-1,1:-1] + X[1:-1,2:,1:-1])/dysq   
          +  (X[1:-1,1:-1,0:-2]  -2.0*X[1:-1,1:-1,1:-1] + X[1:-1,1:-1,2:])/dzsq )
    return lapla

class HeatEquation:
    """
    The heat equation model
    """

    def __init__(self, props):
        self.dx    = 1.0
        self.dy    = 1.0
        self.dz    = 1.0
        for key, val in config.items():
            setattr(self, key, val)

        then = time.time()
        self.DX    = tf.constant(self.dx)
        self.DY    = tf.constant(self.dy)
        self.DZ    = tf.constant(self.dz)
        elapsed = (time.time() - then)
        tf.print('initialisation of DXYZ, elapsed: %f sec' % elapsed)
        self.tinit = elapsed

    @tf.function
    def run(self, im=None):
        # the initial value of the variable
        u_init = np.full([self.height, self.width,self.depth], 0.0, dtype=np.float32)
        u_init[:,1,:] = 1.0

        #define a source that is triggered at t=s2_time: : vertical (2D) along the left face
        s2_init = np.full([self.height, self.width,self.depth], 0.0, dtype=np.float32)
        s2_init[:self.height//2, :self.width//2,:] = 1.0

        then = time.time()
        U = tf.Variable(u_init, name="U" )
        elapsed = (time.time() - then)
        tf.print('U variable, elapsed: %f sec' % elapsed)
        self.tinit = self.tinit + elapsed
        
        then = time.time()
        s2 = tf.constant(s2_init,name="s2")
        elapsed = (time.time() - then)
        tf.print('s2 tensor, elapsed: %f sec' % elapsed)
        self.tinit = self.tinit + elapsed
        tf.print('total initialization: %f sec' % self.tinit)
        
        u_init=[]
        s2_init=[]
        padmode = 'symmetric'
        paddings = tf.constant([[1,1], [1,1], [1,1]])
        
        then = time.time()
        for i in tf.range(self.samples):
            U0 = tf.pad(U[1:-1,1:-1,1:-1], paddings=paddings,mode=padmode,name='boundary' ) 
            U1 = U0 + self.diff * self.dt * laplace(U0,self.DX,self.DY,self.DZ)
            U = U1

            if i == int(self.s2_time / self.dt):
                U = tf.maximum(U, s2)
            # draw a frame every 1 ms
            if im and i % self.dt_per_plot == 0:
                image = U.numpy()
                im.imshow(image)
        elapsed = (time.time() - then)
        print('solution, elapsed: %f sec' % elapsed)
        print('TOTAL, elapsed: %f sec' % (elapsed+self.tinit))
        
        if im:
            im.wait()   # wait until the window is closed






#######################################################################################

if __name__ == '__main__':
    print('=======================================================================')
    config = {
        'width': 64,
        'height': 64,
        'depth': 64,
        'dt': 0.1,
        'dt_per_plot' : 10,
        'diff': 1.5,
        'samples': 10000,
        's2_time': 210
    }
    
    print('config:')
    for key,value in config.items():
        print('{0}\t{1}'.format(key,value))
    
    print('=======================================================================')
    model = HeatEquation(config)
    im = ResultWriter(config)
    model.run(im)
    im = None