-
Notifications
You must be signed in to change notification settings - Fork 6
Convolution over feature dim #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,14 +78,24 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: | |
| assert all(p == 0 for p in self.padding) # not implemented otherwise | ||
| assert all(p == 0 for p in self.output_padding) # not implemented otherwise | ||
| assert self.padding_mode == "zeros" # not implemented otherwise | ||
|
|
||
| from pytorch_to_returnn.naming import Naming | ||
| naming = Naming.get_instance() | ||
| input_tensor = naming.tensors[input] | ||
| in_dim = input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[1]] | ||
| in_spatial_dims = [ | ||
| input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[dim + len(input.shape)]] | ||
| for dim in range(-self.nd, 0)] | ||
|
|
||
| d = { | ||
| "class": "conv", "from": self._get_input_layer_name(input), | ||
| "activation": None, | ||
| "with_bias": self.bias is not None, | ||
| "n_out": self.out_channels, | ||
| "filter_size": self.kernel_size, | ||
| "padding": "valid", | ||
| "in_spatial_dims": [self._get_input_axis_to_returnn(input, dim) for dim in range(-self.nd, 0)], | ||
| "in_spatial_dims": in_spatial_dims, | ||
| "in_dim": in_dim, | ||
| } | ||
| if any(s != 1 for s in self.stride): | ||
| d["strides"] = self.stride | ||
|
|
@@ -121,18 +131,53 @@ def import_params_torch_to_returnn(self, *, layer: LayerBase, torch_module: _Con | |
| def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: LayerBase | ||
| ) -> Tuple[Tuple[int, ...], Dict[int, int]]: | ||
| """ | ||
| The basic returnn_axis_from_torch_axis should be correct, however, if the size of a dynamic axis changes (e.g. due | ||
| to strides and/or padding), this is not covered in the base method and we fix it here. | ||
| If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method | ||
| and we fix it here. Also, the basic returnn_axis_from_torch_axis fails if the RETURNN input feature dim is used as a | ||
| spatial dim for convolution. We try to cover this here and use the basic implementation as a fallback. | ||
| """ | ||
| torch_shape, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( | ||
| inputs_flat=inputs_flat, layer=layer) | ||
| assert len(inputs_flat) == 1 | ||
| torch_shape = list(inputs_flat[0].shape) | ||
| torch_shape[1] = self.out_channels | ||
| for idx in range(self.nd): | ||
| torch_ax = idx + 2 | ||
| torch_shape[torch_ax] = (torch_shape[torch_ax] + 2 * self.padding[idx] - self.dilation[idx] * ( | ||
| self.kernel_size[idx] - 1) - 1) // self.stride[idx] + 1 | ||
|
|
||
| from pytorch_to_returnn.naming import Naming | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why you do it so complicated. This code here should be quite short, straightforward, and not using any heuristics. Your code here is full of heuristics, checking whether you can map all axes, etc. You don't need that. We know exactly how it must map.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
All dims that did not change are trivial, the channel dim can be done as I do it here. How would you do it for the spatial dims? Just assume that the order of spatial dims is the same as in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added an update which does the mapping as I described above.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What order? Of the RETURNN output? We don't need to guess anything here. We know everything exactly.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Except BCHW vs BHWC, but you can just check where out_dim is.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, then that is exactly what I do now, right? |
||
| naming = Naming.get_instance() | ||
| input_tensor = naming.tensors[inputs_flat[0]] | ||
| in_data = input_tensor.returnn_data | ||
| out_data = layer.output | ||
| assert in_data.batch_ndim == out_data.batch_ndim | ||
|
|
||
| mapping_out_to_in = {} | ||
| if in_data.batch_dim_axis is not None and out_data.batch_dim_axis is not None: | ||
| mapping_out_to_in[out_data.batch_dim_axis] = in_data.batch_dim_axis | ||
| if in_data.time_dim_axis and out_data.time_dim_axis: | ||
| mapping_out_to_in[out_data.time_dim_axis] = in_data.time_dim_axis | ||
| in_channel = input_tensor.returnn_axis_from_torch_axis[1] | ||
| out_channel = [ | ||
| dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels] | ||
| if len(out_channel) == 1: | ||
| mapping_out_to_in[out_channel[0]] = in_channel | ||
|
|
||
| if len(mapping_out_to_in) == in_data.batch_ndim - 1: | ||
| # only one axis is missing, just take remaining axis | ||
| remaining_in = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.values())) | ||
| remaining_out = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.keys())) | ||
| assert len(remaining_in) == 1 and len(remaining_out) == 1 | ||
| mapping_out_to_in[remaining_out.pop()] = remaining_in.pop() | ||
|
|
||
| if len(mapping_out_to_in) == in_data.batch_ndim: | ||
| # found all axes, so we can proceed | ||
| returnn_axis_from_torch_axis = {} | ||
| for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): | ||
| torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv | ||
| returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis | ||
| else: | ||
| # did not find all axes, so fall back to (possibly faulty) default mapping | ||
| _, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( | ||
| inputs_flat=inputs_flat, layer=layer) | ||
| return tuple(torch_shape), returnn_axis_from_torch_axis | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you use
_get_input_axis_to_returnnfor that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because that would return an axis description like
"T"or"F". This would be mapped to a dim tag in theConvLayerconstruction. However, in case we do convolution over the feature dim,"F"would be mapped to thein_dim, so the new feature dim and not the old one which does not work.