Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5d5664c

Browse files
authored
Update missed callsite when introducing ModelArgs (#1071)
1 parent 9a276c3 commit 5d5664c

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

build/convert_hf_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
sys.path.append(str(wd.resolve()))
1818
sys.path.append(str((wd / "build").resolve()))
1919

20-
from build.model import TransformerArgs
20+
from build.model import ModelArgs
2121

2222

2323
@torch.inference_mode()
@@ -32,7 +32,7 @@ def convert_hf_checkpoint(
3232
if model_name is None:
3333
model_name = model_dir.name
3434

35-
config = TransformerArgs.from_name(model_name)
35+
config = ModelArgs.from_name(model_name).text_transformer_args
3636
print(f"Model config {config.__dict__}")
3737

3838
# Load the json file containing weight mapping

build/model_dist.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,19 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
112112

113113
# print(f"stage output shape: {x.shape}")
114114
return x
115-
116-
@classmethod
117-
def from_name(cls, name: str):
118-
return cls(TransformerArgs.from_name(name))
119-
120-
@classmethod
121-
def from_table(cls, name: str):
122-
return cls(TransformerArgs.from_table(name))
123-
124-
@classmethod
125-
def from_params(cls, params_path: str):
126-
return cls(TransformerArgs.from_params(params_path))
115+
116+
# temporary disable them due to miss essential input
117+
# @classmethod
118+
# def from_name(cls, name: str):
119+
# return cls(TransformerArgs.from_name(name))
120+
121+
# @classmethod
122+
# def from_table(cls, name: str):
123+
# return cls(TransformerArgs.from_table(name))
124+
125+
# @classmethod
126+
# def from_params(cls, params_path: str):
127+
# return cls(TransformerArgs.from_params(params_path))
127128

128129
@classmethod
129130
def from_gguf(cls, gguf_path: str, **kwargs):

dist_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import torch.distributed as dist
1212
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
1313

14-
from build.model import TransformerArgs
14+
from build.model import ModelArgs
1515
from build.model_dist import TransformerStage
1616

1717
# Model config
1818
def main():
19-
config = TransformerArgs.from_name("Transformer-2-7b-chat-hf")
19+
config = ModelArgs.from_name("Transformer-2-7b-chat-hf").text_transformer_args
2020
print(config)
2121

2222
# Construct a device mesh with available devices (multi-host or single host)

docs/ADVANCED-USERS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ For example, for the stories15M model, this would be expressed as
123123

124124
For models using a configuration not in the list of known
125125
configurations, you can construct the model by initializing the
126-
`TransformerArgs` dataclass that controls model construction from a
126+
`ModelArgs` dataclass that controls model construction from a
127127
parameter json using the `params-path ${PARAMS_PATH}` containing the
128128
appropriate model parameters to initialize the `ModelArgs` for the
129129
model. (We use the model constructor `Model.from_params()`).
130130

131131
The parameter file should be in JSON format specifying these
132-
parameters. You can find the `TransformerArgs` data class in
133-
[`model.py`](https://github.com/pytorch/torchchat/blob/main/model.py#L22).
132+
parameters. You can find the `ModelArgs` data class in
133+
[`model.py`](https://github.com/pytorch/torchchat/blob/main/build/model.py#L70).
134134

135135
The final way to initialize a torchchat model is from GGUF. You load a
136136
GGUF model with the option `--load-gguf ${MODELNAME}.gguf`. Presently,

0 commit comments

Comments
 (0)