|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 |
|
3 | | -# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | # |
5 | 5 | # Redistribution and use in source and binary forms, with or without |
6 | 6 | # modification, are permitted provided that the following conditions |
|
32 | 32 |
|
33 | 33 | import os |
34 | 34 | import queue |
| 35 | +import threading |
35 | 36 | import time |
36 | 37 | import unittest |
37 | 38 | from functools import partial |
@@ -606,53 +607,212 @@ def test_wrong_shape(self): |
606 | 607 | class NonDecoupledTest(tu.TestResultCollector): |
607 | 608 | def setUp(self): |
608 | 609 | self.model_name_ = "repeat_int32" |
609 | | - self.input_data = { |
610 | | - "IN": np.array([1], dtype=np.int32), |
611 | | - "DELAY": np.array([0], dtype=np.uint32), |
612 | | - "WAIT": np.array([0], dtype=np.uint32), |
| 610 | + self.data_matrix = [ |
| 611 | + # ("IN", "DELAY", "WAIT") |
| 612 | + ([1], [0], [0]), |
| 613 | + ([1], [4000], [2000]), |
| 614 | + ([1], [2000], [4000]), |
| 615 | + ] |
| 616 | + |
| 617 | + # For grpc async infer test |
| 618 | + self.callback_error = None |
| 619 | + self.callback_result = None |
| 620 | + self.callback_invoked_event = threading.Event() |
| 621 | + |
| 622 | + def _input_data(self, in_value, delay_value, wait_value): |
| 623 | + return { |
| 624 | + "IN": np.array(in_value, dtype=np.int32), |
| 625 | + "DELAY": np.array(delay_value, dtype=np.uint32), |
| 626 | + "WAIT": np.array(wait_value, dtype=np.uint32), |
613 | 627 | } |
614 | 628 |
|
| 629 | + def _async_callback(self, result, error): |
| 630 | + """Callback for async_infer.""" |
| 631 | + self.callback_error = error |
| 632 | + self.callback_result = result |
| 633 | + self.callback_invoked_event.set() |
| 634 | + |
615 | 635 | def test_grpc(self): |
616 | | - inputs = [ |
617 | | - grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
618 | | - self.input_data["IN"] |
619 | | - ), |
620 | | - grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
621 | | - self.input_data["DELAY"] |
622 | | - ), |
623 | | - grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
624 | | - self.input_data["WAIT"] |
625 | | - ), |
626 | | - ] |
| 636 | + for in_value, delay_value, wait_value in self.data_matrix: |
| 637 | + with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value): |
| 638 | + input_data = self._input_data(in_value, delay_value, wait_value) |
| 639 | + inputs = [ |
| 640 | + grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
| 641 | + input_data["IN"] |
| 642 | + ), |
| 643 | + grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
| 644 | + input_data["DELAY"] |
| 645 | + ), |
| 646 | + grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
| 647 | + input_data["WAIT"] |
| 648 | + ), |
| 649 | + ] |
| 650 | + |
| 651 | + triton_client = grpcclient.InferenceServerClient( |
| 652 | + url="localhost:8001", verbose=True |
| 653 | + ) |
627 | 654 |
|
628 | | - triton_client = grpcclient.InferenceServerClient( |
629 | | - url="localhost:8001", verbose=True |
630 | | - ) |
631 | | - # Expect the inference is successful |
632 | | - res = triton_client.infer(model_name=self.model_name_, inputs=inputs) |
633 | | - self.assertEqual(1, res.as_numpy("OUT")[0]) |
634 | | - self.assertEqual(0, res.as_numpy("IDX")[0]) |
| 655 | + # Expect the inference is successful |
| 656 | + res = triton_client.infer(model_name=self.model_name_, inputs=inputs) |
| 657 | + self.assertEqual(1, res.as_numpy("OUT")[0]) |
| 658 | + self.assertEqual(0, res.as_numpy("IDX")[0]) |
635 | 659 |
|
636 | 660 | def test_http(self): |
637 | | - inputs = [ |
638 | | - httpclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
639 | | - self.input_data["IN"] |
640 | | - ), |
641 | | - httpclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
642 | | - self.input_data["DELAY"] |
643 | | - ), |
644 | | - httpclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
645 | | - self.input_data["WAIT"] |
646 | | - ), |
| 661 | + for in_value, delay_value, wait_value in self.data_matrix: |
| 662 | + with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value): |
| 663 | + input_data = self._input_data(in_value, delay_value, wait_value) |
| 664 | + inputs = [ |
| 665 | + httpclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
| 666 | + input_data["IN"] |
| 667 | + ), |
| 668 | + httpclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
| 669 | + input_data["DELAY"] |
| 670 | + ), |
| 671 | + httpclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
| 672 | + input_data["WAIT"] |
| 673 | + ), |
| 674 | + ] |
| 675 | + |
| 676 | + triton_client = httpclient.InferenceServerClient( |
| 677 | + url="localhost:8000", verbose=True |
| 678 | + ) |
| 679 | + |
| 680 | + # Expect the inference is successful |
| 681 | + res = triton_client.infer(model_name=self.model_name_, inputs=inputs) |
| 682 | + self.assertEqual(1, res.as_numpy("OUT")[0]) |
| 683 | + self.assertEqual(0, res.as_numpy("IDX")[0]) |
| 684 | + |
| 685 | + def test_grpc_async(self): |
| 686 | + for in_value, delay_value, wait_value in self.data_matrix: |
| 687 | + with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value): |
| 688 | + input_data = self._input_data(in_value, delay_value, wait_value) |
| 689 | + inputs = [ |
| 690 | + grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
| 691 | + input_data["IN"] |
| 692 | + ), |
| 693 | + grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
| 694 | + input_data["DELAY"] |
| 695 | + ), |
| 696 | + grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
| 697 | + input_data["WAIT"] |
| 698 | + ), |
| 699 | + ] |
| 700 | + |
| 701 | + triton_client = grpcclient.InferenceServerClient( |
| 702 | + url="localhost:8001", |
| 703 | + verbose=True, |
| 704 | + ) |
| 705 | + |
| 706 | + # Clear previous results |
| 707 | + self.callback_error = None |
| 708 | + self.callback_result = None |
| 709 | + self.callback_invoked_event.clear() |
| 710 | + |
| 711 | + try: |
| 712 | + triton_client.async_infer( |
| 713 | + model_name=self.model_name_, |
| 714 | + inputs=inputs, |
| 715 | + callback=self._async_callback, |
| 716 | + ) |
| 717 | + except Exception as e: |
| 718 | + self.fail(f"Failed to initiate async_infer: {e}") |
| 719 | + continue |
| 720 | + |
| 721 | + # Wait for the callback to be invoked, with a timeout |
| 722 | + self.assertTrue( |
| 723 | + self.callback_invoked_event.wait(timeout=10), |
| 724 | + "Callback not invoked within timeout.", |
| 725 | + ) |
| 726 | + |
| 727 | + # Expect the inference is successful |
| 728 | + self.assertIsNone( |
| 729 | + self.callback_error, f"Inference failed: {self.callback_error}" |
| 730 | + ) |
| 731 | + self.assertIsNotNone(self.callback_result, "Inference result is None.") |
| 732 | + self.assertEqual(1, self.callback_result.as_numpy("OUT")[0]) |
| 733 | + self.assertEqual(0, self.callback_result.as_numpy("IDX")[0]) |
| 734 | + |
| 735 | + # Wait and check server/model health |
| 736 | + time.sleep(5) |
| 737 | + self.assertTrue(triton_client.is_model_ready(self.model_name_)) |
| 738 | + |
| 739 | + def test_grpc_async_cancel(self): |
| 740 | + data_matrix = [ |
| 741 | + # ("IN", "DELAY", "WAIT") |
| 742 | + ([1], [4000], [2000]), |
| 743 | + ([1], [2000], [4000]), |
647 | 744 | ] |
648 | 745 |
|
649 | | - triton_client = httpclient.InferenceServerClient( |
650 | | - url="localhost:8000", verbose=True |
651 | | - ) |
652 | | - # Expect the inference is successful |
653 | | - res = triton_client.infer(model_name=self.model_name_, inputs=inputs) |
654 | | - self.assertEqual(1, res.as_numpy("OUT")[0]) |
655 | | - self.assertEqual(0, res.as_numpy("IDX")[0]) |
| 746 | + for in_value, delay_value, wait_value in data_matrix: |
| 747 | + with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value): |
| 748 | + input_data = self._input_data(in_value, delay_value, wait_value) |
| 749 | + inputs = [ |
| 750 | + grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy( |
| 751 | + input_data["IN"] |
| 752 | + ), |
| 753 | + grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy( |
| 754 | + input_data["DELAY"] |
| 755 | + ), |
| 756 | + grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy( |
| 757 | + input_data["WAIT"] |
| 758 | + ), |
| 759 | + ] |
| 760 | + |
| 761 | + triton_client = grpcclient.InferenceServerClient( |
| 762 | + url="localhost:8001", |
| 763 | + verbose=True, |
| 764 | + ) |
| 765 | + |
| 766 | + # Clear previous results |
| 767 | + self.callback_error = None |
| 768 | + self.callback_result = None |
| 769 | + self.callback_invoked_event.clear() |
| 770 | + |
| 771 | + request_handle = None |
| 772 | + try: |
| 773 | + request_handle = triton_client.async_infer( |
| 774 | + model_name=self.model_name_, |
| 775 | + inputs=inputs, |
| 776 | + callback=self._async_callback, |
| 777 | + ) |
| 778 | + except Exception as e: |
| 779 | + self.fail(f"Failed to initiate async_infer: {e}") |
| 780 | + continue |
| 781 | + |
| 782 | + # Allow request to be fully initiated |
| 783 | + time.sleep(0.5) |
| 784 | + |
| 785 | + # Attempt to cancel the request |
| 786 | + if request_handle: |
| 787 | + try: |
| 788 | + request_handle.cancel() |
| 789 | + except Exception as e: |
| 790 | + self.fail(f"Error calling request_handle.cancel(): {e}") |
| 791 | + continue |
| 792 | + else: |
| 793 | + self.fail("Invalid request_handle, cannot cancel.") |
| 794 | + continue |
| 795 | + |
| 796 | + # Wait for the callback to be invoked |
| 797 | + self.assertTrue( |
| 798 | + self.callback_invoked_event.wait(timeout=10), |
| 799 | + "Callback not invoked within timeout after cancellation.", |
| 800 | + ) |
| 801 | + |
| 802 | + # Expect the inference is failed |
| 803 | + self.assertIsInstance( |
| 804 | + self.callback_error, |
| 805 | + InferenceServerException, |
| 806 | + f"Unexpected error type: {type(self.callback_error)}", |
| 807 | + ) |
| 808 | + self.assertIn( |
| 809 | + "StatusCode.CANCELLED", |
| 810 | + self.callback_error.status(), |
| 811 | + ) |
| 812 | + |
| 813 | + # Wait and check server/model health |
| 814 | + time.sleep(5) |
| 815 | + self.assertTrue(triton_client.is_model_ready(self.model_name_)) |
656 | 816 |
|
657 | 817 |
|
658 | 818 | if __name__ == "__main__": |
|
0 commit comments