Skip to content

Commit b617632

Browse files
author
Victor Li
committed
Changed ff_dim_t to use nonnegative_int, added relative_ff_dim_t that uses int
1 parent 670fb62 commit b617632

File tree

66 files changed

+724
-333
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+724
-333
lines changed

lib/kernels/src/legion_dim.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) {
77
}
88

99
legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, int num_dimensions) {
10-
return legion_dim_t(num_dimensions - ff_dim.value - 1);
10+
return legion_dim_t(num_dimensions - ff_dim.value.get_value() - 1);
1111
}
1212

1313
} // namespace FlexFlow

lib/kernels/test/src/test_concat_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ TEST_SUITE(FF_TEST_SUITE) {
77
TEST_CASE("Test concat kernel forward and backward") {
88
size_t num_inputs = 3;
99
size_t size_per_input = 100;
10-
ff_dim_t concat_axis = ff_dim_t(0);
10+
ff_dim_t concat_axis = ff_dim_t{nonnegative_int{0}};
1111

1212
ManagedPerDeviceFFHandle managed_handle{};
1313
ManagedFFStream managed_stream{};

lib/kernels/test/src/test_transpose_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ TEST_SUITE(FF_TEST_SUITE) {
77
TEST_CASE("Test Transpose Kernel Operations") {
88
std::size_t num_dims = 2;
99

10-
std::vector<ff_dim_t> perm = {ff_dim_t(0), ff_dim_t(1)};
10+
std::vector<ff_dim_t> perm = {ff_dim_t{nonnegative_int{0}},
11+
ff_dim_t{nonnegative_int{1}}};
1112

1213
ManagedPerDeviceFFHandle managed_handle{};
1314
ManagedFFStream managed_stream{};
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#include "local-execution/legion_tensor_shape.h"
2+
#include "kernels/legion_dim.h"
23
#include "op-attrs/tensor_shape.h"
34

45
namespace FlexFlow {
56

67
legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, size_t num_dims) {
7-
return legion_dim_t(num_dims - ff_dim.value - 1);
8+
return legion_dim_t(num_dims - ff_dim.value.get_value() - 1);
89
}
910

1011
legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, TensorShape const &shape) {
11-
return legion_dim_t(num_dims(shape) - ff_dim.value - 1);
12+
return legion_dim_from_ff_dim(ff_dim, num_dims(shape));
1213
}
1314

1415
} // namespace FlexFlow

lib/local-execution/src/ops/linear.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ static DeviceSpecificDeviceStates
6666
auto input = acc.get_tensor<Permissions::RO>(INPUT);
6767
auto weight = acc.get_tensor<Permissions::RO>(WEIGHT);
6868
auto output = acc.get_tensor<Permissions::WO>(OUTPUT);
69-
int out_dim = output.shape.at(ff_dim_t{0});
70-
int batch_size = output.shape.at(ff_dim_t{1});
69+
int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}});
70+
int batch_size = output.shape.at(ff_dim_t{nonnegative_int{1}});
7171

7272
float *one_ptr;
7373

@@ -96,8 +96,8 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
9696
ProfilingSettings profiling = acc.get_argument<ProfilingSettings>(PROFILING);
9797
auto attrs = acc.get_argument<LinearAttrs>(ATTRS);
9898

99-
int in_dim = input.shape.at(ff_dim_t{0}) + 1;
100-
int out_dim = output.shape.at(ff_dim_t{0}) + 1;
99+
int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
100+
int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
101101
int batch_size = output.shape.get_volume() / out_dim;
102102

103103
float const *bias_ptr = NULL;
@@ -140,8 +140,8 @@ static std::optional<float>
140140
bias_ptr = bias.get_float_ptr();
141141
}
142142

143-
int in_dim = input.shape.at(ff_dim_t{0}) + 1;
144-
int out_dim = output.shape.at(ff_dim_t{0}) + 1;
143+
int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
144+
int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
145145
int batch_size = output.shape.get_volume() / out_dim;
146146

147147
return profile(backward_kernel,

lib/local-execution/src/ops/pool_2d.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ static DeviceSpecificDeviceStates
3030
auto input = acc.get_tensor<Permissions::RO>(INPUT);
3131
auto output = acc.get_tensor<Permissions::WO>(OUTPUT);
3232

33-
int input_w = input.shape.at(ff_dim_t(0)) + 1;
34-
int input_h = input.shape.at(ff_dim_t(1)) + 1;
35-
int input_c = input.shape.at(ff_dim_t(2)) + 1;
36-
int input_n = input.shape.at(ff_dim_t(3)) + 1;
37-
int output_w = output.shape.at(ff_dim_t(0)) + 1;
38-
int output_h = output.shape.at(ff_dim_t(1)) + 1;
39-
int output_c = output.shape.at(ff_dim_t(2)) + 1;
40-
int output_n = output.shape.at(ff_dim_t(3)) + 1;
33+
int input_w = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
34+
int input_h = input.shape.at(ff_dim_t{nonnegative_int{1}}) + 1;
35+
int input_c = input.shape.at(ff_dim_t{nonnegative_int{2}}) + 1;
36+
int input_n = input.shape.at(ff_dim_t{nonnegative_int{3}}) + 1;
37+
int output_w = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1;
38+
int output_h = output.shape.at(ff_dim_t{nonnegative_int{1}}) + 1;
39+
int output_c = output.shape.at(ff_dim_t{nonnegative_int{2}}) + 1;
40+
int output_n = output.shape.at(ff_dim_t{nonnegative_int{3}}) + 1;
4141

4242
printf("init pool (input): n(%d) c(%d) h(%d) "
4343
"w(%d)\n",

lib/local-execution/src/ops/reverse.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
5353
coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1;
5454
for (int i = 0; i < output.shape.get_dim(); i++) {
5555
if (i < axis.value) {
56-
in_blk_size *= output.shape.at(ff_dim_t(i));
56+
in_blk_size *= output.shape.at(ff_dim_t{nonnegative_int{i}});
5757
} else if (i == axis.value) {
58-
reverse_dim_size = output.shape.at(ff_dim_t(i));
58+
reverse_dim_size = output.shape.at(ff_dim_t{nonnegative_int{i}});
5959
} else {
60-
num_out_blks *= output.shape.at(ff_dim_t(i));
60+
num_out_blks *= output.shape.at(ff_dim_t{nonnegative_int{i}});
6161
}
6262
}
6363

@@ -79,15 +79,15 @@ static std::optional<float>
7979
auto output_grad = acc.get_tensor_grad<Permissions::RO>(OUTPUT);
8080
auto attrs = acc.get_argument<ReverseAttrs>(ATTRS);
8181

82-
int axis = input_grad.shape.get_dim() - attrs.axis.value - 1;
82+
int axis = input_grad.shape.get_dim() - attrs.axis.value.get_value() - 1;
8383
coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1;
8484
for (int i = 0; i < input_grad.shape.get_dim(); i++) {
8585
if (i < axis) {
86-
in_blk_size *= input_grad.shape.at(ff_dim_t(i));
86+
in_blk_size *= input_grad.shape.at(ff_dim_t{nonnegative_int{i}});
8787
} else if (i == axis) {
88-
reverse_dim_size = input_grad.shape.at(ff_dim_t(i));
88+
reverse_dim_size = input_grad.shape.at(ff_dim_t{nonnegative_int{i}});
8989
} else {
90-
num_out_blks *= input_grad.shape.at(ff_dim_t(i));
90+
num_out_blks *= input_grad.shape.at(ff_dim_t{nonnegative_int{i}});
9191
}
9292
}
9393

lib/local-execution/src/ops/softmax.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ static DeviceSpecificDeviceStates
6464
int output_c = output.shape.at(legion_dim_t(2));
6565
int output_n = output.shape.at(legion_dim_t(3));
6666

67-
SoftmaxPerDeviceState per_device_state = init_kernel(
68-
handle, attrs.dim.value, output_n, output_c, output_h, output_w);
67+
SoftmaxPerDeviceState per_device_state =
68+
init_kernel(handle,
69+
attrs.dim.value.get_value(),
70+
output_n,
71+
output_c,
72+
output_h,
73+
output_w);
6974

7075
return DeviceSpecificDeviceStates{
7176
DeviceSpecific<SoftmaxPerDeviceState>::create(per_device_state)};

lib/local-execution/src/ops/split.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ OpTaskInvocation backward(SplitAttrs const &attrs) {
4747
void calc_block_size(coord_t &num_blocks,
4848
coord_t &block_size,
4949
ArrayShape const &array_shape,
50-
int axis) {
50+
ff_dim_t axis) {
5151
num_blocks = 1;
5252
block_size = 1;
5353
for (int d = 0; d < array_shape.num_elements(); d++) {
54-
if (d <= axis) {
54+
if (d <= axis.value.get_value()) {
5555
block_size *= array_shape.at(legion_dim_t(d));
5656
} else {
5757
num_blocks *= array_shape.at(legion_dim_t(d));
@@ -66,12 +66,12 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
6666
auto attrs = acc.get_argument<SplitAttrs>(ATTRS);
6767

6868
coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS];
69-
calc_block_size(num_blocks, in_block_size, input.shape, attrs.axis.value);
69+
calc_block_size(num_blocks, in_block_size, input.shape, attrs.axis);
7070

7171
for (int i = 0; i < attrs.splits.size(); i++) {
7272
coord_t out_num_blocks;
7373
calc_block_size(
74-
out_num_blocks, out_block_size[i], output.shape, attrs.axis.value);
74+
out_num_blocks, out_block_size[i], output.shape, attrs.axis);
7575
}
7676
float *output_float_ptr = output.get_float_ptr();
7777
return profile(forward_kernel,
@@ -94,12 +94,11 @@ static std::optional<float>
9494
auto attrs = acc.get_argument<SplitAttrs>(ATTRS);
9595

9696
coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS];
97-
calc_block_size(
98-
num_blocks, in_block_size, input_grad.shape, attrs.axis.value);
97+
calc_block_size(num_blocks, in_block_size, input_grad.shape, attrs.axis);
9998
for (int i = 0; i < attrs.splits.size(); i++) {
10099
coord_t out_num_blocks;
101100
calc_block_size(
102-
out_num_blocks, out_block_size[i], output_grad.shape, attrs.axis.value);
101+
out_num_blocks, out_block_size[i], output_grad.shape, attrs.axis);
103102
}
104103
float const *output_grad_ptr = output_grad.get_float_ptr();
105104
return profile(backward_kernel,

0 commit comments

Comments
 (0)