Skip to content

Commit a91324c

Browse files
authored
[BOO] Add an env variable for toggling backward boo convolutions (iree-org#865)
This is temporarily set as defaulting to "do not use backward boo kernels" as our performance on many backward convs is not competitive. --------- Signed-off-by: zjgarvey <[email protected]>
1 parent 5e6baa3 commit a91324c

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

iree/turbine/kernel/boo/ops/conv.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import os
78
from typing import Sequence, Tuple
89

910
import torch
@@ -14,6 +15,8 @@
1415
"boo_conv",
1516
]
1617

18+
BOO_USE_BACKWARD_KERNELS = int(os.getenv("BOO_USE_BACKWARD_KERNELS", "0"))
19+
1720

1821
@torch.library.custom_op("iree_turbine::boo_convolution", mutates_args=())
1922
def boo_convolution(
@@ -163,6 +166,44 @@ def _b(
163166
return input_grad, weight_grad, bias_grad
164167

165168

169+
def pytorch_convolution_backward(ctx, grad_output):
170+
x, w = ctx.saved_tensors
171+
172+
mask = tuple((ctx.needs_input_grad[i] for i in range(3)))
173+
174+
# return to NCHW if necessary
175+
rank = len(x.shape)
176+
perm = [0] + [rank - 1] + list(range(1, rank - 1))
177+
inv_perm = [0] + list(range(2, rank)) + [1]
178+
if ctx.input_layout.endswith("C"):
179+
x = x.permute(perm)
180+
if ctx.kernel_layout.endswith("C"):
181+
w = w.permute(perm)
182+
if ctx.output_layout.endswith("C"):
183+
grad_output = grad_output.permute(perm)
184+
185+
input_grad, weight_grad, bias_grad = torch.ops.aten.convolution_backward(
186+
grad_output,
187+
x,
188+
w,
189+
None,
190+
ctx.stride,
191+
ctx.padding,
192+
ctx.dilation,
193+
False,
194+
[0] * len(ctx.stride),
195+
ctx.groups,
196+
mask,
197+
)
198+
199+
if ctx.input_layout.endswith("C"):
200+
input_grad = input_grad.permute(inv_perm)
201+
if ctx.kernel_layout.endswith("C"):
202+
weight_grad = weight_grad.permute(inv_perm)
203+
# return `None` for attribute args
204+
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None
205+
206+
166207
def boo_convolution_backward(ctx, grad_output):
167208
x, w = ctx.saved_tensors
168209

@@ -216,8 +257,14 @@ def boo_convolution_context(
216257
ctx.use_bias = b is not None
217258

218259

260+
_backward_to_register = (
261+
boo_convolution_backward
262+
if (BOO_USE_BACKWARD_KERNELS)
263+
else pytorch_convolution_backward
264+
)
265+
219266
boo_convolution.register_autograd(
220-
boo_convolution_backward, setup_context=boo_convolution_context
267+
_backward_to_register, setup_context=boo_convolution_context
221268
)
222269

223270

tests/kernel/boo/modeling/boo_conv_2d_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import os
8+
9+
# enable backward boo kernels for testing
10+
os.environ["BOO_USE_BACKWARD_KERNELS"] = "1"
11+
712
import unittest
813
import tempfile
914
from pathlib import Path

tests/kernel/boo/ops/boo_conv_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
3+
# enable backward boo kernels for testing
4+
os.environ["BOO_USE_BACKWARD_KERNELS"] = "1"
5+
16
import unittest
27
import pytest
38
import tempfile

0 commit comments

Comments
 (0)