|
| 1 | +Max-autotune Support on CPU with GEMM Template Tutorial |
| 2 | +============================================================== |
| 3 | + |
| 4 | +**Author**: `Jiong Gong <https://github.com/jgong5>`__, `Leslie Fang <https://github.com/leslie-fang-intel>`__, `Chunyuan Wu <https://github.com/chunyuan-w>`__ |
| 5 | + |
| 6 | +Prerequisites: |
| 7 | +---------------- |
| 8 | +- `torch.compile and TorchInductor concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ |
| 9 | + |
| 10 | +Introduction |
| 11 | +------------ |
| 12 | +``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` profiles multiple implementations of operations at compile time and selects the best-performing one, |
| 13 | +trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations. |
| 14 | +In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries. |
| 15 | +This is similar to the max-autotune mode on CUDA, where implementations from ATen, Triton, and CUTLASS are considered. |
| 16 | + |
| 17 | +We have covered most popular data types, including FP32, BF16, FP16, and INT8, with epilogue fusions for x86 CPUs. |
| 18 | + |
| 19 | +How to activate ``max-autotune`` mode |
| 20 | +------------ |
| 21 | +To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``. |
| 22 | +If you prefer to bypass the tuning process and always use the CPP template implementations, you can configure this via an environment variable: |
| 23 | +``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``. |
| 24 | + |
| 25 | + |
| 26 | +Example code |
| 27 | +------------ |
| 28 | +The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation. |
| 29 | +You could run the example code by setting this environment variable ``export TORCHINDUCTOR_FREEZING=1``. |
| 30 | + |
| 31 | + |
| 32 | +.. code:: python |
| 33 | +
|
| 34 | + import torch |
| 35 | + from torch._inductor import config |
| 36 | + config.trace.log_autotuning_results = True # enable the log of autotuning results |
| 37 | +
|
| 38 | + class M(torch.nn.Module): |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + in_features, |
| 42 | + out_features, |
| 43 | + bias, |
| 44 | + **kwargs, |
| 45 | + ): |
| 46 | + super().__init__() |
| 47 | + self.linear = torch.nn.Linear( |
| 48 | + in_features, |
| 49 | + out_features, |
| 50 | + bias, |
| 51 | + **kwargs, |
| 52 | + ) |
| 53 | + self.relu = torch.nn.ReLU() |
| 54 | +
|
| 55 | + def forward(self, x): |
| 56 | + x = self.linear(x) |
| 57 | + x = self.relu(x) |
| 58 | + return x |
| 59 | +
|
| 60 | + amp_enabled = True |
| 61 | + batch_size = 64 |
| 62 | + in_features = 16 |
| 63 | + out_features = 32 |
| 64 | + bias = True |
| 65 | +
|
| 66 | + x = torch.randn(batch_size, in_features) |
| 67 | + model = M(in_features, out_features, bias) |
| 68 | +
|
| 69 | + with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled): |
| 70 | + compiled = torch.compile(model, mode="max-autotune") # turn on "max-autotune" mode |
| 71 | + y = compiled(x) |
| 72 | +
|
| 73 | +
|
| 74 | +When running the above code snippet, you will see the autotuning result (the performance numbers are for demonstration purposes). |
| 75 | +In this case, CPP template outperforms ATen kernel so that it will be selected. |
| 76 | + |
| 77 | +.. code:: shell |
| 78 | +
|
| 79 | + AUTOTUNE linear_unary(64x16, 32x16, 32) |
| 80 | + cpp_packed_gemm_0 0.2142 ms 100.0% |
| 81 | + _linear_pointwise 0.2441 ms 87.7% |
| 82 | +
|
| 83 | +
|
| 84 | +We could check the generated output code by setting ``export TORCH_LOGS="+output_code"``. |
| 85 | +When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) |
| 86 | +in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` |
| 87 | +(only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. |
| 88 | + |
| 89 | +.. code:: python |
| 90 | +
|
| 91 | + cpp_fused__to_copy_relu_1 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*'], ''' |
| 92 | + |
| 93 | + ... |
| 94 | +
|
| 95 | + template <bool accum> |
| 96 | + inline void kernel_micro_gemm_amx_kernel_32_2( |
| 97 | + AMXState& amx_state, |
| 98 | + const bfloat16* __restrict__ A, |
| 99 | + const bfloat16* __restrict__ B, |
| 100 | + float* __restrict__ C, |
| 101 | + int64_t K, |
| 102 | + int64_t lda, |
| 103 | + int64_t ldb, |
| 104 | + int64_t ldc, |
| 105 | + uint8_t tilecfg_rows |
| 106 | + ) { |
| 107 | + ... |
| 108 | + } |
| 109 | + |
| 110 | + ... |
| 111 | +
|
| 112 | + template <bool accum> |
| 113 | + inline void kernel_micro_gemm( |
| 114 | + AMXState& amx_state, |
| 115 | + const bfloat16* __restrict__ A, |
| 116 | + const bfloat16* __restrict__ B, |
| 117 | + float* __restrict__ C, |
| 118 | + int64_t M, |
| 119 | + int64_t N, |
| 120 | + int64_t K, |
| 121 | + int64_t lda, |
| 122 | + int64_t ldb, |
| 123 | + int64_t ldc |
| 124 | + ) { |
| 125 | + ... |
| 126 | + } |
| 127 | +
|
| 128 | + extern "C" |
| 129 | + void kernel(const bfloat16* X, const bfloat16* W, const bfloat16* inp, bfloat16* Y) |
| 130 | + { |
| 131 | + constexpr int64_t num_threads = 40; |
| 132 | + constexpr int64_t N = 32; |
| 133 | + constexpr int64_t K = 16; |
| 134 | + constexpr int64_t M = static_cast<int64_t>(64L); |
| 135 | + ... |
| 136 | + #pragma omp parallel num_threads(40) |
| 137 | + { |
| 138 | + const int tid = omp_get_thread_num(); |
| 139 | + ... |
| 140 | + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { |
| 141 | + ... |
| 142 | + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { |
| 143 | + ... |
| 144 | + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { |
| 145 | + ... |
| 146 | + for (int64_t nci = nc; nci < nc_block_end; nci++) { |
| 147 | + if (kc == k_block_start) { |
| 148 | + kernel_micro_gemm<static_cast<bool>(false)>( |
| 149 | + ... |
| 150 | + ); |
| 151 | +
|
| 152 | + } else { |
| 153 | + kernel_micro_gemm<static_cast<bool>(true)>( |
| 154 | + ... |
| 155 | + ); |
| 156 | +
|
| 157 | + } |
| 158 | + } |
| 159 | + } |
| 160 | + { |
| 161 | + { |
| 162 | + // Epilogue fusion here for bias and relu |
| 163 | + #pragma GCC ivdep |
| 164 | + for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) |
| 165 | + { |
| 166 | + for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) |
| 167 | + { |
| 168 | + auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(inp + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16)); |
| 169 | + auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); |
| 170 | + auto tmp1 = at::vec::convert<float>(tmp0); |
| 171 | + auto tmp3 = tmp1 + tmp2; |
| 172 | + auto tmp4 = at::vec::convert<bfloat16>(tmp3); |
| 173 | + auto tmp5 = static_cast<float>(0.0); |
| 174 | + auto tmp6 = at::vec::Vectorized<float>(tmp5); |
| 175 | + auto tmp7 = at::vec::maximum(tmp3, tmp6); |
| 176 | + auto tmp8 = at::vec::convert<bfloat16>(tmp7); |
| 177 | + tmp8.store(Y + static_cast<int64_t>(n_start + x1 + (32L*m_start) + (32L*x0)), static_cast<int64_t>(16)); |
| 178 | + } |
| 179 | + |
| 180 | + ... |
| 181 | +
|
| 182 | + } |
| 183 | + } |
| 184 | +
|
| 185 | + } |
| 186 | + } |
| 187 | + } |
| 188 | + ... |
| 189 | + } |
| 190 | + } |
| 191 | + ''') |
| 192 | +
|
| 193 | +Conclusion |
| 194 | +------------ |
| 195 | +In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature and demonstrated |
| 196 | +the generated code of GEMM template. |
| 197 | + |
| 198 | +This feature is in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_. |
0 commit comments