|
6 | 6 | CompletionStreamV1Request, |
7 | 7 | CompletionSyncResponse, |
8 | 8 | CompletionSyncV1Request, |
| 9 | + CreateBatchCompletionsModelConfig, |
| 10 | + CreateBatchCompletionsRequest, |
| 11 | + CreateBatchCompletionsRequestContent, |
| 12 | + CreateBatchCompletionsResponse, |
9 | 13 | ) |
10 | 14 |
|
11 | 15 | COMPLETION_TIMEOUT = 300 |
| 16 | +HTTP_TIMEOUT = 60 |
12 | 17 |
|
13 | 18 |
|
14 | 19 | class Completion(APIEngine): |
@@ -397,3 +402,96 @@ def _create_stream(**kwargs): |
397 | 402 | timeout=timeout, |
398 | 403 | ) |
399 | 404 | return CompletionSyncResponse.parse_obj(response) |
| 405 | + |
| 406 | + @classmethod |
| 407 | + def batch_create( |
| 408 | + cls, |
| 409 | + output_data_path: str, |
| 410 | + model_config: CreateBatchCompletionsModelConfig, |
| 411 | + content: Optional[CreateBatchCompletionsRequestContent] = None, |
| 412 | + input_data_path: Optional[str] = None, |
| 413 | + data_parallelism: int = 1, |
| 414 | + max_runtime_sec: int = 24 * 3600, |
| 415 | + ) -> CreateBatchCompletionsResponse: |
| 416 | + """ |
| 417 | + Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. |
| 418 | +
|
| 419 | + Prompts can be passed in from an input file, or as a part of the request. |
| 420 | +
|
| 421 | + Args: |
| 422 | + output_data_path (str): |
| 423 | + The path to the output file. The output file will be a JSON file containing the completions. |
| 424 | +
|
| 425 | + model_config (CreateBatchCompletionsModelConfig): |
| 426 | + The model configuration to use for the batch completion. |
| 427 | +
|
| 428 | + content (Optional[CreateBatchCompletionsRequestContent]): |
| 429 | + The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. |
| 430 | +
|
| 431 | + input_data_path (Optional[str]): |
| 432 | + The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. |
| 433 | +
|
| 434 | + data_parallelism (int): |
| 435 | + The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. |
| 436 | +
|
| 437 | + max_runtime_sec (int): |
| 438 | + The maximum runtime of the batch completion in seconds. Defaults to 24 hours. |
| 439 | +
|
| 440 | + Returns: |
| 441 | + response (CreateBatchCompletionsResponse): The response containing the job id. |
| 442 | +
|
| 443 | + === "Batch completions with prompts in the request" |
| 444 | + ```python |
| 445 | + from llmengine import Completion |
| 446 | + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent |
| 447 | +
|
| 448 | + response = Completion.batch_create( |
| 449 | + output_data_path="s3://my-path", |
| 450 | + model_config=CreateBatchCompletionsModelConfig( |
| 451 | + model="llama-2-7b", |
| 452 | + checkpoint_path="s3://checkpoint-path", |
| 453 | + labels={"team":"my-team", "product":"my-product"} |
| 454 | + ), |
| 455 | + content=CreateBatchCompletionsRequestContent( |
| 456 | + prompts=["What is deep learning", "What is a neural network"], |
| 457 | + max_new_tokens=10, |
| 458 | + temperature=0.0 |
| 459 | + ) |
| 460 | + ) |
| 461 | + print(response.json()) |
| 462 | + ``` |
| 463 | +
|
| 464 | + === "Batch completions with prompts in a file and with 2 parallel jobs" |
| 465 | + ```python |
| 466 | + from llmengine import Completion |
| 467 | + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent |
| 468 | +
|
| 469 | + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" |
| 470 | +
|
| 471 | + response = Completion.batch_create( |
| 472 | + input_data_path="s3://my-input-path", |
| 473 | + output_data_path="s3://my-output-path", |
| 474 | + model_config=CreateBatchCompletionsModelConfig( |
| 475 | + model="llama-2-7b", |
| 476 | + checkpoint_path="s3://checkpoint-path", |
| 477 | + labels={"team":"my-team", "product":"my-product"} |
| 478 | + ), |
| 479 | + data_parallelism=2 |
| 480 | + ) |
| 481 | + print(response.json()) |
| 482 | + ``` |
| 483 | + """ |
| 484 | + data = CreateBatchCompletionsRequest( |
| 485 | + model_config=model_config, |
| 486 | + content=content, |
| 487 | + input_data_path=input_data_path, |
| 488 | + output_data_path=output_data_path, |
| 489 | + data_parallelism=data_parallelism, |
| 490 | + max_runtime_sec=max_runtime_sec, |
| 491 | + ).dict() |
| 492 | + response = cls.post_sync( |
| 493 | + resource_name="v1/llm/batch-completions", |
| 494 | + data=data, |
| 495 | + timeout=HTTP_TIMEOUT, |
| 496 | + ) |
| 497 | + return CreateBatchCompletionsResponse.parse_obj(response) |
0 commit comments