Skip to content

Commit f2c7213

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Added support for quantized_softmax in ref_implementations (pytorch#15426)
Summary: Add support for quantized_softmax ref implementation Reviewed By: DrJessop Differential Revision: D85188129
1 parent 94def70 commit f2c7213

File tree

3 files changed

+226
-21
lines changed

3 files changed

+226
-21
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,36 +49,16 @@ def _validate_ref_impl_exists() -> None:
4949
"cadence::roi_align_box_processor",
5050
}
5151

52-
# All of these should either
53-
# 1. be removed
54-
# 2. have a reference implementation added to ref_implementations.py
55-
_WARN_ONLY = {
56-
"cadence::quantized_softmax.per_tensor",
57-
"cadence::quantized_softmax",
58-
}
59-
6052
ref_impls = get_registered_ref_implementations()
61-
warn_impls = []
6253
error_impls = []
6354
for op_name in _REGISTERED_META_KERNELS:
6455
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
6556
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name
6657

6758
if op_name_clean not in ref_impls:
68-
if op_name in _WARN_ONLY:
69-
warn_impls.append(op_name)
70-
elif op_name not in _SKIP_OPS:
59+
if op_name not in _SKIP_OPS:
7160
error_impls.append(op_name)
7261

73-
if warn_impls:
74-
warn_msg = (
75-
f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n"
76-
+ "\n".join(f" - {op}" for op in warn_impls)
77-
+ "\n\nPlease add reference implementations in ref_implementations.py using "
78-
+ "@impl_tracked(m, '<op_name>')."
79-
)
80-
logging.warning(warn_msg)
81-
8262
if error_impls:
8363
error_msg = (
8464
f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n"

backends/cadence/aot/ref_implementations.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,3 +2054,94 @@ def softmax_f32_f32(
20542054
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
20552055
assert not half_to_float, "half_to_float is not supported"
20562056
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)
2057+
2058+
2059+
def quantized_softmax_per_tensor_common(
2060+
input_tensor: torch.Tensor,
2061+
mask: torch.Tensor | None,
2062+
dim: int,
2063+
in_scale: float,
2064+
in_zero_point: int,
2065+
out_scale: float,
2066+
out_zero_point: int,
2067+
) -> torch.Tensor:
2068+
"""
2069+
Quantized softmax operation.
2070+
2071+
Args:
2072+
- input_tensor (Tensor): The quantized input tensor
2073+
- mask (Tensor): Mask tensor
2074+
- dim (int): The dimension along which softmax is computed
2075+
- in_scale (float): The scale of the input quantization
2076+
- in_zero_point (int): The zero point of the input quantization
2077+
- out_scale (float): The scale of the output quantization
2078+
- out_zero_point (int): The zero point of the output quantization
2079+
"""
2080+
#TODO: T228751479 - Add support for mask parameter in softmax
2081+
assert mask is None
2082+
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
2083+
if input_tensor.dtype not in supported_dtypes:
2084+
raise ValueError(
2085+
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
2086+
)
2087+
2088+
float_input_tensor = dequantize_per_tensor(
2089+
input_tensor,
2090+
in_scale,
2091+
in_zero_point,
2092+
torch.iinfo(input_tensor.dtype).min,
2093+
torch.iinfo(input_tensor.dtype).max,
2094+
input_tensor.dtype,
2095+
)
2096+
2097+
softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim)
2098+
2099+
return quantize_per_tensor(
2100+
softmax_output,
2101+
out_scale,
2102+
out_zero_point,
2103+
torch.iinfo(input_tensor.dtype).min,
2104+
torch.iinfo(input_tensor.dtype).max,
2105+
input_tensor.dtype,
2106+
)
2107+
2108+
@impl_tracked(m, "quantized_softmax.per_tensor")
2109+
def quantized_softmax_per_tensor(
2110+
input_tensor: torch.Tensor,
2111+
mask: torch.Tensor | None,
2112+
dim: int,
2113+
in_scale: float,
2114+
in_zero_point: int,
2115+
out_scale: float,
2116+
out_zero_point: int,
2117+
) -> torch.Tensor:
2118+
return quantized_softmax_per_tensor_common(
2119+
input_tensor,
2120+
mask,
2121+
dim,
2122+
in_scale,
2123+
in_zero_point,
2124+
out_scale,
2125+
out_zero_point,
2126+
)
2127+
2128+
2129+
@impl_tracked(m, "quantized_softmax")
2130+
def quantized_softmax(
2131+
input_tensor: torch.Tensor,
2132+
mask: torch.Tensor | None,
2133+
dim: int,
2134+
in_scale: torch.Tensor,
2135+
in_zero_point: torch.Tensor,
2136+
out_scale: float,
2137+
out_zero_point: int,
2138+
) -> torch.Tensor:
2139+
return quantized_softmax_per_tensor_common(
2140+
input_tensor,
2141+
mask,
2142+
dim,
2143+
float(in_scale.item()),
2144+
int(in_zero_point.item()),
2145+
out_scale,
2146+
out_zero_point,
2147+
)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,3 +3079,137 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
30793079
self.assertIn(
30803080
"Hidden dimension must be a multiple of 4", str(context.exception)
30813081
)
3082+
3083+
@expand(
3084+
[
3085+
(
3086+
"basic_int8_dim_1",
3087+
torch.tensor([[10, 20, 30]], dtype=torch.int8),
3088+
None,
3089+
1,
3090+
0.1,
3091+
0,
3092+
0.004,
3093+
0,
3094+
torch.int8,
3095+
torch.tensor([[23, 61, 127]], dtype=torch.int8),
3096+
),
3097+
(
3098+
"uint8_with_zero_points",
3099+
torch.tensor([[128, 130, 132]], dtype=torch.uint8),
3100+
None,
3101+
1,
3102+
0.1,
3103+
128,
3104+
0.004,
3105+
128,
3106+
torch.uint8,
3107+
torch.tensor([[195, 210, 228]], dtype=torch.uint8),
3108+
),
3109+
(
3110+
"basic_int16",
3111+
torch.tensor([[100, 200, 300]], dtype=torch.int16),
3112+
None,
3113+
1,
3114+
0.01,
3115+
0,
3116+
0.004,
3117+
0,
3118+
torch.int16,
3119+
torch.tensor([[23, 61, 166]], dtype=torch.int16),
3120+
),
3121+
(
3122+
"multi_row_int8",
3123+
torch.tensor(
3124+
[[10, 20, 30], [5, 10, 15]], dtype=torch.int8
3125+
),
3126+
None,
3127+
1,
3128+
0.1,
3129+
0,
3130+
0.004,
3131+
0,
3132+
torch.int8,
3133+
torch.tensor(
3134+
[[23, 61, 127], [47, 77, 127]], dtype=torch.int8
3135+
),
3136+
),
3137+
(
3138+
"softmax_dim_0",
3139+
torch.tensor([[10, 20], [30, 40]], dtype=torch.int8),
3140+
None,
3141+
0,
3142+
0.1,
3143+
0,
3144+
0.004,
3145+
0,
3146+
torch.int8,
3147+
torch.tensor([[30, 30], [127, 127]], dtype=torch.int8),
3148+
),
3149+
]
3150+
)
3151+
def test_quantized_softmax_per_tensor(
3152+
self,
3153+
name: str,
3154+
input_tensor: torch.Tensor,
3155+
mask: torch.Tensor | None,
3156+
dim: int,
3157+
in_scale: float,
3158+
in_zero_point: int,
3159+
out_scale: float,
3160+
out_zero_point: int,
3161+
dtype: torch.dtype,
3162+
expected_output: torch.Tensor,
3163+
) -> None:
3164+
output = torch.ops.cadence.quantized_softmax.per_tensor(
3165+
input_tensor,
3166+
mask,
3167+
dim,
3168+
in_scale,
3169+
in_zero_point,
3170+
out_scale,
3171+
out_zero_point,
3172+
)
3173+
3174+
# Verify output properties
3175+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype} in {name}")
3176+
self.assertEqual(
3177+
output.shape,
3178+
input_tensor.shape,
3179+
f"Output shape should match input shape in {name}",
3180+
)
3181+
3182+
# Verify output matches expected values (allowing for small quantization errors)
3183+
# For softmax, we expect outputs to be in [0, 1] range when dequantized
3184+
self.assertTrue(
3185+
torch.allclose(
3186+
output.to(torch.float32),
3187+
expected_output.to(torch.float32),
3188+
rtol=0.05,
3189+
atol=5.0,
3190+
),
3191+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
3192+
)
3193+
3194+
def test_quantized_softmax(self) -> None:
3195+
# Test quantized_softmax (default variant with tensor scale/zero_point)
3196+
input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8)
3197+
in_scale = torch.tensor([0.1])
3198+
in_zero_point = torch.tensor([0])
3199+
output = torch.ops.cadence.quantized_softmax(
3200+
input_tensor,
3201+
None, # mask
3202+
1, # dim
3203+
in_scale,
3204+
in_zero_point,
3205+
0.004, # out_scale
3206+
0, # out_zero_point
3207+
)
3208+
3209+
# Verify output properties
3210+
self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8")
3211+
self.assertEqual(
3212+
output.shape,
3213+
input_tensor.shape,
3214+
"Output shape should match input shape",
3215+
)

0 commit comments

Comments
 (0)