Skip to content

Commit 552dbcb

Browse files
Merge pull request #1646 from roboflow/fix/assertions-in-rfdetr-seg-tests
Loose RFDetr Seg integration tests assertions for ONNX which seems to be quite numerically unstable
2 parents fb2a614 + fb0d69f commit 552dbcb

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

inference_experimental/tests/integration_tests/models/test_rfdetr_seg_predictions_onnx.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_package_with_stretch_against_torch_input(
8383
# then
8484
assert len(predictions) == 1
8585
assert np.allclose(
86-
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=1
86+
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=5
8787
)
8888
assert 205000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 207000
8989

@@ -110,11 +110,11 @@ def test_package_with_stretch_against_torch_list_input(
110110
# then
111111
assert len(predictions) == 2
112112
assert np.allclose(
113-
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=1
113+
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=5
114114
)
115115
assert 205000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 207000
116116
assert np.allclose(
117-
predictions[1].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=1
117+
predictions[1].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=5
118118
)
119119
assert 205000 <= np.sum(predictions[1].mask.cpu().numpy()) <= 207000
120120

@@ -141,11 +141,11 @@ def test_package_with_stretch_against_torch_batch_input(
141141
# then
142142
assert len(predictions) == 2
143143
assert np.allclose(
144-
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=1
144+
predictions[0].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=5
145145
)
146146
assert 205000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 207000
147147
assert np.allclose(
148-
predictions[1].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=1
148+
predictions[1].xyxy.cpu().numpy(), np.array([[138, 325, 1262, 556]]), atol=5
149149
)
150150
assert 205000 <= np.sum(predictions[1].mask.cpu().numpy()) <= 207000
151151

@@ -524,7 +524,7 @@ def test_package_with_static_crop_stretch_against_torch_input(
524524
# then
525525
assert len(predictions) == 1
526526
assert np.allclose(
527-
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=1
527+
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=5
528528
)
529529
assert 120000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 122000
530530

@@ -551,11 +551,11 @@ def test_package_with_static_crop_stretch_against_torch_list_input(
551551
# then
552552
assert len(predictions) == 2
553553
assert np.allclose(
554-
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=1
554+
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=5
555555
)
556556
assert 120000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 122000
557557
assert np.allclose(
558-
predictions[1].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=1
558+
predictions[1].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=5
559559
)
560560
assert 120000 <= np.sum(predictions[1].mask.cpu().numpy()) <= 122000
561561

@@ -582,11 +582,11 @@ def test_package_with_static_crop_stretch_against_torch_stack_input(
582582
# then
583583
assert len(predictions) == 2
584584
assert np.allclose(
585-
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=1
585+
predictions[0].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=5
586586
)
587587
assert 120000 <= np.sum(predictions[0].mask.cpu().numpy()) <= 122000
588588
assert np.allclose(
589-
predictions[1].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=1
589+
predictions[1].xyxy.cpu().numpy(), np.array([[321, 331, 963, 561]]), atol=5
590590
)
591591
assert 120000 <= np.sum(predictions[1].mask.cpu().numpy()) <= 122000
592592

0 commit comments

Comments
 (0)