Skip to content

Commit 4c6874e

Browse files
committed
🔧 update model split function
1 parent f513056 commit 4c6874e

File tree

1 file changed

+94
-10
lines changed

1 file changed

+94
-10
lines changed

merle/model_split.py

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,87 @@ def reassemble_blob(
608608
return True
609609

610610

611+
def reassemble_blob_streaming(
612+
part1_path: Path,
613+
s3_bucket: str,
614+
s3_key: str,
615+
output_path: Path,
616+
region: str,
617+
expected_sha256: str | None = None,
618+
) -> bool:
619+
"""
620+
Reassemble a split blob by streaming part2 directly from S3.
621+
622+
This avoids downloading part2 to disk first, which is critical for
623+
staying within Lambda's 10GB ephemeral storage limit when the model
624+
is large (e.g., 9GB blob + 3GB S3 portion would exceed 10GB).
625+
626+
Args:
627+
part1_path: Path to part 1 file (in Docker image)
628+
s3_bucket: S3 bucket containing part 2
629+
s3_key: S3 key for part 2
630+
output_path: Path for reassembled file (in ephemeral storage)
631+
region: AWS region
632+
expected_sha256: Optional SHA256 hash to verify
633+
634+
Returns:
635+
True if successful, False otherwise
636+
"""
637+
logger.info(f"Reassembling blob: {part1_path} + s3://{s3_bucket}/{s3_key}{output_path}")
638+
639+
output_path.parent.mkdir(parents=True, exist_ok=True)
640+
hasher = hashlib.sha256() if expected_sha256 else None
641+
642+
chunk_size = 64 * 1024 * 1024 # 64MB chunks
643+
bytes_written = 0
644+
645+
with output_path.open("wb") as out:
646+
# Write part 1 from local file (in Docker image, read-only)
647+
logger.info(f"Writing part 1 from {part1_path}")
648+
with part1_path.open("rb") as part1:
649+
while True:
650+
chunk = part1.read(chunk_size)
651+
if not chunk:
652+
break
653+
out.write(chunk)
654+
bytes_written += len(chunk)
655+
if hasher:
656+
hasher.update(chunk)
657+
logger.info(f"Part 1 complete: {bytes_written} bytes written")
658+
659+
# Stream part 2 directly from S3 (never hits disk)
660+
logger.info(f"Streaming part 2 from s3://{s3_bucket}/{s3_key}")
661+
s3 = boto3.client("s3", region_name=region)
662+
663+
response = s3.get_object(Bucket=s3_bucket, Key=s3_key)
664+
body = response["Body"]
665+
666+
part2_bytes = 0
667+
while True:
668+
chunk = body.read(chunk_size)
669+
if not chunk:
670+
break
671+
out.write(chunk)
672+
bytes_written += len(chunk)
673+
part2_bytes += len(chunk)
674+
if hasher:
675+
hasher.update(chunk)
676+
677+
body.close()
678+
logger.info(f"Part 2 complete: {part2_bytes} bytes streamed from S3")
679+
680+
if expected_sha256 and hasher:
681+
actual_sha256 = hasher.hexdigest()
682+
if actual_sha256 != expected_sha256:
683+
logger.error(f"SHA256 mismatch! Expected {expected_sha256}, got {actual_sha256}")
684+
output_path.unlink()
685+
return False
686+
logger.info("SHA256 verification passed")
687+
688+
logger.info(f"Reassembly complete: {output_path} ({bytes_written} bytes)")
689+
return True
690+
691+
611692
def copy_model_to_output(model_name: str, output_dir: Path) -> dict:
612693
"""
613694
Copy model files to output directory for Docker image inclusion.
@@ -822,6 +903,7 @@ def reassemble_at_runtime(
822903
Reassemble a split model at Lambda runtime.
823904
824905
This function should be called during Lambda cold start if split_metadata.json exists.
906+
Uses streaming from S3 to avoid exceeding Lambda's 10GB ephemeral storage limit.
825907
826908
Args:
827909
source_models_dir: Path to models in Docker image (/var/task/models)
@@ -872,23 +954,25 @@ def reassemble_at_runtime(
872954
if not dest.exists():
873955
shutil.copy2(blob_file, dest)
874956

875-
# Download part 2 from S3
957+
# Get S3 info and region
876958
s3_info = metadata["s3"]
877959
effective_region = region or s3_info.get("region") or os.environ.get("AWS_REGION") or "us-east-1"
878960

879-
part2_temp = target_models_dir / "temp_part2"
880-
download_from_s3(s3_info["bucket"], s3_info["key"], part2_temp, effective_region)
881-
882-
# Reassemble the blob
961+
# Reassemble the blob by streaming part2 directly from S3
962+
# This avoids downloading part2 to disk first, which would exceed
963+
# Lambda's 10GB ephemeral storage limit for large models
883964
part1_path = blobs_src / part1_filename
884965
output_path = blobs_dst / split_blob_name
885966
expected_sha256 = metadata["blob_split"]["original_blob_sha256"]
886967

887-
success = reassemble_blob(part1_path, part2_temp, output_path, expected_sha256)
888-
889-
# Clean up temp file
890-
if part2_temp.exists():
891-
part2_temp.unlink()
968+
success = reassemble_blob_streaming(
969+
part1_path=part1_path,
970+
s3_bucket=s3_info["bucket"],
971+
s3_key=s3_info["key"],
972+
output_path=output_path,
973+
region=effective_region,
974+
expected_sha256=expected_sha256,
975+
)
892976

893977
if success:
894978
logger.info("Model reassembly complete")

0 commit comments

Comments
 (0)