@@ -105,6 +105,8 @@ def fake_reshape(
105105 reduced_tensor = self .from_tensor (true_tensor , static_shapes = True ).sum (
106106 axis = tuple (sorted (sh )), keepdim = True
107107 )
108+ if len (reduced_tensor .shape ) == 0 == len (new_shape ):
109+ return reduced_tensor
108110 return reduced_tensor .expand (* new_shape )
109111
110112 def make_fake (self , x : Any ) -> Optional ["FakeTensor" ]: # noqa: F821
@@ -157,7 +159,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
157159 )
158160 if type (x ) is dict :
159161 return {
160- k : self .make_fake_with_dynamic_dimensions (v , dynamic_shapes = dynamic_shapes [k ])
162+ k : self .make_fake_with_dynamic_dimensions (
163+ v , dynamic_shapes = dynamic_shapes [k ] if dynamic_shapes else None
164+ )
161165 for k , v in x .items ()
162166 }
163167 if x .__class__ .__name__ in {"DynamicCache" , "StaticCache" , "HybridCache" }:
@@ -231,7 +235,7 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
231235
232236 x = torch .empty (tuple (new_shape ), dtype = x .dtype , device = x .device )
233237
234- t = self .fake_reshape (x , dynamic_shapes ) # type: ignore[arg-type]
238+ t = self .fake_reshape (x , dynamic_shapes ) if dynamic_shapes else x # type: ignore[arg-type]
235239 assert t .device == x .device , f"device mismatch { x .device } -> { t .device } "
236240 assert t .dtype == x .dtype , f"dtype mismatch { x .dtype } -> { t .dtype } "
237241 return t
0 commit comments