Skip to content

Commit 68b254d

Browse files
authored
Fix TensorSchema validation test for symbolic dims (#22366)
Signed-off-by: Benji Beck <[email protected]>
1 parent 8c50d62 commit 68b254d

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

tests/standalone_tests/test_tensor_schema.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
import torch
66

7-
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
87
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
8+
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
99
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
1010

1111

@@ -129,23 +129,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
129129

130130

131131
def test_tensor_schema_with_list_of_symbolic_dim():
132-
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
133-
patches_per_image = [64, 64, 64] # len = bn = 3
134-
135-
FuyuImagePatchInputs(
136-
flat_data=flat_data,
137-
patches_per_image=patches_per_image,
132+
input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160)
133+
input_features_mask = torch.randn(3, 8) # (b=3, fo=8)
134+
audio_embed_sizes = [8, 8, 8] # len = b = 3
135+
136+
GraniteSpeechAudioInputs(
137+
input_features=input_features,
138+
input_features_mask=input_features_mask,
139+
audio_embed_sizes=audio_embed_sizes,
138140
)
139141

140142

141143
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
142-
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
143-
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
144-
145-
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
146-
FuyuImagePatchInputs(
147-
flat_data=flat_data,
148-
patches_per_image=patches_per_image,
144+
input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160)
145+
input_features_mask = torch.randn(4, 8) # (b=4, fo=8)
146+
audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b
147+
148+
with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
149+
GraniteSpeechAudioInputs(
150+
input_features=input_features,
151+
input_features_mask=input_features_mask,
152+
audio_embed_sizes=audio_embed_sizes,
149153
)
150154

151155

0 commit comments

Comments
 (0)