Skip to content

Commit bd34e74

Browse files
authored
Cadence ops: Support for contiguous svd
Differential Revision: D84676357 Pull Request resolved: pytorch#15142
1 parent c016f29 commit bd34e74

File tree

3 files changed

+87
-16
lines changed

3 files changed

+87
-16
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def _validate_ref_impl_exists() -> None:
6060
"cadence::quantized_softmax.per_tensor",
6161
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
6262
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
63-
"cadence::linalg_svd",
6463
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
6564
"cadence::quantized_softmax",
6665
"cadence::quantized_w8a32_gru",

backends/cadence/aot/ref_implementations.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
# Registry to track all ops with reference implementations
2222
_REGISTERED_REF_IMPLEMENTATIONS: set[str] = set()
2323

24+
_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...]
25+
2426

2527
# Custom impl wrapper that tracks registrations
2628
def impl_tracked(
2729
lib: Library, op_name: str
28-
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
30+
) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]:
2931
"""Wrapper around impl that tracks registered ops."""
3032
_REGISTERED_REF_IMPLEMENTATIONS.add(op_name)
3133
return impl(lib, op_name)
@@ -312,7 +314,7 @@ def quantized_add_per_tensor(
312314
dequant_Y = Y_scale * (Y - Y_zero_point)
313315

314316
# q_min/q_max are unused args
315-
return quantize_per_tensor(
317+
out = quantize_per_tensor(
316318
dequant_X + dequant_Y,
317319
out_scale,
318320
out_zero_point,
@@ -321,6 +323,9 @@ def quantized_add_per_tensor(
321323
dtype,
322324
)
323325

326+
assert isinstance(out, torch.Tensor)
327+
return out
328+
324329

325330
@impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor")
326331
def quantized_add_asym8sxasym8s_asym8s_per_tensor(
@@ -338,9 +343,11 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor(
338343
if Y.dtype != torch.int8:
339344
raise ValueError("Y dtype must be torch.int8")
340345

341-
return quantized_add_per_tensor(
346+
out = quantized_add_per_tensor(
342347
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
343348
)
349+
assert isinstance(out, torch.Tensor)
350+
return out
344351

345352

346353
@impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor")
@@ -359,9 +366,11 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor(
359366
if Y.dtype != torch.uint8:
360367
raise ValueError("Y dtype must be torch.int8")
361368

362-
return quantized_add_per_tensor(
369+
out = quantized_add_per_tensor(
363370
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
364371
)
372+
assert isinstance(out, torch.Tensor)
373+
return out
365374

366375

367376
def quantized_linear_common(
@@ -407,14 +416,16 @@ def quantized_linear_common(
407416
(weight - weight_zero_point).float(),
408417
bias.float(),
409418
)
410-
return quantize_per_tensor(
419+
out = quantize_per_tensor(
411420
out,
412421
out_scale,
413422
out_zero_point,
414423
torch.iinfo(dtype).min,
415424
torch.iinfo(dtype).max,
416425
dtype,
417-
).reshape(*leading_dims, N)
426+
)
427+
assert isinstance(out, torch.Tensor)
428+
return out.reshape(*leading_dims, N)
418429

419430

420431
def quantized_linear_variant(
@@ -576,14 +587,16 @@ def quantized_matmul(
576587
(X - X_zero_point).float(),
577588
(Y - Y_zero_point).float(),
578589
)
579-
return quantize_per_tensor(
590+
out = quantize_per_tensor(
580591
out,
581592
out_scale,
582593
out_zero_point,
583594
torch.iinfo(X.dtype).min,
584595
torch.iinfo(X.dtype).max,
585596
X.dtype,
586597
)
598+
assert isinstance(out, torch.Tensor)
599+
return out
587600

588601

589602
@impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s")
@@ -603,7 +616,7 @@ def quantized_matmul_asym8sxasym8s_asym8s(
603616
if Y.dtype != torch.int8:
604617
raise ValueError("Y dtype must be torch.int8")
605618

606-
return quantized_matmul(
619+
out = quantized_matmul(
607620
X,
608621
X_zero_point,
609622
Y,
@@ -614,6 +627,8 @@ def quantized_matmul_asym8sxasym8s_asym8s(
614627
out_zero_point,
615628
transposed,
616629
)
630+
assert isinstance(out, torch.Tensor)
631+
return out
617632

618633

619634
@impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u")
@@ -633,7 +648,7 @@ def quantized_matmul_asym8uxasym8u_asym8u(
633648
if Y.dtype != torch.uint8:
634649
raise ValueError("Y dtype must be torch.uint8")
635650

636-
return quantized_matmul(
651+
out = quantized_matmul(
637652
X,
638653
X_zero_point,
639654
Y,
@@ -644,6 +659,8 @@ def quantized_matmul_asym8uxasym8u_asym8u(
644659
out_zero_point,
645660
transposed,
646661
)
662+
assert isinstance(out, torch.Tensor)
663+
return out
647664

648665

649666
@impl_tracked(m, "quantized_layer_norm.per_tensor")
@@ -681,18 +698,21 @@ def quantized_layer_norm_per_tensor(
681698
float_input_tensor = dequantize_per_tensor(
682699
input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype
683700
)
701+
assert isinstance(float_input_tensor, torch.Tensor)
684702
out = torch.nn.functional.layer_norm(
685703
float_input_tensor, normalized_shape, weight, bias, eps=eps
686704
)
687705

688-
return quantize_per_tensor(
706+
out = quantize_per_tensor(
689707
out,
690708
output_scale,
691709
output_zero_point,
692710
torch.iinfo(input_tensor.dtype).min,
693711
torch.iinfo(input_tensor.dtype).max,
694712
input_tensor.dtype,
695713
)
714+
assert isinstance(out, torch.Tensor)
715+
return out
696716

697717

698718
def quantized_conv_per_tensor(
@@ -754,14 +774,16 @@ def quantized_conv_per_tensor(
754774
else:
755775
raise ValueError("Input tensor must be 3D or 4D")
756776

757-
return quantize_per_tensor(
777+
out = quantize_per_tensor(
758778
float_out,
759779
output_scale,
760780
output_zero_point,
761781
torch.iinfo(input_tensor.dtype).min,
762782
torch.iinfo(input_tensor.dtype).max,
763783
input_tensor.dtype,
764784
)
785+
assert isinstance(out, torch.Tensor)
786+
return out
765787

766788

767789
@impl_tracked(m, "quantized_conv2d_nchw.per_tensor")
@@ -983,7 +1005,7 @@ def variant(
9831005
# Call the appropriate base function
9841006
match layout:
9851007
case "nchw":
986-
return quantized_conv2d_nchw_per_tensor(
1008+
out = quantized_conv2d_nchw_per_tensor(
9871009
input_tensor,
9881010
weight,
9891011
bias,
@@ -1000,7 +1022,7 @@ def variant(
10001022
out_shift,
10011023
)
10021024
case "nhwc":
1003-
return quantized_conv2d_nhwc_per_tensor(
1025+
out = quantized_conv2d_nhwc_per_tensor(
10041026
input_tensor,
10051027
weight,
10061028
bias,
@@ -1019,6 +1041,9 @@ def variant(
10191041
case _:
10201042
raise ValueError(f"Unknown layout {layout}")
10211043

1044+
assert isinstance(out, torch.Tensor)
1045+
return out
1046+
10221047
return variant
10231048

10241049
return decorator
@@ -1293,14 +1318,16 @@ def quantized_relu_common(
12931318
dequantized_X = torch.where(
12941319
X > X_zero_point, X - X_zero_point, torch.zeros_like(X)
12951320
).to(torch.float32)
1296-
return quantize_per_tensor(
1321+
out = quantize_per_tensor(
12971322
dequantized_X,
12981323
out_scale,
12991324
out_zero_point,
13001325
torch.iinfo(X.dtype).min,
13011326
torch.iinfo(X.dtype).max,
13021327
X.dtype,
13031328
)
1329+
assert isinstance(out, torch.Tensor)
1330+
return out
13041331

13051332

13061333
def quantized_relu_variant(
@@ -1557,7 +1584,7 @@ def im2row_per_tensor(
15571584
in_zero_point: int,
15581585
channel_last: bool = False,
15591586
) -> torch.Tensor:
1560-
return im2row(
1587+
out = im2row(
15611588
input_tensor,
15621589
kernel_size,
15631590
dilation,
@@ -1566,6 +1593,8 @@ def im2row_per_tensor(
15661593
torch.tensor(in_zero_point, dtype=torch.int32),
15671594
channel_last,
15681595
)
1596+
assert isinstance(out, torch.Tensor)
1597+
return out
15691598

15701599

15711600
@impl_tracked(m, "transposed_im2row")
@@ -1773,3 +1802,15 @@ def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.T
17731802
@impl_tracked(m, "idma_wait")
17741803
def idma_wait(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
17751804
return src.clone()
1805+
1806+
1807+
@impl_tracked(m, "linalg_svd")
1808+
def linalg_svd(
1809+
A: torch.Tensor,
1810+
full_matrices: bool = False,
1811+
compute_uv: bool = True,
1812+
driver: str | None = None,
1813+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1814+
assert compute_uv
1815+
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver)
1816+
return U.contiguous(), S.contiguous(), Vh.contiguous()

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,3 +2632,34 @@ def test_quantized_embedding_byte(
26322632
expected_out,
26332633
)
26342634
)
2635+
2636+
@expand(
2637+
[
2638+
*[
2639+
(
2640+
dtype,
2641+
(4, 4),
2642+
full_matrices,
2643+
)
2644+
for dtype in [torch.float32, torch.float64]
2645+
for full_matrices in [True, False]
2646+
]
2647+
]
2648+
)
2649+
def test_linalg_svd_outputs_are_contiguous(
2650+
self,
2651+
dtype: torch.dtype,
2652+
shape: tuple[int, int],
2653+
full_matrices: bool,
2654+
) -> None:
2655+
m, n = shape
2656+
a = torch.eye(m, n, dtype=dtype)
2657+
2658+
U, S, Vh = torch.ops.cadence.linalg_svd(a, full_matrices)
2659+
2660+
self.assertTrue(U.is_contiguous(), "U not contiguous")
2661+
self.assertTrue(S.is_contiguous(), "S not contiguous")
2662+
self.assertTrue(Vh.is_contiguous(), "Vh not contiguous")
2663+
self.assertTrue(U.dtype == dtype, "U dtype mismatch")
2664+
self.assertTrue(S.dtype == dtype, "S dtype mismatch")
2665+
self.assertTrue(Vh.dtype == dtype, "Vh dtype mismatch")

0 commit comments

Comments
 (0)