@@ -21,7 +21,7 @@ def __init__(self, data, dtype):
2121 '''
2222 data: numpy array
2323 dtype: triton type, either pointer_type or scalar_type.
24- we don't store block_type here because the shape information is already availale in the data field
24+ we don't store block_type here because the shape information is already available in the data field
2525 attr: a dictionary of attributes
2626 '''
2727 self .data = data
@@ -46,24 +46,23 @@ def set_attr(self, key, value):
4646
4747class BlockPointerHandle :
4848
49- def __init__ (self , base , shape , strides , offsets , tensor_shape , order ):
49+ def __init__ (self , base , shape , strides , offsets , block_shape , order ):
5050 self .base = base
5151 self .shape = shape
5252 self .strides = strides
5353 self .offsets = offsets
54- self .tensor_shape = tensor_shape
54+ self .block_shape = block_shape
5555 self .order = order
5656
5757 def materialize_pointers (self , boundary_check ):
5858 dtype_tt = self .base .get_element_ty ()
5959 n_bytes = dtype_tt .primitive_bitwidth // 8
60- tensor_shape = self .tensor_shape
61- ptrs = np .broadcast_to (self .base .data , self .tensor_shape )
62- masks = np .ones (self .tensor_shape , dtype = bool )
63- for dim in range (len (tensor_shape )):
64- bcast_dims = [1 ] * len (tensor_shape )
65- bcast_dims [dim ] = tensor_shape [dim ]
66- off = (self .offsets [dim ].data + np .arange (tensor_shape [dim ])).reshape (bcast_dims )
60+ ptrs = np .broadcast_to (self .base .data , self .block_shape )
61+ masks = np .ones (self .block_shape , dtype = bool )
62+ for dim in range (len (self .block_shape )):
63+ bcast_dims = [1 ] * len (self .block_shape )
64+ bcast_dims [dim ] = self .block_shape [dim ]
65+ off = (self .offsets [dim ].data + np .arange (self .block_shape [dim ])).reshape (bcast_dims )
6766 ptrs = ptrs + (n_bytes * off * self .strides [dim ].data ).astype (np .uint64 )
6867 if dim in boundary_check :
6968 masks = np .logical_and (masks , off < self .shape [dim ].data )
@@ -655,17 +654,17 @@ def create_barrier(self):
655654 # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
656655 pass
657656
658- def create_make_block_ptr (self , base , shape , strides , offsets , tensor_shape , order ):
657+ def create_make_block_ptr (self , base , shape , strides , offsets , block_shape , order ):
659658 # Create new offsets to avoid modifying the original
660659 new_offsets = [offset .clone () for offset in offsets ]
661- return BlockPointerHandle (base , shape , strides , new_offsets , tensor_shape , order )
660+ return BlockPointerHandle (base , shape , strides , new_offsets , block_shape , order )
662661
663662 def create_advance (self , ptr , offsets ):
664663 if len (ptr .offsets ) != len (offsets ):
665664 raise ValueError ("len(ptr.offsets) != len(offsets)" )
666665 # Create new offsets to avoid modifying the original
667666 new_offsets = [offset .clone () for offset in ptr .offsets ]
668- ret = BlockPointerHandle (ptr .base , ptr .shape , ptr .strides , new_offsets , ptr .tensor_shape , ptr .order )
667+ ret = BlockPointerHandle (ptr .base , ptr .shape , ptr .strides , new_offsets , ptr .block_shape , ptr .order )
669668 for i in range (len (offsets )):
670669 ret .offsets [i ].data += offsets [i ].data
671670 return ret
0 commit comments