diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c3..444b6043a3f 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -294,4 +294,4 @@ def __init__(self, num_input, output_size, num_channels, kernel_size, dropout): def forward(self, x): output = self.tcn(x) output = self.linear(output[:, :, -1]) - return output.squeeze() + return output.squeeze(-1) diff --git a/tests/model/test_pytorch_tcn_ts.py b/tests/model/test_pytorch_tcn_ts.py new file mode 100644 index 00000000000..58a3a8f5838 --- /dev/null +++ b/tests/model/test_pytorch_tcn_ts.py @@ -0,0 +1,10 @@ +import torch + +from qlib.contrib.model.pytorch_tcn_ts import TCNModel + + +def test_tcn_ts_model_keeps_batch_dimension_for_single_sample(): + model = TCNModel(num_input=3, output_size=1, num_channels=[4], kernel_size=2, dropout=0.0) + + assert model(torch.randn(1, 3, 8)).shape == (1,) + assert model(torch.randn(2, 3, 8)).shape == (2,)