1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- from typing import Annotated , Any , Union , get_args , get_origin , get_type_hints
3
+ from typing import (Annotated , Any , Optional , Union , get_args , get_origin ,
4
+ get_type_hints )
4
5
5
6
import torch
6
7
11
12
12
13
class TensorShape :
13
14
14
- def __init__ (self ,
15
- * dims : Union [int , str ],
16
- dynamic_dims : set [str , ...] = None ) -> None :
15
+ def __init__ (
16
+ self ,
17
+ * dims : Union [int , str ],
18
+ dynamic_dims : Optional [set [str ]] = None ,
19
+ ) -> None :
20
+ super ().__init__ ()
21
+
17
22
self .dims = dims
18
23
self .dynamic_dims = dynamic_dims if dynamic_dims else set ()
19
24
@@ -44,11 +49,15 @@ def __str__(self) -> str:
44
49
45
50
class TensorSchema :
46
51
47
- def __init__ (self ,
48
- * ,
49
- validate : bool = True ,
50
- resolve_bindings : dict [str , int ] = None ,
51
- ** kwargs : Any ) -> None :
52
+ def __init__ (
53
+ self ,
54
+ * ,
55
+ validate : bool = True ,
56
+ resolve_bindings : Optional [dict [str , int ]] = None ,
57
+ ** kwargs : Any ,
58
+ ) -> None :
59
+ super ().__init__ ()
60
+
52
61
self ._resolve_bindings = resolve_bindings if resolve_bindings else {}
53
62
54
63
for key , value in kwargs .items ():
@@ -57,16 +66,19 @@ def __init__(self,
57
66
if validate :
58
67
self .validate ()
59
68
60
- def __getitem__ (self , item ) -> Any :
61
- return getattr (self , item )
69
+ def __getitem__ (self , key : str ) -> Any :
70
+ return getattr (self , key )
62
71
63
- def get (self , item , default = None ) -> Any :
64
- return getattr (self , item , default )
72
+ def get (self , key : str , default : Any = None ) -> Any :
73
+ return getattr (self , key , default )
65
74
66
- def _match_shape_with_dynamic (self , actual : tuple [int , ...],
67
- reference : tuple [int , ...],
68
- expected_shape : tuple [Union [int , str ], ...],
69
- dynamic_dims : set [str , ...]) -> bool :
75
+ def _match_shape_with_dynamic (
76
+ self ,
77
+ actual : tuple [int , ...],
78
+ reference : tuple [int , ...],
79
+ expected_shape : tuple [Union [int , str ], ...],
80
+ dynamic_dims : set [str ],
81
+ ) -> bool :
70
82
if len (actual ) != len (reference ) or len (actual ) > len (expected_shape ):
71
83
return False
72
84
@@ -84,10 +96,12 @@ def _match_shape_with_dynamic(self, actual: tuple[int, ...],
84
96
return True
85
97
86
98
def _validate_nested_tensors (
87
- self , value : Union [list [torch .Tensor , ...],
88
- tuple [torch .Tensor , ...]], field_name : str ,
89
- expected_shape : tuple [Union [int , str ], ...],
90
- dynamic_dims : set [str , ...]) -> tuple [int , ...]:
99
+ self ,
100
+ value : Union [list [torch .Tensor ], tuple [torch .Tensor , ...]],
101
+ field_name : str ,
102
+ expected_shape : tuple [Union [int , str ], ...],
103
+ dynamic_dims : set [str ],
104
+ ) -> tuple [int , ...]:
91
105
"""Validate a list/tuple of tensors and return the actual shape."""
92
106
# Ensure all tensors in the list have the same
93
107
# shape, besides dynamic dimensions
@@ -110,12 +124,14 @@ def _validate_nested_tensors(
110
124
# shape = (len(list), *tensor.shape)
111
125
return (len (value ), ) + first .shape
112
126
113
- def _validate_tensor_shape_expected (self , actual_shape : tuple [int , ...],
114
- expected_shape : tuple [Union [int , str ],
115
- ...],
116
- field_name : str , shape_env : dict [str ,
117
- int ],
118
- dynamic_dims : set [str , ...]) -> None :
127
+ def _validate_tensor_shape_expected (
128
+ self ,
129
+ actual_shape : tuple [int , ...],
130
+ expected_shape : tuple [Union [int , str ], ...],
131
+ field_name : str ,
132
+ shape_env : dict [str , int ],
133
+ dynamic_dims : set [str ],
134
+ ) -> None :
119
135
"""Validate that the actual tensor shape matches the expected shape."""
120
136
121
137
if len (actual_shape ) != len (expected_shape ):
0 commit comments