@@ -122,7 +122,21 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122122 input .dtype
123123 ) # cast back to input.dtype
124124 else :
125- c = torch .ops .aten ._weight_int4pack_mm (
125+ # copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126+ def meta__weight_int4pack_mm (x , w , q_group_size , q_scale_and_zeros ):
127+ torch ._check (x .dim () == 2 , lambda : "x must be a 2D tensor" )
128+ torch ._check (w .dim () == 4 , lambda : "w must be a 4D tensor" )
129+ torch ._check (
130+ x .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ],
131+ lambda : f"expected x to be f32/f16/bf16, got { x .dtype } " ,
132+ )
133+ torch ._check (
134+ w .dtype is torch .int32 ,
135+ lambda : f"expected w to be int32, got { w .dtype } " ,
136+ )
137+ return x .new_empty (x .size (0 ), w .size (0 ) * 8 , dtype = x .dtype )
138+
139+ c = meta__weight_int4pack_mm (
126140 input ,
127141 weight_int4pack ,
128142 groupsize ,
@@ -610,10 +624,29 @@ def load_model_and_state_dict(
610624 q , s , z = Q4_0 .unpack (t )
611625 scales_and_zeros = pack_scales_and_zeros (s , z )
612626 q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
613-
627+
614628 if torch .device (device ).type == "cpu" :
615- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
616- q_uint8 , inner_k_tiles
629+ # Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630+ def meta__convert_weight_to_int4pack (w , inner_k_tiles ):
631+ torch ._check (w .dim () == 2 , lambda : "w must be a 2D tensor" )
632+ torch ._check (
633+ w .dtype is torch .uint8 ,
634+ lambda : f"expected w to be uint8, got { w .dtype } " ,
635+ )
636+ n = w .size (0 )
637+ k = w .size (1 ) * 2 # w is [n][k / 2] uint8
638+ return w .new_empty (
639+ (
640+ n // 8 ,
641+ k // (inner_k_tiles * 16 ),
642+ 32 ,
643+ inner_k_tiles // 2 ,
644+ ),
645+ dtype = torch .int32 ,
646+ )
647+
648+ weight_int4pack = meta__convert_weight_to_int4pack (
649+ q_uint8 , inner_k_tiles
617650 )
618651 else :
619652 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
0 commit comments