Skip to content

Commit adc4781

Browse files
committed
added 3D support for pi_init in RGAKAN
1 parent 15ccbff commit adc4781

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

jaxkan/models/RGAKAN.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,29 @@ def _pi_init(self, ref):
258258
>>> C = model._pi_init(ref)
259259
"""
260260
# Get collocation points for the spatiotemporal domain to impose initial condition
261-
t = ref['t'].flatten()[::10] # Downsampled temporal - shape (Nt, )
262-
x = ref['x'].flatten() # spatial - shape (Nx, )
263-
tt, xx = jnp.meshgrid(t, x, indexing="ij")
264-
265-
# collocation inputs - shape (batch, 2), batch = Nt*Nx
266-
inputs = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None]])
267-
268-
# Get Y for inputs
269-
u_0 = ref['usol'][0, :] # initial condition - shape (Nx, )
270-
Y = jnp.tile(u_0.flatten(), (t.shape[0], 1)) # shape (Nt, Nx)
271-
Y = Y.flatten().reshape(-1, 1) # shape (batch, 1)
261+
t = ref['t'].flatten()[::10] # Downsampled temporal - shape (Nt, )
262+
263+
# Check if we have 3D data (t, x, y) or 2D data (t, x)
264+
if 'y' in ref:
265+
downsample = 10
266+
x = ref['x'].flatten()[::downsample] # spatial - shape (Nx, )
267+
y = ref['y'].flatten()[::downsample] # shape (Ny, )
268+
tt, xx, yy = jnp.meshgrid(t, x, y, indexing="ij")
269+
# collocation inputs - shape (batch, 3), batch = Nt*Nx*Ny
270+
inputs = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None], yy.flatten()[:, None]])
271+
# Get Y for inputs - initial condition at t=0
272+
u_0 = ref['usol'][0, ::downsample, ::downsample] # shape (Nx, Ny)
273+
Y = jnp.tile(u_0.flatten(), (t.shape[0], 1)) # shape (Nt, Nx*Ny)
274+
Y = Y.flatten().reshape(-1, 1) # shape (batch, 1)
275+
else:
276+
x = ref['x'].flatten() # spatial - shape (Nx, )
277+
tt, xx = jnp.meshgrid(t, x, indexing="ij")
278+
# collocation inputs - shape (batch, 2), batch = Nt*Nx
279+
inputs = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None]])
280+
# Get Y for inputs
281+
u_0 = ref['usol'][0, :] # initial condition - shape (Nx, )
282+
Y = jnp.tile(u_0.flatten(), (t.shape[0], 1)) # shape (Nt, Nx)
283+
Y = Y.flatten().reshape(-1, 1) # shape (batch, 1)
272284

273285
# Get Φ - essentially do a full forward pass up until the final layer
274286
if self.PE:

0 commit comments

Comments
 (0)