Skip to content

Commit 89cc066

Browse files
authored
Add autocast feature as torchax.amp.autocast. (#9364)
1 parent f20b89f commit 89cc066

File tree

3 files changed

+380
-0
lines changed

3 files changed

+380
-0
lines changed

torchax/test/test_amp.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
import jax
3+
import jax.numpy as jnp
4+
import torchax
5+
from torchax import interop
6+
import torch
7+
8+
9+
class AutocastTest(unittest.TestCase):
10+
11+
def setUp(self):
12+
self.env = torchax.default_env()
13+
14+
def test_auto_cast_ir(self):
15+
with self.env:
16+
with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env):
17+
a = jax.ShapeDtypeStruct((2, 2), jnp.float32)
18+
b = jax.ShapeDtypeStruct((2, 2), jnp.float32)
19+
ir_text = jax.jit(interop.jax_view(torch.matmul)).lower(a, b).as_text()
20+
self.assertIn('tensor<2x2xbf16>', ir_text)
21+
22+
def test_auto_cast_matmul(self):
23+
with self.env:
24+
a = torch.randn(2, 2, device='jax')
25+
b = torch.randn(2, 2, device='jax')
26+
with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env):
27+
c = a @ b
28+
29+
self.assertEqual(c.dtype, torch.bfloat16)
30+
31+
with torch.autocast('cpu', dtype=torch.bfloat16):
32+
c_cpu = a.cpu() @ b.cpu()
33+
34+
self.assertTrue(torch.allclose(c.cpu(), c_cpu))
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

torchax/torchax/amp.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
import contextlib
2+
import enum
3+
import torch
4+
from torch.utils import _pytree as pytree
5+
6+
7+
# enum class CastPolicy : uint8_t {
8+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
9+
# // running the op. Currently, lower_precision_fp is
10+
# // fp16 for AutocastCUDA, and is defined by user
11+
# // (default bf16) for AutocastCPU or other device.
12+
# fp32, // Cast all inputs to at::kFloat before running the op.
13+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
14+
# // 1. we'd like to run in fp32 and
15+
# // 2. have a std::optional<ScalarType> arg that controls
16+
# // the output type.
17+
# // fp32_set_opt_dtype wrappers' policy is: if the output
18+
# // type is already set, don't touch it, otherwise, set
19+
# // it to at::kFloat.
20+
# fp32_append_dtype, // Treats functions (like norm) that
21+
# // 1. we'd like to run in fp32 and
22+
# // 2. have some overloads that accept an output type and
23+
# // other overloads that don't.
24+
# // fp32_append_dtype wrappers wrap the overloads that don't
25+
# // have an output dtype.
26+
# // The wrapper policy is: append at::kFloat to the args,
27+
# // and redispatch to the type-aware overload.
28+
# promote, // Run in the widest dtype among several args.
29+
# };
30+
class CastPolicy(enum.Enum):
31+
LOWER_PRECISION_FP = 0
32+
FP32 = 1
33+
FP32_SET_OPT_DTYPE = 2
34+
FP32_APPEND_DTYPE = 3
35+
PROMOTE = 4
36+
37+
38+
def execute_policy(policy, args, kwargs, target_lower_fp):
39+
40+
def is_float(a):
41+
return isinstance(a, torch.Tensor) and a.is_floating_point()
42+
match policy:
43+
case CastPolicy.LOWER_PRECISION_FP:
44+
return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
45+
(args, kwargs))
46+
case CastPolicy.FP32:
47+
return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
48+
(args, kwargs))
49+
case CastPolicy.PROMOTE:
50+
dtypes = set(a.dtype for a in args)
51+
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
52+
return pytree.tree_map_only(is_float, lambda a: a.to(widest),
53+
(args, kwargs))
54+
case _:
55+
raise AssertionError(f'Policy {policy} not implemented yet.')
56+
57+
58+
@contextlib.contextmanager
59+
def autocast(device, dtype=torch.bfloat16, env=None):
60+
del device
61+
if env is None:
62+
import torchax
63+
env = torchax.default_env()
64+
env.autocast_dtype, old = dtype, env.autocast_dtype
65+
yield
66+
env.autocast_dtype = old
67+
68+
69+
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
70+
autocast_policy = {
71+
torch.ops.aten.conv1d.default:
72+
CastPolicy.LOWER_PRECISION_FP,
73+
torch.ops.aten.conv1d.padding:
74+
CastPolicy.LOWER_PRECISION_FP,
75+
torch.ops.aten.conv2d.default:
76+
CastPolicy.LOWER_PRECISION_FP,
77+
torch.ops.aten.conv2d.padding:
78+
CastPolicy.LOWER_PRECISION_FP,
79+
torch.ops.aten.conv3d.default:
80+
CastPolicy.LOWER_PRECISION_FP,
81+
torch.ops.aten.conv3d.padding:
82+
CastPolicy.LOWER_PRECISION_FP,
83+
torch.ops.aten.bmm.default:
84+
CastPolicy.LOWER_PRECISION_FP,
85+
torch.ops.aten.mm.default:
86+
CastPolicy.LOWER_PRECISION_FP,
87+
torch.ops.aten.linalg_vecdot.default:
88+
CastPolicy.LOWER_PRECISION_FP,
89+
torch.ops.aten.baddbmm.default:
90+
CastPolicy.LOWER_PRECISION_FP,
91+
torch.ops.aten.addmm.default:
92+
CastPolicy.LOWER_PRECISION_FP,
93+
torch.ops.aten._addmm_activation.default:
94+
CastPolicy.LOWER_PRECISION_FP,
95+
torch.ops.aten.addbmm.default:
96+
CastPolicy.LOWER_PRECISION_FP,
97+
torch.ops.aten.linear.default:
98+
CastPolicy.LOWER_PRECISION_FP,
99+
torch.ops.aten._convolution.deprecated:
100+
CastPolicy.LOWER_PRECISION_FP,
101+
torch.ops.aten.matmul.default:
102+
CastPolicy.LOWER_PRECISION_FP,
103+
torch.ops.aten.conv_tbc.default:
104+
CastPolicy.LOWER_PRECISION_FP,
105+
torch.ops.aten.mkldnn_rnn_layer.default:
106+
CastPolicy.LOWER_PRECISION_FP,
107+
torch.ops.aten.conv_transpose1d.default:
108+
CastPolicy.LOWER_PRECISION_FP,
109+
torch.ops.aten.conv_transpose2d.input:
110+
CastPolicy.LOWER_PRECISION_FP,
111+
torch.ops.aten.conv_transpose3d.input:
112+
CastPolicy.LOWER_PRECISION_FP,
113+
torch.ops.aten.prelu.default:
114+
CastPolicy.LOWER_PRECISION_FP,
115+
torch.ops.aten.scaled_dot_product_attention.default:
116+
CastPolicy.LOWER_PRECISION_FP,
117+
torch.ops.aten._native_multi_head_attention.default:
118+
CastPolicy.LOWER_PRECISION_FP,
119+
120+
# fp32 cast policy
121+
torch.ops.aten.avg_pool3d.default:
122+
CastPolicy.FP32,
123+
torch.ops.aten.binary_cross_entropy.default:
124+
CastPolicy.FP32,
125+
torch.ops.aten.grid_sampler.default:
126+
CastPolicy.FP32,
127+
torch.ops.aten.polar.default:
128+
CastPolicy.FP32,
129+
torch.ops.aten.prod.default:
130+
CastPolicy.FP32,
131+
torch.ops.aten.prod.dim_int:
132+
CastPolicy.FP32,
133+
torch.ops.aten.prod.dim_Dimname:
134+
CastPolicy.FP32,
135+
torch.ops.aten.quantile.default:
136+
CastPolicy.FP32,
137+
torch.ops.aten.quantile.scalar:
138+
CastPolicy.FP32,
139+
torch.ops.aten.nanquantile.default:
140+
CastPolicy.FP32,
141+
torch.ops.aten.nanquantile.scalar:
142+
CastPolicy.FP32,
143+
torch.ops.aten.stft.default:
144+
CastPolicy.FP32,
145+
torch.ops.aten.stft.center:
146+
CastPolicy.FP32,
147+
torch.ops.aten.cdist.default:
148+
CastPolicy.FP32,
149+
torch.ops.aten.grid_sampler_2d.default:
150+
CastPolicy.FP32,
151+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
152+
CastPolicy.FP32,
153+
torch.ops.aten.grid_sampler_3d.default:
154+
CastPolicy.FP32,
155+
torch.ops.aten.trace.default:
156+
CastPolicy.FP32,
157+
torch.ops.aten.view_as_complex.default:
158+
CastPolicy.FP32,
159+
torch.ops.aten.cholesky.default:
160+
CastPolicy.FP32,
161+
torch.ops.aten.cholesky_inverse.default:
162+
CastPolicy.FP32,
163+
torch.ops.aten.cholesky_solve.default:
164+
CastPolicy.FP32,
165+
torch.ops.aten.inverse.default:
166+
CastPolicy.FP32,
167+
torch.ops.aten.lu_solve.default:
168+
CastPolicy.FP32,
169+
torch.ops.aten.orgqr.default:
170+
CastPolicy.FP32,
171+
torch.ops.aten.ormqr.default:
172+
CastPolicy.FP32,
173+
torch.ops.aten.pinverse.default:
174+
CastPolicy.FP32,
175+
torch.ops.aten.max_pool3d.default:
176+
CastPolicy.FP32,
177+
torch.ops.aten.max_unpool2d.default:
178+
CastPolicy.FP32,
179+
torch.ops.aten.max_unpool3d.default:
180+
CastPolicy.FP32,
181+
torch.ops.aten.adaptive_avg_pool3d.default:
182+
CastPolicy.FP32,
183+
torch.ops.aten.reflection_pad1d.default:
184+
CastPolicy.FP32,
185+
torch.ops.aten.reflection_pad2d.default:
186+
CastPolicy.FP32,
187+
torch.ops.aten.replication_pad1d.default:
188+
CastPolicy.FP32,
189+
torch.ops.aten.replication_pad2d.default:
190+
CastPolicy.FP32,
191+
torch.ops.aten.replication_pad3d.default:
192+
CastPolicy.FP32,
193+
torch.ops.aten.mse_loss.default:
194+
CastPolicy.FP32,
195+
torch.ops.aten.cosine_embedding_loss.default:
196+
CastPolicy.FP32,
197+
torch.ops.aten.nll_loss.default:
198+
CastPolicy.FP32,
199+
torch.ops.aten.nll_loss2d.default:
200+
CastPolicy.FP32,
201+
torch.ops.aten.hinge_embedding_loss.default:
202+
CastPolicy.FP32,
203+
torch.ops.aten.poisson_nll_loss.default:
204+
CastPolicy.FP32,
205+
torch.ops.aten.smooth_l1_loss.default:
206+
CastPolicy.FP32,
207+
torch.ops.aten.cross_entropy_loss.default:
208+
CastPolicy.FP32,
209+
torch.ops.aten.l1_loss.default:
210+
CastPolicy.FP32,
211+
torch.ops.aten.huber_loss.default:
212+
CastPolicy.FP32,
213+
torch.ops.aten.margin_ranking_loss.default:
214+
CastPolicy.FP32,
215+
torch.ops.aten.soft_margin_loss.default:
216+
CastPolicy.FP32,
217+
torch.ops.aten.triplet_margin_loss.default:
218+
CastPolicy.FP32,
219+
torch.ops.aten.multi_margin_loss.default:
220+
CastPolicy.FP32,
221+
torch.ops.aten.ctc_loss.IntList:
222+
CastPolicy.FP32,
223+
torch.ops.aten.ctc_loss.Tensor:
224+
CastPolicy.FP32,
225+
torch.ops.aten.kl_div.default:
226+
CastPolicy.FP32,
227+
torch.ops.aten.multilabel_margin_loss.default:
228+
CastPolicy.FP32,
229+
torch.ops.aten.binary_cross_entropy_with_logits.default:
230+
CastPolicy.FP32,
231+
torch.ops.aten.fft_fft.default:
232+
CastPolicy.FP32,
233+
torch.ops.aten.fft_ifft.default:
234+
CastPolicy.FP32,
235+
torch.ops.aten.fft_fft2.default:
236+
CastPolicy.FP32,
237+
torch.ops.aten.fft_ifft2.default:
238+
CastPolicy.FP32,
239+
torch.ops.aten.fft_fftn.default:
240+
CastPolicy.FP32,
241+
torch.ops.aten.fft_ifftn.default:
242+
CastPolicy.FP32,
243+
torch.ops.aten.fft_rfft.default:
244+
CastPolicy.FP32,
245+
torch.ops.aten.fft_irfft.default:
246+
CastPolicy.FP32,
247+
torch.ops.aten.fft_rfft2.default:
248+
CastPolicy.FP32,
249+
torch.ops.aten.fft_irfft2.default:
250+
CastPolicy.FP32,
251+
torch.ops.aten.fft_rfftn.default:
252+
CastPolicy.FP32,
253+
torch.ops.aten.fft_irfftn.default:
254+
CastPolicy.FP32,
255+
torch.ops.aten.fft_hfft.default:
256+
CastPolicy.FP32,
257+
torch.ops.aten.fft_ihfft.default:
258+
CastPolicy.FP32,
259+
torch.ops.aten.linalg_cond.default:
260+
CastPolicy.FP32,
261+
torch.ops.aten.linalg_cond.p_str:
262+
CastPolicy.FP32,
263+
torch.ops.aten.linalg_matrix_rank.default:
264+
CastPolicy.FP32,
265+
torch.ops.aten.linalg_matrix_rank.tol_tensor:
266+
CastPolicy.FP32,
267+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
268+
CastPolicy.FP32,
269+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
270+
CastPolicy.FP32,
271+
torch.ops.aten.linalg_solve.default:
272+
CastPolicy.FP32,
273+
torch.ops.aten.linalg_cholesky.default:
274+
CastPolicy.FP32,
275+
torch.ops.aten.linalg_svdvals.default:
276+
CastPolicy.FP32,
277+
torch.ops.aten.linalg_eigvals.default:
278+
CastPolicy.FP32,
279+
torch.ops.aten.linalg_eigvalsh.default:
280+
CastPolicy.FP32,
281+
torch.ops.aten.linalg_inv.default:
282+
CastPolicy.FP32,
283+
torch.ops.aten.linalg_householder_product.default:
284+
CastPolicy.FP32,
285+
torch.ops.aten.linalg_tensorinv.default:
286+
CastPolicy.FP32,
287+
torch.ops.aten.linalg_tensorsolve.default:
288+
CastPolicy.FP32,
289+
torch.ops.aten.fake_quantize_per_tensor_affine.default:
290+
CastPolicy.FP32,
291+
torch.ops.aten.geqrf.default:
292+
CastPolicy.FP32,
293+
torch.ops.aten._lu_with_info.default:
294+
CastPolicy.FP32,
295+
torch.ops.aten.qr.default:
296+
CastPolicy.FP32,
297+
torch.ops.aten.svd.default:
298+
CastPolicy.FP32,
299+
torch.ops.aten.triangular_solve.default:
300+
CastPolicy.FP32,
301+
torch.ops.aten.fractional_max_pool2d.default:
302+
CastPolicy.FP32,
303+
torch.ops.aten.fractional_max_pool3d.default:
304+
CastPolicy.FP32,
305+
torch.ops.aten.adaptive_max_pool3d.default:
306+
CastPolicy.FP32,
307+
torch.ops.aten.multilabel_margin_loss_forward.default:
308+
CastPolicy.FP32,
309+
torch.ops.aten.linalg_qr.default:
310+
CastPolicy.FP32,
311+
torch.ops.aten.linalg_cholesky_ex.default:
312+
CastPolicy.FP32,
313+
torch.ops.aten.linalg_svd.default:
314+
CastPolicy.FP32,
315+
torch.ops.aten.linalg_eig.default:
316+
CastPolicy.FP32,
317+
torch.ops.aten.linalg_eigh.default:
318+
CastPolicy.FP32,
319+
torch.ops.aten.linalg_lstsq.default:
320+
CastPolicy.FP32,
321+
torch.ops.aten.linalg_inv_ex.default:
322+
CastPolicy.FP32,
323+
324+
# promote
325+
torch.ops.aten.stack.default:
326+
CastPolicy.PROMOTE,
327+
torch.ops.aten.cat.default:
328+
CastPolicy.PROMOTE,
329+
torch.ops.aten.index_copy.default:
330+
CastPolicy.PROMOTE,
331+
torch.ops.aten.index_copy.dimname:
332+
CastPolicy.PROMOTE,
333+
}

0 commit comments

Comments
 (0)