Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4b666a7

Browse files
committed
update transformer config
1 parent fff8647 commit 4b666a7

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

distributed/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def apply_tp(
6262
# after we apply TP to the model. Because we don't want to change model code
6363
# when applying TP. We need to have change to ensure KVCache has the correct
6464
# size as k and v.
65-
model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size()
65+
model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size()
6666

6767
# Apply tensor parallelism to every transformer block
6868
for transformer_block in model.layers:

torchchat/cli/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _initialize_model(
563563
model.setup_caches(
564564
max_batch_size=1,
565565
max_seq_length=max_seq_length
566-
or model.config.transformer_args["text"].max_seq_length,
566+
or model.model.config.max_seq_length,
567567
)
568568

569569
model.to(dtype=builder_args.precision)

torchchat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def export_for_server(
5454
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
5555
)
5656

57-
seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length)
57+
seq = Dim("seq", min=1, max=model.model.config.max_seq_length)
5858
# Specify that the first dimension of each input is that batch size
5959
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
6060
else:

torchchat/usages/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
5959
T = prompt.size(0)
6060
T_new = T + max_new_tokens
6161
if max_seq_length is None:
62-
max_seq_length = min(T_new, model.config.transformer_args["text"].block_size)
62+
max_seq_length = min(T_new, model.model.config.block_size)
6363

6464
device, dtype = prompt.device, prompt.dtype
6565
# create an empty tensor of the expected final shape and

torchchat/usages/openai_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def __init__(self, *args, **kwargs):
233233

234234
super().__init__(*args, **kwargs)
235235
self.max_seq_length = (
236-
self.model.config.transformer_args["text"].max_seq_length
236+
self.model.model.config.max_seq_length
237237
+ self.speculative_builder_args.speculate_k
238238
+ 1
239239
if self.draft_model is not None
240-
else self.model.config.transformer_args["text"].max_seq_length
240+
else self.model.model.config.max_seq_length
241241
)
242242
# The System fingerprint is a unique identifier for the model and its configuration.
243243
self.system_fingerprint = (

0 commit comments

Comments
 (0)