Skip to content

Commit 09a0be2

Browse files
committed
fix some typing hints
1 parent 0a02b25 commit 09a0be2

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

torch_cfd/equations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def implicit_solve(self, vort_hat, dt):
450450
def step(self, *args, **kwargs):
451451
return self.forward(*args, **kwargs)
452452

453-
def forward(self, vort_hat, dt, steps=1):
453+
def forward(self, vort_hat, dt, steps=1) -> Tuple[torch.Tensor, torch.Tensor]:
454454
"""
455455
vort_hat: (B, kx, ky) or (n_t, kx, ky) or (kx, ky)
456456
- if rfft2 is used then the shape is (*, nx, ny//2+1)

torch_cfd/grids.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def clone(self, *args, **kwargs):
285285
def to(self, *args, **kwargs):
286286
return GridArray(self.data.to(*args, **kwargs), self.offset, self.grid)
287287

288-
_HANDLED_TYPES = (numbers.Number, Array)
288+
_HANDLED_TYPES = (numbers.Number, torch.Tensor)
289289

290290
@classmethod
291291
def __torch_function__(self, ufunc, types, args=(), kwargs=None):
@@ -483,7 +483,7 @@ def enforce_edge_bc(self, *args) -> GridVariable:
483483
GridVariableVector = Tuple[GridVariable, ...]
484484

485485

486-
class GridArrayTensor(Array):
486+
class GridArrayTensor(torch.Tensor):
487487
"""A numpy array of GridArrays, representing a physical tensor field.
488488
489489
Packing tensor coordinates into a numpy array of dtype object is useful

torch_cfd/initial_conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def spectral_filter(
9393
v: torch.Tensor,
9494
grid: grids.Grid,
9595
) -> torch.Tensor:
96-
"""Filter an Array with white noise to match a prescribed spectral density."""
96+
"""Filter a torch.Tensor with white noise to match a prescribed spectral density."""
9797
k = _angular_frequency_magnitude(grid)
9898
filters = torch.where(k > 0, spectral_density(k), 0.0)
9999
# The output signal can safely be assumed to be real if our input signal was

0 commit comments

Comments
 (0)