Skip to content

Commit a3abe1a

Browse files
mlazospytorchmergebot
authored andcommitted
Add support for bfloat16 atomic adds in fbcode (pytorch#141857)
This adds support for bfloat16 atomic add in fbcode (OSS will have to wait until those changes are upstreamed to triton) Originally I attempted to write inline asm, but the triton API was not flexible enough to support this use case. In the long run the right answer is to implement this properly in OSS triton. relevant issues: * pytorch#137425 in fbcode only * pytorch#97016 Pull Request resolved: pytorch#141857 Approved by: https://github.com/eellison
1 parent d51e6fa commit a3abe1a

File tree

5 files changed

+88
-13
lines changed

5 files changed

+88
-13
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,52 @@ def foo(inp):
14771477
foo_c = torch.compile(foo)
14781478
torch.testing.assert_allclose(foo(inp), foo_c(inp))
14791479

1480+
@unittest.skipIf(
1481+
not config.is_fbcode(),
1482+
"bfloat16 atomic add is only supported in fbcode today #97016",
1483+
)
1484+
@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
1485+
def test_atomic_add_bfloat16(self):
1486+
def f(x, y):
1487+
return torch.index_select(x, 0, y)
1488+
1489+
x = torch.randn(
1490+
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
1491+
)
1492+
y = torch.ones(713268, dtype=torch.int64, device="cuda")
1493+
x_ref = x.clone().detach().requires_grad_(True)
1494+
y_ref = y.clone().detach()
1495+
1496+
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
1497+
fc = FileCheck()
1498+
fc.check("tl.atomic_add")
1499+
fc.run(bw_code)
1500+
1501+
self.assertEqual(f(x_ref, y_ref), out)
1502+
1503+
@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
1504+
@unittest.skipIf(
1505+
config.is_fbcode(),
1506+
"bfloat16 atomic add is supported in fbcode, so we won't fallback",
1507+
)
1508+
def test_index_add_fallback(self):
1509+
def f(x, y):
1510+
return torch.index_select(x, 0, y)
1511+
1512+
x = torch.randn(
1513+
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
1514+
)
1515+
y = torch.ones(713268, dtype=torch.int64, device="cuda")
1516+
x_ref = x.clone().detach().requires_grad_(True)
1517+
y_ref = y.clone().detach()
1518+
1519+
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
1520+
fc = FileCheck()
1521+
fc.check("aten.index_add")
1522+
fc.run(bw_code)
1523+
1524+
self.assertEqual(f(x_ref, y_ref), out)
1525+
14801526
@requires_multigpu()
14811527
def test_not_initializing_wrong_device(self):
14821528
device_stats = torch.cuda.memory_stats("cuda:0")

torch/_inductor/decomposition.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
)
1818
from torch._decomp.decompositions import (
1919
_grid_sampler_2d as decomp_grid_sampler_2d,
20+
_index_add,
2021
pw_cast_for_opmath,
2122
)
2223
from torch._decomp.decompositions_for_rng import extra_random_decomps
2324
from torch._dynamo.utils import counters
25+
from torch._environment import is_fbcode
2426
from torch._higher_order_ops.out_dtype import out_dtype
2527
from torch._inductor.utils import pad_listlike
2628
from torch._prims_common import (
@@ -48,6 +50,7 @@
4850
inductor_decompositions = get_decompositions(
4951
[
5052
aten._adaptive_avg_pool2d_backward,
53+
aten.index_select,
5154
aten.addmv,
5255
aten.arange,
5356
aten.bitwise_and_,
@@ -58,7 +61,6 @@
5861
aten.flip,
5962
aten.gelu,
6063
aten.hardtanh,
61-
aten.index_select,
6264
aten.lcm,
6365
aten.leaky_relu,
6466
aten.linalg_vector_norm,
@@ -101,6 +103,7 @@
101103
aten._softmax_backward_data,
102104
aten.clamp_max,
103105
aten.clamp_min,
106+
aten.index_add, # we conditionally call this decomp
104107
aten.glu, # inductor lowers this directly
105108
aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
106109
aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
@@ -173,6 +176,24 @@ def full(
173176
return NotImplemented
174177

175178

179+
@register_decomposition([aten.index_add])
180+
def index_add(
181+
x: torch.Tensor,
182+
dim: int,
183+
index: torch.Tensor,
184+
tensor: torch.Tensor,
185+
*,
186+
alpha: torch.types.Number = 1,
187+
) -> torch.Tensor:
188+
# If we are not in fbcode and dtype is bfloat16
189+
# fallback to index_add kernel
190+
# see https://github.com/pytorch/pytorch/issues/137425 for details
191+
if not is_fbcode() and x.dtype == torch.bfloat16:
192+
return NotImplemented
193+
else:
194+
return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
195+
196+
176197
# Not really sure how to put this into the main library. PrimTorch wants
177198
# empty_permuted to go to the prim, and typically users don't really want
178199
# to decompose to empty_strided (but inductor is OK with it, because we are

torch/_inductor/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) ->
10861086
), f"{target} is not an OpOverload"
10871087
base_name = target.name().split(".")[0]
10881088
if base_name in FALLBACK_ALLOW_LIST:
1089-
make_fallback(target)
1089+
make_fallback(target, warn=False, override_decomp=True)
10901090
elif config.implicit_fallbacks:
10911091
error = (
10921092
MissingOperatorWithDecomp

torch/_inductor/lowering.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@
7575
from .virtualized import ops, V
7676

7777

78+
# TODO(jansel): we should implement decomps or lowerings for these
79+
# https://github.com/pytorch/torchdynamo/issues/327
80+
FALLBACK_ALLOW_LIST = {
81+
"torchvision::roi_align",
82+
"aten::index_add",
83+
}
84+
7885
log = logging.getLogger(__name__)
7986
lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
8087
# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
@@ -1869,8 +1876,10 @@ def check_skip_condition(node, parent, is_output):
18691876
return check_skip_condition(node, node, is_output=True)
18701877

18711878

1872-
def make_fallback(op, layout_constraint=None, warn=True):
1873-
assert op not in decompositions, f"both a fallback and a decomp for same op: {op}"
1879+
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
1880+
assert (
1881+
op not in decompositions or override_decomp
1882+
), f"both a fallback and a decomp for same op: {op}"
18741883
if (
18751884
warn
18761885
and bool(os.getenv("CI"))
@@ -1880,6 +1889,7 @@ def make_fallback(op, layout_constraint=None, warn=True):
18801889
config.fallback_random
18811890
and op in torch._decomp.decompositions_for_rng.extra_random_decomps
18821891
)
1892+
and not override_decomp
18831893
):
18841894
# Note: 'warn' is holdover from when this was a warning, but for ops that previously
18851895
# set warn=False we do not want a CI error.
@@ -2325,13 +2335,6 @@ def apply_constraint(arg, fx_arg):
23252335
return args, kwargs
23262336

23272337

2328-
# TODO(jansel): we should implement decomps or lowerings for these
2329-
# https://github.com/pytorch/torchdynamo/issues/327
2330-
FALLBACK_ALLOW_LIST = {
2331-
"torchvision::roi_align",
2332-
}
2333-
2334-
23352338
def sdpa_constraint(fx_node, *args, **kwargs):
23362339
# sdpa requires dense last dimension]
23372340

torch/_inductor/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,8 +1929,13 @@ def device_need_guard(device: str):
19291929

19301930

19311931
def needs_fallback_due_to_atomic_add_limitations(dtype):
1932-
# tl.atomic_add does NOT support the following types
1933-
return dtype in {torch.int64, torch.bool, torch.bfloat16}
1932+
# tl.atomic add has bfloat16 support in fbcode
1933+
# but not in OSS https://github.com/pytorch/pytorch/issues/97016
1934+
# we will fallback until the code is upstreamed to OSS
1935+
if config.is_fbcode() and dtype == torch.bfloat16:
1936+
return False
1937+
else:
1938+
return dtype in {torch.int64, torch.bool, torch.bfloat16}
19341939

19351940

19361941
def use_scatter_fallback(

0 commit comments

Comments
 (0)