Skip to content

Convolution over feature dim #134

@vieting

Description

@vieting

I have a case where the convolution is done over a dim, that RETURNN considers as feature dim. However, in pytorch a new dim is created before the convolution and this is now supposed to be the feature dim. Of course, RETURNN cannot directly know this. But I think in the case that a convolution is done over the feature dim and another static dim exists, we can argue that we should consider that other dim as feature dim and do the convolution. What do you think, @albertz?

A test to show the error:

def test_vgg():
  n_batch, n_time, n_feat = 3, 20, 5  # B, T, F

  def model_func(wrapped_import, inputs: torch.Tensor):
    if typing.TYPE_CHECKING or not wrapped_import:
      import torch
    else:
      torch = wrapped_import("torch")

    class VggBlock(torch.nn.Module):
      def __init__(self, n_in, n_out, kernel_size, pool_size=None, stride: typing.Union[int, typing.Tuple] = 1):
        super().__init__()
        self.conv = torch.nn.Conv2d(n_in, n_out, kernel_size, stride=stride)
        self.activation = torch.nn.SiLU()
        self.pooling = torch.nn.MaxPool2d(pool_size) if pool_size is not None else None

      def forward(self, x):
        # ignore padding here
        x = self.conv(x)
        x = self.activation(x)
        if self.pooling is not None:
          x = self.pooling(x)
        return x

    x = inputs.unsqueeze(1)  # (B, 1, T, F)
    # VGG block:
    vgg_blocks = torch.nn.Sequential(
      VggBlock(1, 32, (3, 3)),
    )
    x = vgg_blocks(x)
    x = x.transpose(2, 3).flatten(1, 2)  # (B, F, T)
    return x

  rnd = numpy.random.RandomState(42)
  x = rnd.normal(0., 1., (n_batch, n_time, n_feat)).astype("float32")
  converter = verify_torch_and_convert_to_returnn(
    model_func, inputs=x, returnn_dummy_input_shape=x.shape, validate_allclose_kwargs=dict(rtol=0, atol=5e-3),
    inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2})

Error message:

...
  File ".../returnn/returnn/tf/layers/basic.py", line 5416, in ConvLayer.transform_input
    line: assert input_data.feature_dim_axis not in axes, "invalid in_spatial_dims %s" % (in_spatial_dims,)
    locals:
      input_data = <local> Data{'Unflatten_output', [B,'Unflatten_split_dims0'(1),T|'time:data'[B],F|F'feature:data'(5)]}
      input_data.feature_dim_axis = <local> 3
      axes = <local> [2, 3]
      in_spatial_dims = <local> [Dim{'time:data'[B]}, Dim{F'feature:data'(5)}]
AssertionError: invalid in_spatial_dims [Dim{'time:data'[B]}, Dim{F'feature:data'(5)}]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions