Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6c7374e

Browse files
authored
Merge branch 'main' into benchmarking_script
2 parents 91e9909 + f20f5e7 commit 6c7374e

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

dist_run.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@
5555
# Using model name to identify the model to load, for example "llama2-7b-chat".
5656
# You can change it to other values listed below.
5757
# For details on the name-to-distribution mapping, see README.md or models.json.
58+
59+
# Name : HF distribution name, dtype, and model dimension
5860
NAME_TO_DISTRIBUTION_AND_DTYPE = {
59-
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
60-
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
61+
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16, 4096),
62+
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16, 4096),
63+
"llama3-70b": ("meta-llama/Meta-Llama-3-70B-Instruct", torch.bfloat16, 8192),
6164
}
6265

6366

@@ -314,8 +317,12 @@ def main(args):
314317
gpu_memory_monitor = GPUMemoryMonitor("cuda")
315318
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
316319

317-
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
318-
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
320+
distribution, model_dtype, model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE[
321+
model_name
322+
]
323+
logger.info(
324+
f"Using model weights from {distribution}, dtype {model_dtype} and model dimension {model_dimension}"
325+
)
319326

320327
# Model-level config
321328
model_config = ModelArgs.from_name(distribution)
@@ -338,6 +345,7 @@ def main(args):
338345

339346
# Tensor parallel is enabled in this program
340347
tp_degree = world_size // pp_degree
348+
logger.info(f"Using TP degree {tp_degree} and PP degree {pp_degree}")
341349

342350
# Create device mesh
343351
mesh_dimensions = (pp_degree, tp_degree)
@@ -388,7 +396,6 @@ def main(args):
388396
# sense. Thus it is interchangeable with micro-batch size below.
389397
batch_size = len(prompt)
390398
seqlen_prefill = 1024 # sequence length
391-
dim = 4096 # embedding dimension
392399

393400
# Setup KV caches (after model distribution)
394401
# The number of cache lanes is the same as the maximum number of
@@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
419426
0, config.vocab_size, (batch_size, seqlen), device=device
420427
)
421428
activation = torch.rand(
422-
batch_size, seqlen, dim, device=device, dtype=model_dtype
429+
batch_size, seqlen, model_dimension, device=device, dtype=model_dtype
423430
)
424431
logits = torch.rand(
425432
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from transformers import AutoModelForCausalLM, AutoTokenizer
22

3-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
4-
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
3+
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
4+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")
55
print("Model weights and tokenizer downloaded")

torchchat/generate.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424

2525
from PIL import Image
2626

27+
# torchtune model definition dependencies
28+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29+
30+
from torchtune.generation import sample as tune_sample
31+
from torchtune.models.llama3 import llama3_tokenizer
32+
33+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
34+
from torchtune.training import set_default_dtype
35+
2736
from torchchat.cli.builder import (
2837
_initialize_model,
2938
_initialize_tokenizer,
@@ -35,15 +44,6 @@
3544
from torchchat.utils.build_utils import device_sync, set_precision
3645
from torchchat.utils.device_info import get_device_info
3746

38-
# torchtune model definition dependencies
39-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
40-
41-
from torchtune.generation import sample as tune_sample
42-
from torchtune.models.llama3 import llama3_tokenizer
43-
44-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
45-
from torchtune.training import set_default_dtype
46-
4747

4848
class _ChatFormatter(ABC):
4949
def __init__(self, tokenizer):
@@ -1155,13 +1155,9 @@ def callback(x, *, done_generating=False):
11551155
print(
11561156
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
11571157
)
1158-
aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec
1159-
# Don't continue here.... because we need to report and reset
1160-
# continue
1161-
else:
1162-
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1163-
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1164-
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
1158+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1159+
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1160+
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
11651161

11661162
logging.info(
11671163
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\

0 commit comments

Comments
 (0)