@@ -202,6 +202,13 @@ def embedding_2bit(
202202 weight_quant_max : int ,
203203 indices : torch .Tensor ,
204204) -> torch .Tensor :
205+ assert (
206+ weight_quant_min == - 2
207+ ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208+ assert (
209+ weight_quant_max == 1
210+ ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
211+
205212 embedding_weight_checks (weight , weight_scales , weight_zero_points )
206213 group_size = (4 * weight .size (1 )) // (
207214 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -257,6 +264,13 @@ def embedding_2bit_dtype(
257264 indices : torch .Tensor ,
258265 dtype : Optional [torch .dtype ],
259266) -> torch .Tensor :
267+ assert (
268+ weight_quant_min == - 2
269+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270+ assert (
271+ weight_quant_max == 1
272+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
273+
260274 embedding_weight_checks (weight , weight_scales , weight_zero_points )
261275 group_size = (4 * weight .size (1 )) // (
262276 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -334,6 +348,13 @@ def embedding_4bit(
334348 weight_quant_max : int ,
335349 indices : torch .Tensor ,
336350) -> torch .Tensor :
351+ assert (
352+ weight_quant_min == - 8
353+ ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354+ assert (
355+ weight_quant_max == 7
356+ ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
357+
337358 embedding_weight_checks (weight , weight_scales , weight_zero_points )
338359 group_size = (2 * weight .size (1 )) // (
339360 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -387,6 +408,13 @@ def embedding_4bit_dtype(
387408 indices : torch .Tensor ,
388409 dtype : Optional [torch .dtype ],
389410) -> torch .Tensor :
411+ assert (
412+ weight_quant_min == - 8
413+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414+ assert (
415+ weight_quant_max == 7
416+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
417+
390418 embedding_weight_checks (weight , weight_scales , weight_zero_points )
391419 group_size = (2 * weight .size (1 )) // (
392420 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
0 commit comments