Skip to content

Commit 03d4dc0

Browse files
authored
Add default args for _aten_conv2d (#9623)
Add default args for _aten_conv2d, which would otherwise fail in the following code snippet ```python import torch from torch.export import export_for_training import torchax from torchax import interop from torch.utils import _pytree as pytree import jax from torchax.ops import mappings class Simple(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, bias=False) def forward(self, x): x = self.conv1(x) return x model = Simple() exported = export_for_training(model, (torch.randn(1, 3, 224, 224),)) def make_shape_struct(x): return jax.ShapeDtypeStruct(x.shape, mappings.t2j_dtype(x.dtype)) def map_nth(v, i): def f(t): if isinstance(t, torch.Tensor): return t[i : i + 1] return t return pytree.tree_map(f, v) env = torchax.default_env() with env: model = exported.module().to("jax") def func_to_export(x): # hard code weights in model return model(x) example_inputs_jax = pytree.tree_map_only( torch.Tensor, lambda x: x.to("jax"), map_nth(exported.example_inputs, 0) ) res = jax.jit(interop.jax_view(func_to_export)).lower(*example_inputs_jax[0]) # TypeError: _aten_conv2d() missing 5 required positional arguments: 'bias', 'stride', 'padding', 'dilation', and 'groups' ``` cc @qihqi
1 parent 0fc62aa commit 03d4dc0

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchax/torchax/ops/jaten.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,11 +1017,11 @@ def _aten_bucketize(input,
10171017
def _aten_conv2d(
10181018
input,
10191019
weight,
1020-
bias,
1021-
stride,
1022-
padding,
1023-
dilation,
1024-
groups,
1020+
bias=None,
1021+
stride=[1, 1],
1022+
padding=[0, 0],
1023+
dilation=[1, 1],
1024+
groups=1,
10251025
):
10261026
return _aten_convolution(
10271027
input,

0 commit comments

Comments
 (0)