@@ -259,6 +259,47 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
259259 return False
260260
261261
262+ def ndim_of (node : Any ) -> Optional [int ]:
263+ """
264+ Returns the number of dimensions of the tensor produced by the given node
265+ """
266+ if not is_single_tensor_node (node ):
267+ return None
268+
269+ return node .meta ["val" ].ndim
270+
271+
272+ def is_unsqueezed_vector (node : torch .fx .Node ) -> bool :
273+ """
274+ Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
275+ """
276+ if not is_single_tensor_node (node ):
277+ return False
278+
279+ tensor = node .meta ["val" ]
280+ assert isinstance (tensor , FakeTensor )
281+
282+ if len (tensor .shape ) < 1 :
283+ return False
284+ # All dims except last are 1, last can be any size
285+ return all (dim == 1 for dim in tensor .shape [:- 1 ])
286+
287+
288+ def op_contains_bool_tensor (node : torch .fx .Node ) -> bool :
289+ """
290+ Returns true if the operator used to compute the given node contains a bool tensor
291+ """
292+ if is_tensor_node (node ) and tensor_node_is_bool (node ):
293+ return True
294+
295+ for arg_node in node .args :
296+ # pyre-ignore[6]
297+ if is_tensor_node (arg_node ) and tensor_node_is_bool (arg_node ):
298+ return True
299+
300+ return False
301+
302+
262303def get_primary_arg_idx (self , node : torch .fx .Node ) -> Optional [int ]:
263304 primary_arg_idx : Optional [int ] = None
264305 for i , arg_node in enumerate (node .args ):
@@ -568,6 +609,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
568609 self .valid_texture_layouts & other .valid_texture_layouts ,
569610 )
570611
612+ def make_union (self , other : "TensorRepSet" ) -> "TensorRepSet" :
613+ """
614+ Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
615+ with the union of the two.
616+ """
617+ return TensorRepSet (
618+ self .valid_buffer_layouts | other .valid_buffer_layouts ,
619+ self .valid_texture_layouts | other .valid_texture_layouts ,
620+ )
621+
571622 def is_compatible (self , storage : TensorRepr ) -> bool :
572623 """
573624 Check if this TensorRepr is compatible with the given TensorRepSet.
@@ -693,10 +744,6 @@ def make_filtered_tensor_repset(
693744 if len (tensor_val .shape ) > 4 :
694745 return TensorRepSet (tensor_repset .valid_buffer_layouts , set ())
695746
696- # Bool tensors are currently not supported
697- if tensor_val .dtype == torch .bool :
698- return NO_STORAGE
699-
700747 return TensorRepSet (tensor_repset .valid_buffer_layouts , valid_texture_layouts )
701748
702749
@@ -1230,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool:
12301277##
12311278
12321279
1280+ def normalize_dims (dims : Union [int , List [int ]], ndim : int ) -> Union [int , List [int ]]:
1281+ """
1282+ Normalize dimension indices to be non-negative and within [0, ndim).
1283+ Accepts a single int or a list of ints.
1284+ """
1285+ if isinstance (dims , int ):
1286+ if dims < 0 :
1287+ dims += ndim
1288+
1289+ return dims
1290+
1291+ normalized = []
1292+ for d in dims :
1293+ if d < 0 :
1294+ d += ndim
1295+ normalized .append (d )
1296+
1297+ return normalized
1298+
1299+
12331300def nchw_dim_to_whcn_dim (nchw_dim : int , ndim : int ) -> int :
12341301 # Handle negative indices for nchw_dim
12351302 if nchw_dim < 0 :
0 commit comments