Skip to content

Commit 883d61c

Browse files
committed
change input of NS2dSpectral to initialized class
1 parent 28cc4a9 commit 883d61c

File tree

4 files changed

+6
-9
lines changed

4 files changed

+6
-9
lines changed

fno/data_gen/data_gen_Kolmogorov2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def main(args):
129129
drag=0.1,
130130
smooth=True,
131131
forcing_fn=forcing_fn,
132-
solver=RK4CrankNicolsonStepper,
132+
solver=RK4CrankNicolsonStepper(),
133133
).to(device)
134134

135135
num_batches = total_samples // batch_size

fno/data_gen/data_gen_McWilliams2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def main(args):
113113
drag=0,
114114
smooth=True,
115115
forcing_fn=None,
116-
solver=RK4CrankNicolsonStepper,
116+
solver=RK4CrankNicolsonStepper(),
117117
).to(device)
118118

119119
num_batches = total_samples // batch_size

fno/data_gen/data_gen_fno.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,12 @@ def main(args):
168168
device=device,
169169
dtype=torch.float64,
170170
)
171-
171+
step_fn = IMEXStepper(order=2)
172172
ns2d = NavierStokes2DSpectral(
173173
viscosity=visc,
174174
grid=grid,
175175
smooth=True,
176176
forcing_fn=forcing_fn,
177-
solver=IMEXStepper,
178-
order=2,
179177
).to(device)
180178

181179
if os.path.exists(data_filepath) and not force_rerun:

torch_cfd/equations.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,17 +379,16 @@ def __init__(
379379
drag: float = 0.0,
380380
smooth: bool = True,
381381
forcing_fn: Optional[Callable] = None,
382-
solver: Optional[Callable] = RK4CrankNicolsonStepper,
383-
requires_grad: bool = False,
384-
**solver_kwargs,
382+
solver: IMEXStepper = None,
383+
**kwargs,
385384
):
386385
super().__init__()
387386
self.viscosity = viscosity
388387
self.grid = grid
389388
self.drag = drag
390389
self.smooth = smooth
391390
self.forcing_fn = forcing_fn
392-
self.solver = solver(requires_grad=requires_grad, **solver_kwargs)
391+
self.solver = solver
393392
self._initialize()
394393

395394
def _initialize(self):

0 commit comments

Comments
 (0)