|
6 | 6 |
|
7 | 7 | import re |
8 | 8 | from dataclasses import dataclass |
9 | | -from typing import Any, List, Optional, Union |
| 9 | +from typing import List, Optional, Union |
10 | 10 |
|
11 | | -from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( |
| 11 | +from executorch.backends.vulkan.test.op_tests.utils.aten_types import ( |
12 | 12 | AT_INT_ARRAY_REF, |
13 | 13 | AT_SCALAR, |
14 | 14 | AT_TENSOR, |
15 | 15 | AT_TENSOR_LIST, |
16 | 16 | BOOL, |
17 | | - CppTestFileGen, |
18 | 17 | DOUBLE, |
19 | 18 | INT, |
20 | 19 | OPT_AT_DOUBLE_ARRAY_REF, |
|
28 | 27 | OPT_SCALAR_TYPE, |
29 | 28 | STRING, |
30 | 29 | TENSOR_VECTOR, |
31 | | - TestSuite, |
32 | | - TestSuiteGen, |
33 | 30 | THREE_TENSOR_TUPLE, |
34 | 31 | TWO_TENSOR_TUPLE, |
35 | 32 | ) |
| 33 | +from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite |
36 | 34 |
|
37 | 35 | from torchgen.api import cpp |
38 | 36 | from torchgen.api.types import CppSignatureGroup |
39 | | - |
40 | 37 | from torchgen.gen import generate_static_dispatch_backend_call, translate_args |
41 | | - |
42 | 38 | from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature |
43 | 39 | from torchgen.model import NativeFunction, Variant |
44 | 40 |
|
45 | | -################################## |
46 | | -## Custom Test Suite Definition ## |
47 | | -################################## |
48 | | - |
49 | | - |
50 | | -@dataclass |
51 | | -class VkTestSuite(TestSuite): |
52 | | - def __init__(self, input_cases: List[Any]): |
53 | | - super().__init__(input_cases) |
54 | | - self.storage_types: List[str] = ["utils::kTexture3D"] |
55 | | - self.layouts: List[str] = ["utils::kChannelsPacked"] |
56 | | - self.data_gen: str = "make_rand_tensor" |
57 | | - |
58 | | - |
59 | | -########################## |
60 | | -## Code Generator Class ## |
61 | | -########################## |
| 41 | +################################### |
| 42 | +## Compute Graph Code Generation ## |
| 43 | +################################### |
62 | 44 |
|
63 | 45 |
|
64 | 46 | @dataclass |
@@ -105,6 +87,8 @@ def vk_out(self): |
105 | 87 |
|
106 | 88 |
|
107 | 89 | class ComputeGraphGen: |
| 90 | + backend_key = None |
| 91 | + |
108 | 92 | def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): |
109 | 93 | self.op_reg_name = op_reg_name |
110 | 94 | self.f = f |
@@ -230,7 +214,7 @@ def gen_decl(self, fn_name: str, ret_type: str = "void") -> str: |
230 | 214 |
|
231 | 215 | def create_aten_fn_call(self) -> str: |
232 | 216 | func_call = generate_static_dispatch_backend_call( |
233 | | - self.f_sig, self.f, TestSuiteGen.backend_key |
| 217 | + self.f_sig, self.f, ComputeGraphGen.backend_key |
234 | 218 | )[7:].replace("::cpu", "") |
235 | 219 |
|
236 | 220 | return func_call |
@@ -611,147 +595,3 @@ def gen_op_check_fn(self) -> str: |
611 | 595 | op_check_fn += "\n }" |
612 | 596 |
|
613 | 597 | return op_check_fn |
614 | | - |
615 | | - |
616 | | -################################## |
617 | | -## Test Fixture Code Generation ## |
618 | | -################################## |
619 | | - |
620 | | -test_fixture_template = """ |
621 | | -class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, utils::StorageType, utils::GPUMemoryLayout>> {{ |
622 | | - protected: |
623 | | - ComputeGraph* graph; |
624 | | - at::ScalarType test_dtype = at::kFloat; |
625 | | - float rtol = {rtol}; |
626 | | - float atol = {atol}; |
627 | | -
|
628 | | - void SetUp() override {{ |
629 | | - GraphConfig config; |
630 | | - utils::StorageType default_storage_type; |
631 | | - utils::GPUMemoryLayout default_memory_layout; |
632 | | - std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); |
633 | | - config.set_storage_type_override(default_storage_type); |
634 | | - config.set_memory_layout_override(default_memory_layout); |
635 | | - graph = new ComputeGraph(config); |
636 | | -
|
637 | | - if (test_dtype == at::kHalf) {{ |
638 | | - rtol = 1e-2; |
639 | | - atol = 1e-2; |
640 | | - }} |
641 | | - }} |
642 | | -
|
643 | | - void TearDown() override {{ |
644 | | - delete graph; |
645 | | - graph = nullptr; |
646 | | - }} |
647 | | -
|
648 | | - {check_fn} |
649 | | -}}; |
650 | | -""" |
651 | | - |
652 | | - |
653 | | -class VkTestSuiteGen(TestSuiteGen): |
654 | | - def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite): |
655 | | - super().__init__(f, inputs) |
656 | | - self.op_reg_name = op_reg_name |
657 | | - self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def) |
658 | | - |
659 | | - def generate_fixture_cpp(self) -> str: |
660 | | - check_fn = "" |
661 | | - if not self.suite_def.requires_prepack: |
662 | | - check_fn = self.generator.gen_op_check_fn() |
663 | | - |
664 | | - prepacked_check_fn = "" |
665 | | - if self.suite_def.supports_prepack(): |
666 | | - self.generator.should_prepack = True |
667 | | - prepacked_check_fn = self.generator.gen_op_check_fn() |
668 | | - check_fn += "\n\n " |
669 | | - check_fn += prepacked_check_fn |
670 | | - |
671 | | - return test_fixture_template.format( |
672 | | - op_name=self.op_name, |
673 | | - check_fn=check_fn, |
674 | | - rtol=self.suite_def.rtol, |
675 | | - atol=self.suite_def.atol, |
676 | | - ) |
677 | | - |
678 | | - def gen_parameterization(self) -> str: |
679 | | - dtypes = self.suite_def.dtypes |
680 | | - storage_types = self.suite_def.storage_types |
681 | | - layouts = self.suite_def.layouts |
682 | | - |
683 | | - return f""" |
684 | | -INSTANTIATE_TEST_SUITE_P( |
685 | | - Combos_{self.op_name}, |
686 | | - GeneratedOpsTest_{self.op_name}, |
687 | | - ::testing::Combine( |
688 | | - ::testing::Values({', '.join(dtypes)}), |
689 | | - ::testing::Values({', '.join(storage_types)}), |
690 | | - ::testing::Values({', '.join(layouts)}))); |
691 | | - """ |
692 | | - |
693 | | - |
694 | | -############################## |
695 | | -## Test File Code Generation ## |
696 | | -############################### |
697 | | - |
698 | | -preamble_str = """ |
699 | | -#include <executorch/backends/vulkan/runtime/api/api.h> |
700 | | -#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> |
701 | | -#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> |
702 | | -
|
703 | | -#include <tuple> |
704 | | -
|
705 | | -using namespace vkcompute; |
706 | | -using TensorOptions = at::TensorOptions; |
707 | | -
|
708 | | -vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { |
709 | | - switch (at_scalartype) { |
710 | | - case c10::kFloat: |
711 | | - return vkapi::kFloat; |
712 | | - case c10::kHalf: |
713 | | - return vkapi::kHalf; |
714 | | - case c10::kInt: |
715 | | - return vkapi::kInt; |
716 | | - case c10::kLong: |
717 | | - return vkapi::kInt; |
718 | | - case c10::kChar: |
719 | | - return vkapi::kChar; |
720 | | - default: |
721 | | - VK_THROW("Unsupported at::ScalarType!"); |
722 | | - } |
723 | | -} |
724 | | -
|
725 | | -#ifdef USE_VULKAN_FP16_INFERENCE |
726 | | -bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) { |
727 | | -#else |
728 | | -bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) { |
729 | | -#endif |
730 | | - // Skip checking index tensors |
731 | | - if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) { |
732 | | - return true; |
733 | | - } |
734 | | - bool is_close = at::allclose(t1, t2, rtol, atol); |
735 | | - if (!is_close && t1.numel() < 500) { |
736 | | - std::cout << "reference: " << std::endl; |
737 | | - print(t1, 150); |
738 | | - std::cout << std::endl; |
739 | | - std::cout << "vulkan: " << std::endl; |
740 | | - print(t2, 150); |
741 | | - std::cout << std::endl; |
742 | | - } |
743 | | - return is_close; |
744 | | -} |
745 | | -""" |
746 | | - |
747 | | - |
748 | | -class VkCppTestFileGen(CppTestFileGen): |
749 | | - def __init__(self, out_path: str): |
750 | | - super().__init__(out_path) |
751 | | - |
752 | | - def generate_preamble(self) -> str: |
753 | | - return preamble_str |
754 | | - |
755 | | - def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None: |
756 | | - suites_gen = VkTestSuiteGen(op_reg_name, f, all_input_cases) |
757 | | - self.suites_gens.append(suites_gen) |
0 commit comments