Skip to content

Commit 3c07478

Browse files
Samantha Andowzou3519
authored andcommitted
[functorch] add layer norm support, clean up some binary cross entropy support (pytorch/functorch#807)
* add layer norm support, clean up some binary cross entropy support * zero returns have the same shape as their input
1 parent b33262d commit 3c07478

File tree

4 files changed

+129
-14
lines changed

4 files changed

+129
-14
lines changed

functorch/functorch/_src/decompositions.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import Tensor
33
import torch._decomp
4-
from typing import Tuple
4+
from typing import Tuple, List, Optional
55

66
aten = torch.ops.aten
77

@@ -21,6 +21,16 @@ def decorator(f):
2121
return decorator
2222

2323

24+
# Functions where we need a special decomposition for jvp but there's another version that
25+
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
26+
# the backwards of a normalization function. Without jvp, it should used the saved value)
27+
decomposition_table_for_jvp = {}
28+
29+
30+
def register_decomposition_for_jvp(fn):
31+
return register_decomposition(fn, registry=decomposition_table_for_jvp)
32+
33+
2434
@maybe_register_decomposition(aten.trace.default)
2535
def trace(self: Tensor) -> Tensor:
2636
return torch.sum(torch.diag(self))
@@ -35,3 +45,86 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
3545
else:
3646
buffer = z
3747
return min - torch.log1p(z), buffer
48+
49+
50+
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
51+
def native_layer_norm_backward(
52+
grad_out: Tensor,
53+
input: Tensor,
54+
normalized_shape: List[int],
55+
mean: Tensor,
56+
rstd: Tensor,
57+
weight: Optional[Tensor],
58+
bias: Optional[Tensor],
59+
output_mask: List[bool],
60+
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
61+
input_shape = input.shape
62+
input_ndim = input.dim()
63+
64+
axis = input_ndim - len(normalized_shape)
65+
inner_dims = input_shape[axis:]
66+
outer_dims = input_shape[:axis]
67+
inner_dim_indices = list(range(axis, input_ndim))
68+
outer_dim_indices = list(range(0, axis))
69+
70+
N = 1
71+
for i in inner_dims:
72+
N *= i
73+
M = 1
74+
for i in outer_dims:
75+
M *= i
76+
if M <= 0 or N <= 0:
77+
return (
78+
input.new_zeros(input_shape),
79+
input.new_zeros(input_shape[axis:]),
80+
input.new_zeros(input_shape[axis:]),
81+
)
82+
83+
# this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84+
# so that they track gradients through input
85+
mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True)
86+
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True)
87+
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
88+
eps = eps.detach()
89+
rstd_ = 1 / torch.sqrt(var + eps)
90+
91+
x_hat = (input - mean_) * rstd_
92+
if weight is not None:
93+
grad_x_hat = grad_out * weight
94+
else:
95+
grad_x_hat = grad_out
96+
a = grad_x_hat * N
97+
b = torch.sum(grad_x_hat, inner_dim_indices, True)
98+
c1 = torch.mul(grad_x_hat, x_hat)
99+
c2 = torch.sum(c1, inner_dim_indices, True)
100+
c3 = torch.mul(x_hat, c2)
101+
inner = a - b - c3
102+
103+
if output_mask[0]:
104+
d_input: Optional[Tensor] = (rstd_ / N) * inner
105+
else:
106+
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
107+
108+
if output_mask[1] and weight is not None:
109+
if len(outer_dim_indices) > 0:
110+
d_weight: Optional[Tensor] = torch.sum(
111+
grad_out * x_hat, outer_dim_indices, False
112+
)
113+
else:
114+
d_weight = grad_out * x_hat
115+
elif weight is not None:
116+
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
117+
else:
118+
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
119+
120+
if output_mask[2] and bias is not None:
121+
if len(outer_dim_indices) > 0:
122+
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
123+
else:
124+
d_bias = grad_out
125+
elif bias is not None:
126+
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
127+
else:
128+
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
129+
130+
return (d_input, d_weight, d_bias)

functorch/functorch/_src/eager_transforms.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.autograd.forward_ad as fwAD
1515

1616
from .vmap import vmap
17-
from .decompositions import decomposition_table
17+
from .decompositions import decomposition_table, decomposition_table_for_jvp
1818

1919

2020
from functorch._C import (
@@ -1276,8 +1276,13 @@ def wrapped(*args, **kwargs):
12761276

12771277

12781278
def _register_jit_decomposition(decomp, use_python=False):
1279-
assert decomp in decomposition_table, f"could not find {decomp}"
1280-
decomp_fn = decomposition_table[decomp]
1279+
if decomp in decomposition_table_for_jvp:
1280+
decomposition_table_used = decomposition_table_for_jvp
1281+
elif decomp in decomposition_table:
1282+
decomposition_table_used = decomposition_table
1283+
else:
1284+
raise RuntimeError(f"could not find decomposition for {decomp}")
1285+
decomp_fn = decomposition_table_used[decomp]
12811286
if use_python:
12821287
decomp_fn = torch.jit.ignore(decomp_fn)
12831288
sig = inspect.signature(decomp_fn)
@@ -1310,3 +1315,4 @@ def get_function_def(sig):
13101315
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
13111316
_register_jit_decomposition(torch.ops.aten.binary_cross_entropy_backward.default)
13121317
_register_jit_decomposition(torch.ops.aten.binary_cross_entropy.default)
1318+
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)

functorch/functorch/csrc/DynamicLayer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
481481
JVP_DECOMP(log_sigmoid_forward);
482482
JVP_DECOMP(binary_cross_entropy);
483483
JVP_DECOMP(binary_cross_entropy_backward);
484+
JVP_DECOMP(native_layer_norm_backward);
484485
}
485486

486487

functorch/test/test_ops.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,6 @@ def test_vjpvmap(self, device, dtype, op):
11481148
xfail('nn.functional.hardswish', ''),
11491149
xfail('nn.functional.huber_loss', ''),
11501150
xfail('nn.functional.instance_norm', ''),
1151-
xfail('nn.functional.layer_norm', ''),
11521151
xfail('nn.functional.logsigmoid', ''),
11531152
xfail('nn.functional.pad', 'circular'),
11541153
xfail('nn.functional.prelu', ''),
@@ -1199,6 +1198,11 @@ def test_jvpvjp(self, device, dtype, op):
11991198
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
12001199
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
12011200

1201+
if isinstance(primals[0], torch.Tensor) and primals[0].numel() == 0:
1202+
# typically the first primal arg is the input. If the input has no elements, we will typically run
1203+
# into an issue of "Expected Tensor but got None"
1204+
continue
1205+
12021206
def push_vjp(primals, cotangents):
12031207
_, vjp_fn = vjp(fn, *primals)
12041208
return vjp_fn(cotangents)
@@ -1228,19 +1232,23 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
12281232
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
12291233
return expected
12301234

1231-
def compare_jacobians(primals, cotangents, in_dims=(0,1)):
1232-
def get_vjp(primals, cotangents):
1235+
def compare_jacobians(cotangents_and_primals, in_dims, atol_rtol):
1236+
def get_vjp(cotangents, *primals):
12331237
_, vjp_fn = vjp(fn, *primals)
12341238
return vjp_fn(cotangents)
12351239

1236-
jacobian_jvp = jacfwd(get_vjp, in_dims)(primals, cotangents)
1237-
jacobian_vjp = jacrev(get_vjp, in_dims)(primals, cotangents)
1240+
jacobian_jvp = jacfwd(get_vjp, in_dims)(*cotangents_and_primals)
1241+
jacobian_vjp = jacrev(get_vjp, in_dims)(*cotangents_and_primals)
12381242

12391243
# For dtype changing operations, the jacobians have different dtype.
12401244
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
12411245
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
12421246

1243-
self.assertEqual(jacobian_jvp, jacobian_vjp)
1247+
if atol_rtol is not None:
1248+
(atol, rtol) = atol_rtol
1249+
self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
1250+
else:
1251+
self.assertEqual(jacobian_jvp, jacobian_vjp)
12441252

12451253
# HACK: obviously pytorch should also have the same coverage
12461254
# For things that do have the same coverage, we test that jvp x vjp
@@ -1255,12 +1263,19 @@ def get_vjp(primals, cotangents):
12551263
'log_softmax',
12561264
'nn.functional.cross_entropy',
12571265
'nn.functional.binary_cross_entropy',
1266+
'nn.functional.layer_norm'
12581267
}
12591268
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
1260-
in_dims = (0, 1)
1261-
if op.name == 'nn.functional.binary_cross_entropy': # reverse second derivative wrt target not defined
1262-
in_dims = 1
1263-
compare_jacobians(primals, cotangents, in_dims)
1269+
def is_differentiable(t):
1270+
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
1271+
args = (cotangents, *primals)
1272+
if op.name == 'nn.functional.binary_cross_entropy':
1273+
in_dims = (0, 1) # targets is float32 but isn't differentiable
1274+
atol_rtol = 1.5E-4, 1.3e-06
1275+
else:
1276+
in_dims = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
1277+
atol_rtol = None
1278+
compare_jacobians(args, in_dims, atol_rtol)
12641279
else:
12651280
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
12661281
self.assertEqual(result, expected)

0 commit comments

Comments
 (0)