Skip to content

Commit f5be059

Browse files
prathikrrohan11235813
authored andcommitted
add kernel tests for ops that changed in opset18 (#19767)
### Description <!-- Describe your changes. --> - [x] Pad operator has introduced a new input called "axes" which specifies which axis to pad. But it defaults to input_rank if axes is not provided which was the behavior before the opset upgrade. - [x] ReduceMean - [x] ReduceL2 - [x] ReduceLogSumExp - [x] ReduceSum - Reduction ops all had the axes attribute switched to an input and a new attribute called "noop_with_empty_axes" was added to define what to do when axes is not specified. - [x] Resize has had two new attributes introduced: antialias and keep_aspect_ratio_policy. From Operators.md I've gathered: "Antialiasing is achieved by stretching the resampling filter by a factor max(1, 1 / scale), which means that when downsampling, more input pixels contribute to an output pixel." keep_aspect_ratio_policy "describes how to interpret the `sizes` input with regard to keeping the original aspect ratio of the input." there are a couple enum-type options that specify different policies and what to do in each case. - NOTE: Baiju already included opset18 tests in microsoft/onnxruntime#17772 - [x] ScatterElements/ScatterND has had a new attribute introduced called "reduction." This specifies the type of reduction to apply: none (default), add, mul, max, min. - [x] Split introduced a new attribute called "num_outputs" which specifies how many outputs to split the input tensor into. This is in contrast to the previous, default behavior of specifying a "split" input which defines the size of each resultant tensor of the output. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 69ee5f7 commit f5be059

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

orttraining/orttraining/core/graph/gradient_builder.cc

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
11121112

11131113
ArgDef grad = GO(0);
11141114
if (!keepdims) {
1115+
size_t numInputs = GetSrcNodeInputSize();
11151116
if (attributes.find("axes") != attributes.end()) {
11161117
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
11171118
grad = IA("Unqueezed_Grad");
@@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
11221123
result.push_back(axes_values_node);
11231124
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
11241125
}
1126+
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
1127+
grad = IA("Unqueezed_Grad");
1128+
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
11251129
}
11261130
}
11271131

@@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
11521156
}
11531157

11541158
ArgDef grad = GO(0);
1155-
if (!keepdims && attributes.find("axes") != attributes.end()) {
1156-
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
1157-
grad = IA("Unsqueezed_Grad");
1158-
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
1159+
if (!keepdims) {
1160+
size_t numInputs = GetSrcNodeInputSize();
1161+
if (attributes.find("axes") != attributes.end()) {
1162+
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
1163+
grad = IA("Unsqueezed_Grad");
11591164

1160-
result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
1165+
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
1166+
1167+
result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
1168+
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
1169+
grad = IA("Unsqueezed_Grad");
1170+
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
1171+
1172+
result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")}));
1173+
}
11611174
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
11621175
} else {
11631176
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
@@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
11881201
ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
11891202
result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));
11901203

1191-
if (!keepdims && attributes.find("axes") != attributes.end()) {
1192-
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
1204+
if (!keepdims) {
1205+
size_t numInputs = GetSrcNodeInputSize();
11931206
scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
1194-
result.emplace_back(
1195-
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
1207+
if (attributes.find("axes") != attributes.end()) {
1208+
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
1209+
result.emplace_back(
1210+
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
1211+
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
1212+
result.emplace_back(
1213+
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def}));
1214+
}
11961215
}
11971216

11981217
result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));

orttraining/orttraining/test/gradient/gradient_ops_test.cc

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) {
607607

608608
OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
609609
RunReductionTests(op_def_opset13);
610+
611+
// axes is input from opset 18.
612+
OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18};
613+
RunReductionTests(op_def_opset18, true, true);
610614
}
611615

612616
TEST(GradientCheckerTest, ReduceSumGrad) {
@@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
619623
OpDef op_def_13{"ReduceSum", kOnnxDomain, 13};
620624

621625
RunReductionTests(op_def_13, true, true);
626+
627+
OpDef op_def_18{"ReduceSum", kOnnxDomain, 18};
628+
629+
RunReductionTests(op_def_18, true, true);
622630
}
623631

624632
TEST(GradientCheckerTest, ReduceL2Grad) {
@@ -641,13 +649,22 @@ TEST(GradientCheckerTest, ReduceL2Grad) {
641649
{MakeAttribute("axes", axes)}));
642650
EXPECT_IS_TINY(max_error);
643651
}
652+
653+
// axes is input from opset 18
654+
OpDef op_def_18{"ReduceL2", kOnnxDomain, 18};
655+
656+
RunReductionTests(op_def_18, true, true);
644657
}
645658

646659
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
647660
// Attribute axes supports negative values from opset 11.
648661
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
649662

650663
RunReductionTests(op_def);
664+
665+
OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18};
666+
667+
RunReductionTests(op_def_opset18, true, true);
651668
}
652669

653670
TEST(GradientCheckerTest, ReluGrad) {
@@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) {
698715
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
699716
{MakeAttribute("axis", int64_t(0))}));
700717
EXPECT_IS_TINY(max_error);
718+
719+
// opset18 test
720+
OpDef op_def_18{"Split", kOnnxDomain, 18};
721+
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
722+
{MakeAttribute("axis", int64_t(0)),
723+
MakeAttribute("num_outputs", int64_t(3))}));
724+
EXPECT_IS_TINY(max_error);
701725
}
702726

703727
template <typename T>
@@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) {
27332757
TEST(GradientCheckerTest, PadGrad) {
27342758
float max_error;
27352759
GradientChecker<float, float, float> gradient_checker;
2736-
OpDef op_def{"Pad", kOnnxDomain, 11};
2760+
OpDef op_def{"Pad", kOnnxDomain, 18};
27372761

27382762
{
27392763
TensorInfo x_info({2, 4}, true);
@@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) {
28032827
TEST(GradientCheckerTest, ScatterNDGrad) {
28042828
float max_error;
28052829
GradientChecker<float, float, float> gradient_checker;
2806-
OpDef op_def{"ScatterND", kOnnxDomain, 11};
2830+
OpDef op_def{"ScatterND", kOnnxDomain, 18};
28072831

28082832
{
28092833
TensorInfo data_info({8}, true);
@@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) {
28872911
TEST(GradientCheckerTest, ScatterElementsGrad) {
28882912
float max_error;
28892913
GradientChecker<float, float, float> gradient_checker;
2890-
OpDef op_def{"ScatterElements", kOnnxDomain, 13};
2914+
OpDef op_def{"ScatterElements", kOnnxDomain, 18};
28912915

28922916
{ // without axis
28932917
TensorInfo data_info({3, 3}, true);

0 commit comments

Comments
 (0)