21
21
logger = logging .getLogger ()
22
22
23
23
24
- # yapf conflicts with isort for this docstring
25
- # yapf: disable
26
24
"""
27
25
tensorize_vllm_model.py is a script that can be used to serialize and
28
26
deserialize vLLM models. These models can be loaded using tensorizer
@@ -132,7 +130,8 @@ def get_parser():
132
130
"can be loaded using tensorizer directly to the GPU "
133
131
"extremely quickly. Tensor encryption and decryption is "
134
132
"also supported, although libsodium must be installed to "
135
- "use it." )
133
+ "use it."
134
+ )
136
135
parser = EngineArgs .add_cli_args (parser )
137
136
138
137
parser .add_argument (
@@ -144,13 +143,14 @@ def get_parser():
144
143
"along with the model by instantiating a TensorizerConfig object, "
145
144
"creating a dict from it with TensorizerConfig.to_serializable(), "
146
145
"and passing it to LoRARequest's initializer with the kwarg "
147
- "tensorizer_config_dict."
146
+ "tensorizer_config_dict." ,
148
147
)
149
148
150
- subparsers = parser .add_subparsers (dest = ' command' , required = True )
149
+ subparsers = parser .add_subparsers (dest = " command" , required = True )
151
150
152
151
serialize_parser = subparsers .add_parser (
153
- 'serialize' , help = "Serialize a model to `--serialized-directory`" )
152
+ "serialize" , help = "Serialize a model to `--serialized-directory`"
153
+ )
154
154
155
155
serialize_parser .add_argument (
156
156
"--suffix" ,
@@ -163,7 +163,9 @@ def get_parser():
163
163
"`--suffix` is `v1`, the serialized model tensors will be "
164
164
"saved to "
165
165
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
166
- "If none is provided, a random UUID will be used." ))
166
+ "If none is provided, a random UUID will be used."
167
+ ),
168
+ )
167
169
serialize_parser .add_argument (
168
170
"--serialized-directory" ,
169
171
type = str ,
@@ -175,108 +177,127 @@ def get_parser():
175
177
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
176
178
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
177
179
"where `suffix` is given by `--suffix` or a random UUID if not "
178
- "provided." )
180
+ "provided." ,
181
+ )
179
182
180
183
serialize_parser .add_argument (
181
184
"--serialization-kwargs" ,
182
185
type = tensorizer_kwargs_arg ,
183
186
required = False ,
184
- help = ("A JSON string containing additional keyword arguments to "
185
- "pass to Tensorizer's TensorSerializer during "
186
- "serialization." ))
187
+ help = (
188
+ "A JSON string containing additional keyword arguments to "
189
+ "pass to Tensorizer's TensorSerializer during "
190
+ "serialization."
191
+ ),
192
+ )
187
193
188
194
serialize_parser .add_argument (
189
195
"--keyfile" ,
190
196
type = str ,
191
197
required = False ,
192
- help = ("Encrypt the model weights with a randomly-generated binary key,"
193
- " and save the key at this path" ))
198
+ help = (
199
+ "Encrypt the model weights with a randomly-generated binary key,"
200
+ " and save the key at this path"
201
+ ),
202
+ )
194
203
195
204
deserialize_parser = subparsers .add_parser (
196
- 'deserialize' ,
197
- help = ("Deserialize a model from `--path-to-tensors`"
198
- " to verify it can be loaded and used." ))
205
+ "deserialize" ,
206
+ help = (
207
+ "Deserialize a model from `--path-to-tensors`"
208
+ " to verify it can be loaded and used."
209
+ ),
210
+ )
199
211
200
212
deserialize_parser .add_argument (
201
213
"--path-to-tensors" ,
202
214
type = str ,
203
215
required = False ,
204
- help = "The local path or S3 URI to the model tensors to deserialize. " )
216
+ help = "The local path or S3 URI to the model tensors to deserialize. " ,
217
+ )
205
218
206
219
deserialize_parser .add_argument (
207
220
"--serialized-directory" ,
208
221
type = str ,
209
222
required = False ,
210
223
help = "Directory with model artifacts for loading. Assumes a "
211
- "model.tensors file exists therein. Can supersede "
212
- "--path-to-tensors." )
224
+ "model.tensors file exists therein. Can supersede "
225
+ "--path-to-tensors." ,
226
+ )
213
227
214
228
deserialize_parser .add_argument (
215
229
"--keyfile" ,
216
230
type = str ,
217
231
required = False ,
218
- help = ("Path to a binary key to use to decrypt the model weights,"
219
- " if the model was serialized with encryption" ))
232
+ help = (
233
+ "Path to a binary key to use to decrypt the model weights,"
234
+ " if the model was serialized with encryption"
235
+ ),
236
+ )
220
237
221
238
deserialize_parser .add_argument (
222
239
"--deserialization-kwargs" ,
223
240
type = tensorizer_kwargs_arg ,
224
241
required = False ,
225
- help = ("A JSON string containing additional keyword arguments to "
226
- "pass to Tensorizer's `TensorDeserializer` during "
227
- "deserialization." ))
242
+ help = (
243
+ "A JSON string containing additional keyword arguments to "
244
+ "pass to Tensorizer's `TensorDeserializer` during "
245
+ "deserialization."
246
+ ),
247
+ )
228
248
229
249
TensorizerArgs .add_cli_args (deserialize_parser )
230
250
231
251
return parser
232
252
233
- def merge_extra_config_with_tensorizer_config ( extra_cfg : dict ,
234
- cfg : TensorizerConfig ):
253
+
254
+ def merge_extra_config_with_tensorizer_config ( extra_cfg : dict , cfg : TensorizerConfig ):
235
255
for k , v in extra_cfg .items ():
236
256
if hasattr (cfg , k ):
237
257
setattr (cfg , k , v )
238
258
logger .info (
239
259
"Updating TensorizerConfig with %s from "
240
- "--model-loader-extra-config provided" , k
260
+ "--model-loader-extra-config provided" ,
261
+ k ,
241
262
)
242
263
264
+
243
265
def deserialize (args , tensorizer_config ):
244
266
if args .lora_path :
245
267
tensorizer_config .lora_dir = tensorizer_config .tensorizer_dir
246
- llm = LLM (model = args .model ,
247
- load_format = "tensorizer" ,
248
- tensor_parallel_size = args .tensor_parallel_size ,
249
- model_loader_extra_config = tensorizer_config ,
250
- enable_lora = True ,
268
+ llm = LLM (
269
+ model = args .model ,
270
+ load_format = "tensorizer" ,
271
+ tensor_parallel_size = args .tensor_parallel_size ,
272
+ model_loader_extra_config = tensorizer_config ,
273
+ enable_lora = True ,
251
274
)
252
275
sampling_params = SamplingParams (
253
- temperature = 0 ,
254
- max_tokens = 256 ,
255
- stop = ["[/assistant]" ]
276
+ temperature = 0 , max_tokens = 256 , stop = ["[/assistant]" ]
256
277
)
257
278
258
279
# Truncating this as the extra text isn't necessary
259
- prompts = [
260
- "[user] Write a SQL query to answer the question based on ..."
261
- ]
280
+ prompts = ["[user] Write a SQL query to answer the question based on ..." ]
262
281
263
282
# Test LoRA load
264
283
print (
265
284
llm .generate (
266
- prompts ,
267
- sampling_params ,
268
- lora_request = LoRARequest ("sql-lora" ,
269
- 1 ,
270
- args .lora_path ,
271
- tensorizer_config_dict = tensorizer_config
272
- .to_serializable ())
285
+ prompts ,
286
+ sampling_params ,
287
+ lora_request = LoRARequest (
288
+ "sql-lora" ,
289
+ 1 ,
290
+ args .lora_path ,
291
+ tensorizer_config_dict = tensorizer_config .to_serializable (),
292
+ ),
273
293
)
274
294
)
275
295
else :
276
- llm = LLM (model = args .model ,
277
- load_format = "tensorizer" ,
278
- tensor_parallel_size = args .tensor_parallel_size ,
279
- model_loader_extra_config = tensorizer_config
296
+ llm = LLM (
297
+ model = args .model ,
298
+ load_format = "tensorizer" ,
299
+ tensor_parallel_size = args .tensor_parallel_size ,
300
+ model_loader_extra_config = tensorizer_config ,
280
301
)
281
302
return llm
282
303
@@ -285,17 +306,20 @@ def main():
285
306
parser = get_parser ()
286
307
args = parser .parse_args ()
287
308
288
- s3_access_key_id = (getattr (args , 's3_access_key_id' , None )
289
- or os .environ .get ("S3_ACCESS_KEY_ID" , None ))
290
- s3_secret_access_key = (getattr (args , 's3_secret_access_key' , None )
291
- or os .environ .get ("S3_SECRET_ACCESS_KEY" , None ))
292
- s3_endpoint = (getattr (args , 's3_endpoint' , None )
293
- or os .environ .get ("S3_ENDPOINT_URL" , None ))
309
+ s3_access_key_id = getattr (args , "s3_access_key_id" , None ) or os .environ .get (
310
+ "S3_ACCESS_KEY_ID" , None
311
+ )
312
+ s3_secret_access_key = getattr (
313
+ args , "s3_secret_access_key" , None
314
+ ) or os .environ .get ("S3_SECRET_ACCESS_KEY" , None )
315
+ s3_endpoint = getattr (args , "s3_endpoint" , None ) or os .environ .get (
316
+ "S3_ENDPOINT_URL" , None
317
+ )
294
318
295
319
credentials = {
296
320
"s3_access_key_id" : s3_access_key_id ,
297
321
"s3_secret_access_key" : s3_secret_access_key ,
298
- "s3_endpoint" : s3_endpoint
322
+ "s3_endpoint" : s3_endpoint ,
299
323
}
300
324
301
325
model_ref = args .model
@@ -309,25 +333,25 @@ def main():
309
333
if args .model_loader_extra_config :
310
334
extra_config = json .loads (args .model_loader_extra_config )
311
335
312
-
313
- tensorizer_dir = (args .serialized_directory or
314
- extra_config .get ("tensorizer_dir" ))
315
- tensorizer_uri = (getattr (args , "path_to_tensors" , None )
316
- or extra_config .get ("tensorizer_uri" ))
336
+ tensorizer_dir = args .serialized_directory or extra_config .get ("tensorizer_dir" )
337
+ tensorizer_uri = getattr (args , "path_to_tensors" , None ) or extra_config .get (
338
+ "tensorizer_uri"
339
+ )
317
340
318
341
if tensorizer_dir and tensorizer_uri :
319
- parser .error ("--serialized-directory and --path-to-tensors "
320
- "cannot both be provided" )
342
+ parser .error (
343
+ "--serialized-directory and --path-to-tensors cannot both be provided"
344
+ )
321
345
322
346
if not tensorizer_dir and not tensorizer_uri :
323
- parser .error ("Either --serialized-directory or --path-to-tensors "
324
- " must be provided")
325
-
347
+ parser .error (
348
+ "Either --serialized-directory or --path-to-tensors must be provided"
349
+ )
326
350
327
351
if args .command == "serialize" :
328
352
engine_args = EngineArgs .from_cli_args (args )
329
353
330
- input_dir = tensorizer_dir .rstrip ('/' )
354
+ input_dir = tensorizer_dir .rstrip ("/" )
331
355
suffix = args .suffix if args .suffix else uuid .uuid4 ().hex
332
356
base_path = f"{ input_dir } /vllm/{ model_ref } /{ suffix } "
333
357
if engine_args .tensor_parallel_size > 1 :
@@ -339,15 +363,14 @@ def main():
339
363
tensorizer_uri = model_path ,
340
364
encryption_keyfile = keyfile ,
341
365
serialization_kwargs = args .serialization_kwargs or {},
342
- ** credentials
366
+ ** credentials ,
343
367
)
344
368
345
369
if args .lora_path :
346
370
tensorizer_config .lora_dir = tensorizer_config .tensorizer_dir
347
371
tensorize_lora_adapter (args .lora_path , tensorizer_config )
348
372
349
- merge_extra_config_with_tensorizer_config (extra_config ,
350
- tensorizer_config )
373
+ merge_extra_config_with_tensorizer_config (extra_config , tensorizer_config )
351
374
tensorize_vllm_model (engine_args , tensorizer_config )
352
375
353
376
elif args .command == "deserialize" :
@@ -356,11 +379,10 @@ def main():
356
379
tensorizer_dir = args .serialized_directory ,
357
380
encryption_keyfile = keyfile ,
358
381
deserialization_kwargs = args .deserialization_kwargs or {},
359
- ** credentials
382
+ ** credentials ,
360
383
)
361
384
362
- merge_extra_config_with_tensorizer_config (extra_config ,
363
- tensorizer_config )
385
+ merge_extra_config_with_tensorizer_config (extra_config , tensorizer_config )
364
386
deserialize (args , tensorizer_config )
365
387
else :
366
388
raise ValueError ("Either serialize or deserialize must be specified." )
0 commit comments