@@ -258,17 +258,29 @@ def _pi_init(self, ref):
258258 >>> C = model._pi_init(ref)
259259 """
260260 # Get collocation points for the spatiotemporal domain to impose initial condition
261- t = ref ['t' ].flatten ()[::10 ] # Downsampled temporal - shape (Nt, )
262- x = ref ['x' ].flatten () # spatial - shape (Nx, )
263- tt , xx = jnp .meshgrid (t , x , indexing = "ij" )
264-
265- # collocation inputs - shape (batch, 2), batch = Nt*Nx
266- inputs = jnp .hstack ([tt .flatten ()[:, None ], xx .flatten ()[:, None ]])
267-
268- # Get Y for inputs
269- u_0 = ref ['usol' ][0 , :] # initial condition - shape (Nx, )
270- Y = jnp .tile (u_0 .flatten (), (t .shape [0 ], 1 )) # shape (Nt, Nx)
271- Y = Y .flatten ().reshape (- 1 , 1 ) # shape (batch, 1)
261+ t = ref ['t' ].flatten ()[::10 ] # Downsampled temporal - shape (Nt, )
262+
263+ # Check if we have 3D data (t, x, y) or 2D data (t, x)
264+ if 'y' in ref :
265+ downsample = 10
266+ x = ref ['x' ].flatten ()[::downsample ] # spatial - shape (Nx, )
267+ y = ref ['y' ].flatten ()[::downsample ] # shape (Ny, )
268+ tt , xx , yy = jnp .meshgrid (t , x , y , indexing = "ij" )
269+ # collocation inputs - shape (batch, 3), batch = Nt*Nx*Ny
270+ inputs = jnp .hstack ([tt .flatten ()[:, None ], xx .flatten ()[:, None ], yy .flatten ()[:, None ]])
271+ # Get Y for inputs - initial condition at t=0
272+ u_0 = ref ['usol' ][0 , ::downsample , ::downsample ] # shape (Nx, Ny)
273+ Y = jnp .tile (u_0 .flatten (), (t .shape [0 ], 1 )) # shape (Nt, Nx*Ny)
274+ Y = Y .flatten ().reshape (- 1 , 1 ) # shape (batch, 1)
275+ else :
276+ x = ref ['x' ].flatten () # spatial - shape (Nx, )
277+ tt , xx = jnp .meshgrid (t , x , indexing = "ij" )
278+ # collocation inputs - shape (batch, 2), batch = Nt*Nx
279+ inputs = jnp .hstack ([tt .flatten ()[:, None ], xx .flatten ()[:, None ]])
280+ # Get Y for inputs
281+ u_0 = ref ['usol' ][0 , :] # initial condition - shape (Nx, )
282+ Y = jnp .tile (u_0 .flatten (), (t .shape [0 ], 1 )) # shape (Nt, Nx)
283+ Y = Y .flatten ().reshape (- 1 , 1 ) # shape (batch, 1)
272284
273285 # Get Φ - essentially do a full forward pass up until the final layer
274286 if self .PE :
0 commit comments