@@ -47,11 +47,11 @@ OpTaskInvocation backward(SplitAttrs const &attrs) {
4747void calc_block_size (coord_t &num_blocks,
4848 coord_t &block_size,
4949 ArrayShape const &array_shape,
50- int axis) {
50+ ff_dim_t axis) {
5151 num_blocks = 1 ;
5252 block_size = 1 ;
5353 for (int d = 0 ; d < array_shape.num_elements (); d++) {
54- if (d <= axis) {
54+ if (d <= axis. value . get_value () ) {
5555 block_size *= array_shape.at (legion_dim_t (d));
5656 } else {
5757 num_blocks *= array_shape.at (legion_dim_t (d));
@@ -66,12 +66,12 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
6666 auto attrs = acc.get_argument <SplitAttrs>(ATTRS);
6767
6868 coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS];
69- calc_block_size (num_blocks, in_block_size, input.shape , attrs.axis . value );
69+ calc_block_size (num_blocks, in_block_size, input.shape , attrs.axis );
7070
7171 for (int i = 0 ; i < attrs.splits .size (); i++) {
7272 coord_t out_num_blocks;
7373 calc_block_size (
74- out_num_blocks, out_block_size[i], output.shape , attrs.axis . value );
74+ out_num_blocks, out_block_size[i], output.shape , attrs.axis );
7575 }
7676 float *output_float_ptr = output.get_float_ptr ();
7777 return profile (forward_kernel,
@@ -94,12 +94,11 @@ static std::optional<float>
9494 auto attrs = acc.get_argument <SplitAttrs>(ATTRS);
9595
9696 coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS];
97- calc_block_size (
98- num_blocks, in_block_size, input_grad.shape , attrs.axis .value );
97+ calc_block_size (num_blocks, in_block_size, input_grad.shape , attrs.axis );
9998 for (int i = 0 ; i < attrs.splits .size (); i++) {
10099 coord_t out_num_blocks;
101100 calc_block_size (
102- out_num_blocks, out_block_size[i], output_grad.shape , attrs.axis . value );
101+ out_num_blocks, out_block_size[i], output_grad.shape , attrs.axis );
103102 }
104103 float const *output_grad_ptr = output_grad.get_float_ptr ();
105104 return profile (backward_kernel,
0 commit comments