Skip to content

Commit 7bf27fe

Browse files
authored
[ET-VK][ez] Allow bool tensors to be lowered to ET-VK and add uint8(bool) dtype variants for several compute shaders (pytorch#15316)
Title says it all! Differential Revision: [D84716458](https://our.internmc.facebook.com/intern/diff/D84716458/) [ghstack-poisoned]
1 parent ff6deb2 commit 7bf27fe

File tree

8 files changed

+90
-5
lines changed

8 files changed

+90
-5
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
texture_limits: utils.ImageExtents,
6060
buffer_limit: int,
6161
require_dynamic_shape: bool = False,
62+
skip_bool_tensors: bool = False,
6263
operator_blocklist: Optional[Set[OpKey]] = None,
6364
operator_allowlist: Optional[Set[OpKey]] = None,
6465
fusable_subgraphs: Optional[List[PatternMatch]] = None,
@@ -69,6 +70,7 @@ def __init__(
6970
self.texture_limits: utils.ImageExtents = texture_limits
7071
self.buffer_limit = buffer_limit
7172
self.require_dynamic_shapes = require_dynamic_shape
73+
self.skip_bool_tensors = skip_bool_tensors
7274
self.operator_blocklist: Set[OpKey] = (
7375
operator_blocklist if operator_blocklist is not None else set()
7476
)
@@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
117119
return False, "no operator implementation"
118120
features = get_op_features(target)
119121

122+
# bool tensors are internally represented with int8 buffers, which may not be
123+
# supported by some GPUs. Therefore, provide the option to skip these tensors.
124+
if self.skip_bool_tensors and utils.op_contains_bool_tensor(node):
125+
return False, f"op {utils.node_io_str(node)} contains bool tensor"
126+
120127
# Get the possible tensor representations for each tensor participating in the
121128
# this operator. Then check that all tensors are representable as either a
122129
# buffer or texture.
@@ -398,6 +405,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
398405
texture_limits,
399406
buffer_limit,
400407
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
408+
skip_bool_tensors=self.options.get("skip_bool_tensors", False),
401409
operator_blocklist=self.operator_blocklist,
402410
operator_allowlist=self.operator_allowlist,
403411
fusable_subgraphs=fusable_subgraphs,

backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_buffer:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_buffer

backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_texture:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_texture3d

backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ transfer_buffer:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: select_buffer
1314
OP_NAME: select

backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ transfer_texture:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: select_texture3d
1314
OP_NAME: select

backends/vulkan/runtime/graph/ops/glsl/view.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ view:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: view

backends/vulkan/test/tester.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def __init__(
4444

4545
class Partition(BaseStages.Partition):
4646
def __init__(self, partitioner: Optional[Partitioner] = None):
47+
vk_compile_spec = {"skip_bool_tensors": True}
4748
super().__init__(
48-
partitioner=partitioner or VulkanPartitioner(),
49+
partitioner=partitioner or VulkanPartitioner(vk_compile_spec),
4950
)
5051

5152

@@ -55,6 +56,10 @@ def __init__(
5556
partitioners: Optional[List[Partitioner]] = None,
5657
edge_compile_config: Optional[EdgeCompileConfig] = None,
5758
):
59+
if partitioners is None:
60+
vk_compile_spec = {"skip_bool_tensors": True}
61+
partitioners = [VulkanPartitioner(vk_compile_spec)]
62+
5863
super().__init__(
5964
default_partitioner_cls=VulkanPartitioner,
6065
partitioners=partitioners,

backends/vulkan/utils.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,47 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
259259
return False
260260

261261

262+
def ndim_of(node: Any) -> Optional[int]:
263+
"""
264+
Returns the number of dimensions of the tensor produced by the given node
265+
"""
266+
if not is_single_tensor_node(node):
267+
return None
268+
269+
return node.meta["val"].ndim
270+
271+
272+
def is_unsqueezed_vector(node: torch.fx.Node) -> bool:
273+
"""
274+
Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
275+
"""
276+
if not is_single_tensor_node(node):
277+
return False
278+
279+
tensor = node.meta["val"]
280+
assert isinstance(tensor, FakeTensor)
281+
282+
if len(tensor.shape) < 1:
283+
return False
284+
# All dims except last are 1, last can be any size
285+
return all(dim == 1 for dim in tensor.shape[:-1])
286+
287+
288+
def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
289+
"""
290+
Returns true if the operator used to compute the given node contains a bool tensor
291+
"""
292+
if is_tensor_node(node) and tensor_node_is_bool(node):
293+
return True
294+
295+
for arg_node in node.args:
296+
# pyre-ignore[6]
297+
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
298+
return True
299+
300+
return False
301+
302+
262303
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
263304
primary_arg_idx: Optional[int] = None
264305
for i, arg_node in enumerate(node.args):
@@ -568,6 +609,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
568609
self.valid_texture_layouts & other.valid_texture_layouts,
569610
)
570611

612+
def make_union(self, other: "TensorRepSet") -> "TensorRepSet":
613+
"""
614+
Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
615+
with the union of the two.
616+
"""
617+
return TensorRepSet(
618+
self.valid_buffer_layouts | other.valid_buffer_layouts,
619+
self.valid_texture_layouts | other.valid_texture_layouts,
620+
)
621+
571622
def is_compatible(self, storage: TensorRepr) -> bool:
572623
"""
573624
Check if this TensorRepr is compatible with the given TensorRepSet.
@@ -693,10 +744,6 @@ def make_filtered_tensor_repset(
693744
if len(tensor_val.shape) > 4:
694745
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
695746

696-
# Bool tensors are currently not supported
697-
if tensor_val.dtype == torch.bool:
698-
return NO_STORAGE
699-
700747
return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)
701748

702749

@@ -1230,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool:
12301277
##
12311278

12321279

1280+
def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]:
1281+
"""
1282+
Normalize dimension indices to be non-negative and within [0, ndim).
1283+
Accepts a single int or a list of ints.
1284+
"""
1285+
if isinstance(dims, int):
1286+
if dims < 0:
1287+
dims += ndim
1288+
1289+
return dims
1290+
1291+
normalized = []
1292+
for d in dims:
1293+
if d < 0:
1294+
d += ndim
1295+
normalized.append(d)
1296+
1297+
return normalized
1298+
1299+
12331300
def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int:
12341301
# Handle negative indices for nchw_dim
12351302
if nchw_dim < 0:

0 commit comments

Comments
 (0)