2424 pack_scales_and_zeros ,
2525)
2626
27+ from torchao .dtypes .utils import is_device
28+
2729
2830logger : logging .Logger = logging .getLogger (__name__ )
2931
@@ -128,6 +130,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
128130 groupsize ,
129131 scales_and_zeros ,
130132 )
133+
131134 new_shape = origin_input_size [:- 1 ] + (out_features ,)
132135 c = c .reshape (new_shape )
133136 return c
@@ -178,16 +181,27 @@ def __init__(
178181 ), "must specify both weights and scales_and_zeros, or neither"
179182
180183 if weight is None :
181- weight = torch .empty (
182- (
183- out_features // 8 ,
184- in_features // (inner_k_tiles * 16 ),
185- 32 ,
186- inner_k_tiles // 2 ,
187- ),
188- dtype = torch .int32 ,
189- device = device ,
190- )
184+ if is_device (device , "cpu" ):
185+ weight = torch .empty (
186+ (
187+ out_features ,
188+ in_features // 2 ,
189+ ),
190+ dtype = torch .uint8 ,
191+ device = device ,
192+ )
193+ else :
194+ weight = torch .empty (
195+ (
196+ out_features // 8 ,
197+ in_features // (inner_k_tiles * 16 ),
198+ 32 ,
199+ inner_k_tiles // 2 ,
200+ ),
201+ dtype = torch .int32 ,
202+ device = device ,
203+ )
204+
191205 scales_and_zeros = torch .empty (
192206 (in_features // groupsize , out_features , 2 ),
193207 dtype = get_precision (),
@@ -223,12 +237,17 @@ def _prepare_weight_and_scales_and_zeros(
223237 weight_int32 , scales_and_zeros = group_quantize_tensor (
224238 weight_bf16 , n_bit = 4 , groupsize = groupsize
225239 )
226- weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
227- torch .uint8
228- )
229- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
230- weight_uint8 , inner_k_tiles
231- )
240+ if is_device (weight_int32 .device .type , "cpu" ):
241+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
242+ weight_int32 , inner_k_tiles
243+ )
244+ else :
245+ weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
246+ torch .uint8
247+ )
248+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
249+ weight_uint8 , inner_k_tiles
250+ )
232251 return weight_int4pack , scales_and_zeros
233252
234253 @classmethod
@@ -609,17 +628,14 @@ def load_model_and_state_dict(
609628 if load_state_dict :
610629 q , s , z = Q4_0 .unpack (t )
611630 scales_and_zeros = pack_scales_and_zeros (s , z )
612- q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
613-
614- if torch .device (device ).type == "cpu" :
615- weight_int4pack = (
616- torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
617- q , inner_k_tiles
618- )
631+ if is_device (q .device .type , "cpu" ):
632+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
633+ q , inner_k_tiles
619634 )
620635 else :
636+ q_tmp = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
621637 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
622- q_uint8 , inner_k_tiles
638+ q_tmp , inner_k_tiles
623639 )
624640 state_dict [f"{ fqn } .weight" ] = weight_int4pack
625641 state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
@@ -632,7 +648,7 @@ def load_model_and_state_dict(
632648 in_features = in_features ,
633649 out_features = out_features ,
634650 bias = False ,
635- device = "meta " ,
651+ device = "cpu " ,
636652 groupsize = Q4_0 .groupsize ,
637653 inner_k_tiles = inner_k_tiles ,
638654 ),
0 commit comments