Skip to content

Commit d02e4bf

Browse files
committed
from . imports
1 parent 09a0be2 commit d02e4bf

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

torch_cfd/finite_differences.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import typing
3838
from typing import Optional, Sequence, Tuple
39-
from . import grids
39+
from torch_cfd import grids
4040
import numpy as np
4141
import torch
4242

torch_cfd/forcings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222

23-
from . import grids
23+
from torch_cfd import grids
2424

2525
Grid = grids.Grid
2626
GridArray = grids.GridArray

torch_cfd/initial_conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.fft as fft
2424

25-
from . import grids, pressure
25+
from torch_cfd import grids, pressure
2626

2727
GridArray = grids.GridArray
2828
GridArrayVector = grids.GridArrayVector

torch_cfd/pressure.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import torch
2323

24-
from . import grids, fast_diagonalization as solver, finite_differences as fd
24+
from torch_cfd import grids
25+
from torch_cfd import fast_diagonalization as solver
26+
from torch_cfd import finite_differences as fdm
2527

2628

2729
Array = grids.Array
@@ -91,8 +93,8 @@ def solve_fast_diag(
9193
# only matmul implementation supports non-circulant matrices
9294
implementation = "matmul"
9395
grid = grids.consistent_grid(*v)
94-
rhs = fd.divergence(v)
95-
laplacians = list(map(fd.laplacian_matrix, grid.shape, grid.step))
96+
rhs = fdm.divergence(v)
97+
laplacians = list(map(fdm.laplacian_matrix, grid.shape, grid.step))
9698
laplacians = [lap.to(grid.device) for lap in laplacians]
9799
rhs_transformed = _rhs_transform(rhs, pressure_bc)
98100
pinv = solver.pseudoinverse(
@@ -126,6 +128,6 @@ def projection(
126128

127129
q = solve(v, q0, pressure_bc)
128130
q = pressure_bc.impose_bc(q)
129-
q_grad = fd.forward_difference(q)
131+
q_grad = fdm.forward_difference(q)
130132
v_projected = tuple(u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad))
131133
return v_projected

0 commit comments

Comments
 (0)