Skip to content

Commit 54c840b

Browse files
authored
[INTERPRETER][NFC] Rename tensor_shape -> block_shape in interpreter (#5195)
`tensor_shape` is a confusing name and doesn't match block pointer's semantic. `block_shape` is much clearer.
1 parent aaf64d6 commit 54c840b

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

python/triton/runtime/interpreter.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, data, dtype):
2121
'''
2222
data: numpy array
2323
dtype: triton type, either pointer_type or scalar_type.
24-
we don't store block_type here because the shape information is already availale in the data field
24+
we don't store block_type here because the shape information is already available in the data field
2525
attr: a dictionary of attributes
2626
'''
2727
self.data = data
@@ -46,24 +46,23 @@ def set_attr(self, key, value):
4646

4747
class BlockPointerHandle:
4848

49-
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
49+
def __init__(self, base, shape, strides, offsets, block_shape, order):
5050
self.base = base
5151
self.shape = shape
5252
self.strides = strides
5353
self.offsets = offsets
54-
self.tensor_shape = tensor_shape
54+
self.block_shape = block_shape
5555
self.order = order
5656

5757
def materialize_pointers(self, boundary_check):
5858
dtype_tt = self.base.get_element_ty()
5959
n_bytes = dtype_tt.primitive_bitwidth // 8
60-
tensor_shape = self.tensor_shape
61-
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
62-
masks = np.ones(self.tensor_shape, dtype=bool)
63-
for dim in range(len(tensor_shape)):
64-
bcast_dims = [1] * len(tensor_shape)
65-
bcast_dims[dim] = tensor_shape[dim]
66-
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
60+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
61+
masks = np.ones(self.block_shape, dtype=bool)
62+
for dim in range(len(self.block_shape)):
63+
bcast_dims = [1] * len(self.block_shape)
64+
bcast_dims[dim] = self.block_shape[dim]
65+
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
6766
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
6867
if dim in boundary_check:
6968
masks = np.logical_and(masks, off < self.shape[dim].data)
@@ -655,17 +654,17 @@ def create_barrier(self):
655654
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
656655
pass
657656

658-
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
657+
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
659658
# Create new offsets to avoid modifying the original
660659
new_offsets = [offset.clone() for offset in offsets]
661-
return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
660+
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
662661

663662
def create_advance(self, ptr, offsets):
664663
if len(ptr.offsets) != len(offsets):
665664
raise ValueError("len(ptr.offsets) != len(offsets)")
666665
# Create new offsets to avoid modifying the original
667666
new_offsets = [offset.clone() for offset in ptr.offsets]
668-
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
667+
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
669668
for i in range(len(offsets)):
670669
ret.offsets[i].data += offsets[i].data
671670
return ret

0 commit comments

Comments
 (0)