@@ -122,21 +122,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122122 input .dtype
123123 ) # cast back to input.dtype
124124 else :
125- # copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126- def meta__weight_int4pack_mm (x , w , q_group_size , q_scale_and_zeros ):
127- torch ._check (x .dim () == 2 , lambda : "x must be a 2D tensor" )
128- torch ._check (w .dim () == 4 , lambda : "w must be a 4D tensor" )
129- torch ._check (
130- x .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ],
131- lambda : f"expected x to be f32/f16/bf16, got { x .dtype } " ,
132- )
133- torch ._check (
134- w .dtype is torch .int32 ,
135- lambda : f"expected w to be int32, got { w .dtype } " ,
136- )
137- return x .new_empty (x .size (0 ), w .size (0 ) * 8 , dtype = x .dtype )
138-
139- c = meta__weight_int4pack_mm (
125+ c = torch .ops .aten ._weight_int4pack_mm_for_cpu (
140126 input ,
141127 weight_int4pack ,
142128 groupsize ,
@@ -626,27 +612,10 @@ def load_model_and_state_dict(
626612 q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
627613
628614 if torch .device (device ).type == "cpu" :
629- # Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630- def meta__convert_weight_to_int4pack (w , inner_k_tiles ):
631- torch ._check (w .dim () == 2 , lambda : "w must be a 2D tensor" )
632- torch ._check (
633- w .dtype is torch .uint8 ,
634- lambda : f"expected w to be uint8, got { w .dtype } " ,
615+ weight_int4pack = (
616+ torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
617+ q , inner_k_tiles
635618 )
636- n = w .size (0 )
637- k = w .size (1 ) * 2 # w is [n][k / 2] uint8
638- return w .new_empty (
639- (
640- n // 8 ,
641- k // (inner_k_tiles * 16 ),
642- 32 ,
643- inner_k_tiles // 2 ,
644- ),
645- dtype = torch .int32 ,
646- )
647-
648- weight_int4pack = meta__convert_weight_to_int4pack (
649- q_uint8 , inner_k_tiles
650619 )
651620 else :
652621 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
0 commit comments