2121# Registry to track all ops with reference implementations
2222_REGISTERED_REF_IMPLEMENTATIONS : set [str ] = set ()
2323
24+ _OUTPUTS_TYPE = torch .Tensor | tuple [torch .Tensor , ...]
25+
2426
2527# Custom impl wrapper that tracks registrations
2628def impl_tracked (
2729 lib : Library , op_name : str
28- ) -> Callable [[Callable [..., torch . Tensor ]], Callable [..., torch . Tensor ]]:
30+ ) -> Callable [[Callable [..., _OUTPUTS_TYPE ]], Callable [..., _OUTPUTS_TYPE ]]:
2931 """Wrapper around impl that tracks registered ops."""
3032 _REGISTERED_REF_IMPLEMENTATIONS .add (op_name )
3133 return impl (lib , op_name )
@@ -312,7 +314,7 @@ def quantized_add_per_tensor(
312314 dequant_Y = Y_scale * (Y - Y_zero_point )
313315
314316 # q_min/q_max are unused args
315- return quantize_per_tensor (
317+ out = quantize_per_tensor (
316318 dequant_X + dequant_Y ,
317319 out_scale ,
318320 out_zero_point ,
@@ -321,6 +323,9 @@ def quantized_add_per_tensor(
321323 dtype ,
322324 )
323325
326+ assert isinstance (out , torch .Tensor )
327+ return out
328+
324329
325330@impl_tracked (m , "quantized_add_asym8sxasym8s_asym8s.per_tensor" )
326331def quantized_add_asym8sxasym8s_asym8s_per_tensor (
@@ -338,9 +343,11 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor(
338343 if Y .dtype != torch .int8 :
339344 raise ValueError ("Y dtype must be torch.int8" )
340345
341- return quantized_add_per_tensor (
346+ out = quantized_add_per_tensor (
342347 X , X_scale , X_zero_point , Y , Y_scale , Y_zero_point , out_scale , out_zero_point
343348 )
349+ assert isinstance (out , torch .Tensor )
350+ return out
344351
345352
346353@impl_tracked (m , "quantized_add_asym8uxasym8u_asym8u.per_tensor" )
@@ -359,9 +366,11 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor(
359366 if Y .dtype != torch .uint8 :
360367 raise ValueError ("Y dtype must be torch.int8" )
361368
362- return quantized_add_per_tensor (
369+ out = quantized_add_per_tensor (
363370 X , X_scale , X_zero_point , Y , Y_scale , Y_zero_point , out_scale , out_zero_point
364371 )
372+ assert isinstance (out , torch .Tensor )
373+ return out
365374
366375
367376def quantized_linear_common (
@@ -407,14 +416,16 @@ def quantized_linear_common(
407416 (weight - weight_zero_point ).float (),
408417 bias .float (),
409418 )
410- return quantize_per_tensor (
419+ out = quantize_per_tensor (
411420 out ,
412421 out_scale ,
413422 out_zero_point ,
414423 torch .iinfo (dtype ).min ,
415424 torch .iinfo (dtype ).max ,
416425 dtype ,
417- ).reshape (* leading_dims , N )
426+ )
427+ assert isinstance (out , torch .Tensor )
428+ return out .reshape (* leading_dims , N )
418429
419430
420431def quantized_linear_variant (
@@ -576,14 +587,16 @@ def quantized_matmul(
576587 (X - X_zero_point ).float (),
577588 (Y - Y_zero_point ).float (),
578589 )
579- return quantize_per_tensor (
590+ out = quantize_per_tensor (
580591 out ,
581592 out_scale ,
582593 out_zero_point ,
583594 torch .iinfo (X .dtype ).min ,
584595 torch .iinfo (X .dtype ).max ,
585596 X .dtype ,
586597 )
598+ assert isinstance (out , torch .Tensor )
599+ return out
587600
588601
589602@impl_tracked (m , "quantized_matmul_asym8sxasym8s_asym8s" )
@@ -603,7 +616,7 @@ def quantized_matmul_asym8sxasym8s_asym8s(
603616 if Y .dtype != torch .int8 :
604617 raise ValueError ("Y dtype must be torch.int8" )
605618
606- return quantized_matmul (
619+ out = quantized_matmul (
607620 X ,
608621 X_zero_point ,
609622 Y ,
@@ -614,6 +627,8 @@ def quantized_matmul_asym8sxasym8s_asym8s(
614627 out_zero_point ,
615628 transposed ,
616629 )
630+ assert isinstance (out , torch .Tensor )
631+ return out
617632
618633
619634@impl_tracked (m , "quantized_matmul_asym8uxasym8u_asym8u" )
@@ -633,7 +648,7 @@ def quantized_matmul_asym8uxasym8u_asym8u(
633648 if Y .dtype != torch .uint8 :
634649 raise ValueError ("Y dtype must be torch.uint8" )
635650
636- return quantized_matmul (
651+ out = quantized_matmul (
637652 X ,
638653 X_zero_point ,
639654 Y ,
@@ -644,6 +659,8 @@ def quantized_matmul_asym8uxasym8u_asym8u(
644659 out_zero_point ,
645660 transposed ,
646661 )
662+ assert isinstance (out , torch .Tensor )
663+ return out
647664
648665
649666@impl_tracked (m , "quantized_layer_norm.per_tensor" )
@@ -681,18 +698,21 @@ def quantized_layer_norm_per_tensor(
681698 float_input_tensor = dequantize_per_tensor (
682699 input_tensor , X_scale , X_zero_point , - 128 , 127 , input_tensor .dtype
683700 )
701+ assert isinstance (float_input_tensor , torch .Tensor )
684702 out = torch .nn .functional .layer_norm (
685703 float_input_tensor , normalized_shape , weight , bias , eps = eps
686704 )
687705
688- return quantize_per_tensor (
706+ out = quantize_per_tensor (
689707 out ,
690708 output_scale ,
691709 output_zero_point ,
692710 torch .iinfo (input_tensor .dtype ).min ,
693711 torch .iinfo (input_tensor .dtype ).max ,
694712 input_tensor .dtype ,
695713 )
714+ assert isinstance (out , torch .Tensor )
715+ return out
696716
697717
698718def quantized_conv_per_tensor (
@@ -754,14 +774,16 @@ def quantized_conv_per_tensor(
754774 else :
755775 raise ValueError ("Input tensor must be 3D or 4D" )
756776
757- return quantize_per_tensor (
777+ out = quantize_per_tensor (
758778 float_out ,
759779 output_scale ,
760780 output_zero_point ,
761781 torch .iinfo (input_tensor .dtype ).min ,
762782 torch .iinfo (input_tensor .dtype ).max ,
763783 input_tensor .dtype ,
764784 )
785+ assert isinstance (out , torch .Tensor )
786+ return out
765787
766788
767789@impl_tracked (m , "quantized_conv2d_nchw.per_tensor" )
@@ -983,7 +1005,7 @@ def variant(
9831005 # Call the appropriate base function
9841006 match layout :
9851007 case "nchw" :
986- return quantized_conv2d_nchw_per_tensor (
1008+ out = quantized_conv2d_nchw_per_tensor (
9871009 input_tensor ,
9881010 weight ,
9891011 bias ,
@@ -1000,7 +1022,7 @@ def variant(
10001022 out_shift ,
10011023 )
10021024 case "nhwc" :
1003- return quantized_conv2d_nhwc_per_tensor (
1025+ out = quantized_conv2d_nhwc_per_tensor (
10041026 input_tensor ,
10051027 weight ,
10061028 bias ,
@@ -1019,6 +1041,9 @@ def variant(
10191041 case _:
10201042 raise ValueError (f"Unknown layout { layout } " )
10211043
1044+ assert isinstance (out , torch .Tensor )
1045+ return out
1046+
10221047 return variant
10231048
10241049 return decorator
@@ -1293,14 +1318,16 @@ def quantized_relu_common(
12931318 dequantized_X = torch .where (
12941319 X > X_zero_point , X - X_zero_point , torch .zeros_like (X )
12951320 ).to (torch .float32 )
1296- return quantize_per_tensor (
1321+ out = quantize_per_tensor (
12971322 dequantized_X ,
12981323 out_scale ,
12991324 out_zero_point ,
13001325 torch .iinfo (X .dtype ).min ,
13011326 torch .iinfo (X .dtype ).max ,
13021327 X .dtype ,
13031328 )
1329+ assert isinstance (out , torch .Tensor )
1330+ return out
13041331
13051332
13061333def quantized_relu_variant (
@@ -1557,7 +1584,7 @@ def im2row_per_tensor(
15571584 in_zero_point : int ,
15581585 channel_last : bool = False ,
15591586) -> torch .Tensor :
1560- return im2row (
1587+ out = im2row (
15611588 input_tensor ,
15621589 kernel_size ,
15631590 dilation ,
@@ -1566,6 +1593,8 @@ def im2row_per_tensor(
15661593 torch .tensor (in_zero_point , dtype = torch .int32 ),
15671594 channel_last ,
15681595 )
1596+ assert isinstance (out , torch .Tensor )
1597+ return out
15691598
15701599
15711600@impl_tracked (m , "transposed_im2row" )
@@ -1773,3 +1802,15 @@ def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.T
17731802@impl_tracked (m , "idma_wait" )
17741803def idma_wait (src : torch .Tensor , task_num : int = 0 , channel : int = 0 ) -> torch .Tensor :
17751804 return src .clone ()
1805+
1806+
1807+ @impl_tracked (m , "linalg_svd" )
1808+ def linalg_svd (
1809+ A : torch .Tensor ,
1810+ full_matrices : bool = False ,
1811+ compute_uv : bool = True ,
1812+ driver : str | None = None ,
1813+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1814+ assert compute_uv
1815+ U , S , Vh = torch .linalg .svd (A , full_matrices = full_matrices , driver = driver )
1816+ return U .contiguous (), S .contiguous (), Vh .contiguous ()
0 commit comments