[BUG] pytorch-forecasting#1752 Fixing#1786
[BUG] pytorch-forecasting#1752 Fixing#1786RUPESH-KUMAR01 wants to merge 0 commit intosktime:mainfrom
Conversation
fkiraly
left a comment
There was a problem hiding this comment.
Great - could we add a test to ensure we have fixed the bug?
Changes Made to include the test for the bug :
Updated Code Snippet:predictions = net.predict(
val_dataloader,
return_index=True,
return_x=True,
return_y=True,
fast_dev_run=2, # 🔹 Now runs for two batches
trainer_kwargs=trainer_kwargs,
)
if isinstance(predictions.output, torch.Tensor):
assert predictions.output.shape == predictions.y[0].shape, "shape of predictions should match shape of targets"
else:
for i in range(len(predictions.output)):
assert predictions.output[i].shape == predictions.y[0][i].shape, "shape of predictions should match shape of targets"I am not familiar with tests, but while debugging, I found where the tests failed with my previous approach and modified the function with extra constraints. If the changes are sufficient I will add the modifications to the PR. |
|
The fix is pending? |
|
I tested the code initially before making changes in tests, which did not fail. The tests I modified are to add more constraints to the initial tests. The test modifications will ensure that the batch_size and time_stamps of both output and y match. |
Description
This PR is towards the issue #1752 . concat_sequences concat the timesteps of each batch. But our goal is to not concat time_stamps but to concat the batches.
Checklist
pre-commit install.To run hooks independent of commit, execute
pre-commit run --all-files