1+ import contextlib
2+ import enum
3+ import torch
4+ from torch .utils import _pytree as pytree
5+
6+ # enum class CastPolicy : uint8_t {
7+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
8+ # // running the op. Currently, lower_precision_fp is
9+ # // fp16 for AutocastCUDA, and is defined by user
10+ # // (default bf16) for AutocastCPU or other device.
11+ # fp32, // Cast all inputs to at::kFloat before running the op.
12+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
13+ # // 1. we'd like to run in fp32 and
14+ # // 2. have a std::optional<ScalarType> arg that controls
15+ # // the output type.
16+ # // fp32_set_opt_dtype wrappers' policy is: if the output
17+ # // type is already set, don't touch it, otherwise, set
18+ # // it to at::kFloat.
19+ # fp32_append_dtype, // Treats functions (like norm) that
20+ # // 1. we'd like to run in fp32 and
21+ # // 2. have some overloads that accept an output type and
22+ # // other overloads that don't.
23+ # // fp32_append_dtype wrappers wrap the overloads that don't
24+ # // have an output dtype.
25+ # // The wrapper policy is: append at::kFloat to the args,
26+ # // and redispatch to the type-aware overload.
27+ # promote, // Run in the widest dtype among several args.
28+ # };
29+ class CastPolicy (enum .Enum ):
30+ LOWER_PRECISION_FP = 0
31+ FP32 = 1
32+ FP32_SET_OPT_DTYPE = 2
33+ FP32_APPEND_DTYPE = 3
34+ PROMOTE = 4
35+
36+
37+ def execute_policy (policy , args , kwargs , target_lower_fp ):
38+ def is_float (a ):
39+ return isinstance (a , torch .Tensor ) and a .is_floating_point ()
40+ match policy :
41+ case CastPolicy .LOWER_PRECISION_FP :
42+ return pytree .tree_map_only (is_float , lambda a : a .to (target_lower_fp ), (args , kwargs ))
43+ case CastPolicy .FP32 :
44+ return pytree .tree_map_only (is_float , lambda a : a .to (torch .float32 ), (args , kwargs ))
45+ case CastPolicy .PROMOTE :
46+ dtypes = set (a .dtype for a in args )
47+ widest = max ((dtype .itemsize , dtype ) for dtype in dtypes )[1 ]
48+ return pytree .tree_map_only (is_float , lambda a : a .to (widest ), (args , kwargs ))
49+ case _:
50+ raise AssertionError (f'Policy { policy } not implemented yet.' )
51+
52+
53+ @contextlib .contextmanager
54+ def autocast (device , dtype = torch .bfloat16 , env = None ):
55+ del device
56+ if env is None :
57+ import torchax
58+ env = torchax .default_env ()
59+ env .autocast_dtype , old = dtype , env .autocast_dtype
60+ yield
61+ env .autocast_dtype = old
62+
63+
64+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
65+ autocast_policy = {
66+ torch .ops .aten .conv1d .default : CastPolicy .LOWER_PRECISION_FP ,
67+ torch .ops .aten .conv1d .padding : CastPolicy .LOWER_PRECISION_FP ,
68+ torch .ops .aten .conv2d .default : CastPolicy .LOWER_PRECISION_FP ,
69+ torch .ops .aten .conv2d .padding : CastPolicy .LOWER_PRECISION_FP ,
70+ torch .ops .aten .conv3d .default : CastPolicy .LOWER_PRECISION_FP ,
71+ torch .ops .aten .conv3d .padding : CastPolicy .LOWER_PRECISION_FP ,
72+ torch .ops .aten .bmm .default : CastPolicy .LOWER_PRECISION_FP ,
73+ torch .ops .aten .mm .default : CastPolicy .LOWER_PRECISION_FP ,
74+ torch .ops .aten .linalg_vecdot .default : CastPolicy .LOWER_PRECISION_FP ,
75+ torch .ops .aten .baddbmm .default : CastPolicy .LOWER_PRECISION_FP ,
76+ torch .ops .aten .addmm .default : CastPolicy .LOWER_PRECISION_FP ,
77+ torch .ops .aten ._addmm_activation .default : CastPolicy .LOWER_PRECISION_FP ,
78+ torch .ops .aten .addbmm .default : CastPolicy .LOWER_PRECISION_FP ,
79+ torch .ops .aten .linear .default : CastPolicy .LOWER_PRECISION_FP ,
80+ torch .ops .aten ._convolution .deprecated : CastPolicy .LOWER_PRECISION_FP ,
81+ torch .ops .aten .matmul .default : CastPolicy .LOWER_PRECISION_FP ,
82+ torch .ops .aten .conv_tbc .default : CastPolicy .LOWER_PRECISION_FP ,
83+ torch .ops .aten .mkldnn_rnn_layer .default : CastPolicy .LOWER_PRECISION_FP ,
84+ torch .ops .aten .conv_transpose1d .default : CastPolicy .LOWER_PRECISION_FP ,
85+ torch .ops .aten .conv_transpose2d .input : CastPolicy .LOWER_PRECISION_FP ,
86+ torch .ops .aten .conv_transpose3d .input : CastPolicy .LOWER_PRECISION_FP ,
87+ torch .ops .aten .prelu .default : CastPolicy .LOWER_PRECISION_FP ,
88+ torch .ops .aten .scaled_dot_product_attention .default : CastPolicy .LOWER_PRECISION_FP ,
89+ torch .ops .aten ._native_multi_head_attention .default : CastPolicy .LOWER_PRECISION_FP ,
90+
91+ # fp32 cast policy
92+ torch .ops .aten .avg_pool3d .default : CastPolicy .FP32 ,
93+ torch .ops .aten .binary_cross_entropy .default : CastPolicy .FP32 ,
94+ torch .ops .aten .grid_sampler .default : CastPolicy .FP32 ,
95+ torch .ops .aten .polar .default : CastPolicy .FP32 ,
96+ torch .ops .aten .prod .default : CastPolicy .FP32 ,
97+ torch .ops .aten .prod .dim_int : CastPolicy .FP32 ,
98+ torch .ops .aten .prod .dim_Dimname : CastPolicy .FP32 ,
99+ torch .ops .aten .quantile .default : CastPolicy .FP32 ,
100+ torch .ops .aten .quantile .scalar : CastPolicy .FP32 ,
101+ torch .ops .aten .nanquantile .default : CastPolicy .FP32 ,
102+ torch .ops .aten .nanquantile .scalar : CastPolicy .FP32 ,
103+ torch .ops .aten .stft .default : CastPolicy .FP32 ,
104+ torch .ops .aten .stft .center : CastPolicy .FP32 ,
105+ torch .ops .aten .cdist .default : CastPolicy .FP32 ,
106+ torch .ops .aten .grid_sampler_2d .default : CastPolicy .FP32 ,
107+ torch .ops .aten ._grid_sampler_2d_cpu_fallback .default : CastPolicy .FP32 ,
108+ torch .ops .aten .grid_sampler_3d .default : CastPolicy .FP32 ,
109+ torch .ops .aten .trace .default : CastPolicy .FP32 ,
110+ torch .ops .aten .view_as_complex .default : CastPolicy .FP32 ,
111+ torch .ops .aten .cholesky .default : CastPolicy .FP32 ,
112+ torch .ops .aten .cholesky_inverse .default : CastPolicy .FP32 ,
113+ torch .ops .aten .cholesky_solve .default : CastPolicy .FP32 ,
114+ torch .ops .aten .inverse .default : CastPolicy .FP32 ,
115+ torch .ops .aten .lu_solve .default : CastPolicy .FP32 ,
116+ torch .ops .aten .orgqr .default : CastPolicy .FP32 ,
117+ torch .ops .aten .ormqr .default : CastPolicy .FP32 ,
118+ torch .ops .aten .pinverse .default : CastPolicy .FP32 ,
119+ torch .ops .aten .max_pool3d .default : CastPolicy .FP32 ,
120+ torch .ops .aten .max_unpool2d .default : CastPolicy .FP32 ,
121+ torch .ops .aten .max_unpool3d .default : CastPolicy .FP32 ,
122+ torch .ops .aten .adaptive_avg_pool3d .default : CastPolicy .FP32 ,
123+ torch .ops .aten .reflection_pad1d .default : CastPolicy .FP32 ,
124+ torch .ops .aten .reflection_pad2d .default : CastPolicy .FP32 ,
125+ torch .ops .aten .replication_pad1d .default : CastPolicy .FP32 ,
126+ torch .ops .aten .replication_pad2d .default : CastPolicy .FP32 ,
127+ torch .ops .aten .replication_pad3d .default : CastPolicy .FP32 ,
128+ torch .ops .aten .mse_loss .default : CastPolicy .FP32 ,
129+ torch .ops .aten .cosine_embedding_loss .default : CastPolicy .FP32 ,
130+ torch .ops .aten .nll_loss .default : CastPolicy .FP32 ,
131+ torch .ops .aten .nll_loss2d .default : CastPolicy .FP32 ,
132+ torch .ops .aten .hinge_embedding_loss .default : CastPolicy .FP32 ,
133+ torch .ops .aten .poisson_nll_loss .default : CastPolicy .FP32 ,
134+ torch .ops .aten .smooth_l1_loss .default : CastPolicy .FP32 ,
135+ torch .ops .aten .cross_entropy_loss .default : CastPolicy .FP32 ,
136+ torch .ops .aten .l1_loss .default : CastPolicy .FP32 ,
137+ torch .ops .aten .huber_loss .default : CastPolicy .FP32 ,
138+ torch .ops .aten .margin_ranking_loss .default : CastPolicy .FP32 ,
139+ torch .ops .aten .soft_margin_loss .default : CastPolicy .FP32 ,
140+ torch .ops .aten .triplet_margin_loss .default : CastPolicy .FP32 ,
141+ torch .ops .aten .multi_margin_loss .default : CastPolicy .FP32 ,
142+ torch .ops .aten .ctc_loss .IntList : CastPolicy .FP32 ,
143+ torch .ops .aten .ctc_loss .Tensor : CastPolicy .FP32 ,
144+ torch .ops .aten .kl_div .default : CastPolicy .FP32 ,
145+ torch .ops .aten .multilabel_margin_loss .default : CastPolicy .FP32 ,
146+ torch .ops .aten .binary_cross_entropy_with_logits .default : CastPolicy .FP32 ,
147+ torch .ops .aten .fft_fft .default : CastPolicy .FP32 ,
148+ torch .ops .aten .fft_ifft .default : CastPolicy .FP32 ,
149+ torch .ops .aten .fft_fft2 .default : CastPolicy .FP32 ,
150+ torch .ops .aten .fft_ifft2 .default : CastPolicy .FP32 ,
151+ torch .ops .aten .fft_fftn .default : CastPolicy .FP32 ,
152+ torch .ops .aten .fft_ifftn .default : CastPolicy .FP32 ,
153+ torch .ops .aten .fft_rfft .default : CastPolicy .FP32 ,
154+ torch .ops .aten .fft_irfft .default : CastPolicy .FP32 ,
155+ torch .ops .aten .fft_rfft2 .default : CastPolicy .FP32 ,
156+ torch .ops .aten .fft_irfft2 .default : CastPolicy .FP32 ,
157+ torch .ops .aten .fft_rfftn .default : CastPolicy .FP32 ,
158+ torch .ops .aten .fft_irfftn .default : CastPolicy .FP32 ,
159+ torch .ops .aten .fft_hfft .default : CastPolicy .FP32 ,
160+ torch .ops .aten .fft_ihfft .default : CastPolicy .FP32 ,
161+ torch .ops .aten .linalg_cond .default : CastPolicy .FP32 ,
162+ torch .ops .aten .linalg_cond .p_str : CastPolicy .FP32 ,
163+ torch .ops .aten .linalg_matrix_rank .default : CastPolicy .FP32 ,
164+ torch .ops .aten .linalg_matrix_rank .tol_tensor : CastPolicy .FP32 ,
165+ torch .ops .aten .linalg_matrix_rank .atol_rtol_tensor : CastPolicy .FP32 ,
166+ torch .ops .aten .linalg_matrix_rank .atol_rtol_float : CastPolicy .FP32 ,
167+ torch .ops .aten .linalg_solve .default : CastPolicy .FP32 ,
168+ torch .ops .aten .linalg_cholesky .default : CastPolicy .FP32 ,
169+ torch .ops .aten .linalg_svdvals .default : CastPolicy .FP32 ,
170+ torch .ops .aten .linalg_eigvals .default : CastPolicy .FP32 ,
171+ torch .ops .aten .linalg_eigvalsh .default : CastPolicy .FP32 ,
172+ torch .ops .aten .linalg_inv .default : CastPolicy .FP32 ,
173+ torch .ops .aten .linalg_householder_product .default : CastPolicy .FP32 ,
174+ torch .ops .aten .linalg_tensorinv .default : CastPolicy .FP32 ,
175+ torch .ops .aten .linalg_tensorsolve .default : CastPolicy .FP32 ,
176+ torch .ops .aten .fake_quantize_per_tensor_affine .default : CastPolicy .FP32 ,
177+ torch .ops .aten .geqrf .default : CastPolicy .FP32 ,
178+ torch .ops .aten ._lu_with_info .default : CastPolicy .FP32 ,
179+ torch .ops .aten .qr .default : CastPolicy .FP32 ,
180+ torch .ops .aten .svd .default : CastPolicy .FP32 ,
181+ torch .ops .aten .triangular_solve .default : CastPolicy .FP32 ,
182+ torch .ops .aten .fractional_max_pool2d .default : CastPolicy .FP32 ,
183+ torch .ops .aten .fractional_max_pool3d .default : CastPolicy .FP32 ,
184+ torch .ops .aten .adaptive_max_pool3d .default : CastPolicy .FP32 ,
185+ torch .ops .aten .multilabel_margin_loss_forward .default : CastPolicy .FP32 ,
186+ torch .ops .aten .linalg_qr .default : CastPolicy .FP32 ,
187+ torch .ops .aten .linalg_cholesky_ex .default : CastPolicy .FP32 ,
188+ torch .ops .aten .linalg_svd .default : CastPolicy .FP32 ,
189+ torch .ops .aten .linalg_eig .default : CastPolicy .FP32 ,
190+ torch .ops .aten .linalg_eigh .default : CastPolicy .FP32 ,
191+ torch .ops .aten .linalg_lstsq .default : CastPolicy .FP32 ,
192+ torch .ops .aten .linalg_inv_ex .default : CastPolicy .FP32 ,
193+
194+ # promote
195+ torch .ops .aten .stack .default : CastPolicy .PROMOTE ,
196+ torch .ops .aten .cat .default : CastPolicy .PROMOTE ,
197+ torch .ops .aten .index_copy .default : CastPolicy .PROMOTE ,
198+ torch .ops .aten .index_copy .dimname : CastPolicy .PROMOTE ,
199+ }
0 commit comments