Skip to content

Commit c5f6ebe

Browse files
authored
Fix LSTM visualization (#65)
1 parent 45f5161 commit c5f6ebe

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

tests/test_layered.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5858
return CustomModel()
5959

6060

61+
@pytest.fixture()
62+
def lstm_model() -> nn.Module:
63+
"""Define a simple LSTM model for testing."""
64+
65+
class LSTMModel(nn.Module):
66+
"""A simple LSTM model."""
67+
68+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
69+
super().__init__()
70+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
71+
72+
def forward(self, x: torch.Tensor) -> torch.Tensor:
73+
"""Forward pass."""
74+
out, _ = self.lstm(x)
75+
return out
76+
77+
# Create an instance of the LSTM model
78+
return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
79+
80+
6181
def test_sequential_model_layered_view_runs(sequential_model: nn.Sequential) -> None:
6282
"""Test layered view on sequential model."""
6383
_ = layered_view(sequential_model, input_shape=(1, 3, 224, 224))
@@ -71,3 +91,8 @@ def test_module_list_model_layered_view_runs(module_list_model: nn.ModuleList) -
7191
def test_custom_model_layered_view_runs(custom_model: nn.Module) -> None:
7292
"""Test layered view on custom model."""
7393
_ = layered_view(custom_model, input_shape=(1, 3, 224, 224))
94+
95+
96+
def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
97+
"""Test layered view on lstm model."""
98+
_ = layered_view(lstm_model, input_shape=(1, 10, 10))

tests/test_lenet_style.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5858
return CustomModel()
5959

6060

61+
@pytest.fixture()
62+
def lstm_model() -> nn.Module:
63+
"""Define a simple LSTM model for testing."""
64+
65+
class LSTMModel(nn.Module):
66+
"""A simple LSTM model."""
67+
68+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
69+
super().__init__()
70+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
71+
72+
def forward(self, x: torch.Tensor) -> torch.Tensor:
73+
"""Forward pass."""
74+
out, _ = self.lstm(x)
75+
return out
76+
77+
# Create an instance of the LSTM model
78+
return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
79+
80+
6181
def test_sequential_model_lenet_view_runs(sequential_model: nn.Sequential) -> None:
6282
"""Test lenet view on sequential model."""
6383
_ = lenet_view(sequential_model, input_shape=(1, 3, 224, 224))
@@ -71,3 +91,8 @@ def test_module_list_model_lenet_view_runs(module_list_model: nn.ModuleList) ->
7191
def test_custom_model_lenet_view_runs(custom_model: nn.Module) -> None:
7292
"""Test lenet view on custom model."""
7393
_ = lenet_view(custom_model, input_shape=(1, 3, 224, 224))
94+
95+
96+
def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
97+
"""Test layered view on lstm model."""
98+
_ = lenet_view(lstm_model, input_shape=(1, 10, 10))

visualtorch/utils/layer_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,13 @@ def hook(
185185
m_key = "%s-%i" % (class_name, module_idx + 1)
186186
layers[m_key] = OrderedDict()
187187
layers[m_key]["module"] = module
188-
if isinstance(out, list | tuple):
189-
layers[m_key]["output_shape"] = tuple((-1,) + o.size()[1:] for o in out)
188+
if isinstance(out, tuple):
189+
if hasattr(out[0], "size"):
190+
layers[m_key]["output_shape"] = out[0].size()
191+
else:
192+
layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
193+
elif isinstance(out, list):
194+
layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
190195
else:
191196
layers[m_key]["output_shape"] = out.size()
192197

0 commit comments

Comments
 (0)