-
Notifications
You must be signed in to change notification settings - Fork 248
[1/n llava]unify model construction ppl #1153
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1153
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (2 Unrelated Failures)As of commit f224da7 with merge base 16b3d64 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
distributed/parallelize_llama.py
Outdated
| # when applying TP. We need to have change to ensure KVCache has the correct | ||
| # size as k and v. | ||
| model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size() | ||
| model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.model is really hard to reason about... what type is it?
The former was clunky, but legible. I'm not sure about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not happy with "text" either it was not sustainable, especially if the number of modules increases.
It needs fixing, but model.model might not be perfectly there yet, but it's close
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's annoying, i'm 100% agree.
I will remove model.model as soon as I can.
torchchat/model.py
Outdated
|
|
||
| class Transformer(nn.Module): | ||
| def __init__(self, config: TransformerArgs) -> None: | ||
| def __init__(self, config: Dict[str, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a fan of this one, Transformer taking TransformerArgs is the most intuitive set up and matches the other classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. Bring it back
| ) | ||
| elif generator_args.chat_mode: | ||
| if ( | ||
| max_seq_length := self.model.config.transformer_args.get("text", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your changes are right; just calling out that the old implementation was broken in 26c1d8b
torchchat/generate.py
Outdated
| if text_transformer_args is not None | ||
| else 2048 | ||
| ), | ||
| encoded.size(0) + generator_args.max_new_tokens, max_seq_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is a departure from the original code where the second argument to min is block_size (which represents a different max_seq_length (confusing i know)).
While we want to move away from using the block_size, let's not do it in this diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! not sure why this happen, probaly a typo. Will fix it.
| model_type (ModelType): The type of the model. This attribute is used to categorize the model into different classes. | ||
| transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model. | ||
| The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains | ||
| the parameter names and their corresponding values for the respective transformer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each inner dictionary contains the parameter names and their corresponding values for the respective transformer.
This sounds like the intent of transformer args; why can't we use that instead of Dictp[str, Any]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for unification. this arg takes charge for describing architecture for all models, including tune-backends, chat-backends, and even mix-backends. so we need a unify way to descible how we will set up them.
for chat-backend modules, the inner Dict will be converted into tranformerArg afterwards.
| def __init__( | ||
| self, | ||
| transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], | ||
| transformer_args: Dict[str, Dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should find a way to reconcile Dict[str, Any] into a TransformerArgs in a future PR
This makes this work well since we have 3 "cases", but storing/passing around an untyped Dict makes me nervous
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More than agree. My mental model would be creating an abstract class containig essential apis for all module configurations, and for different transformer (e.g. ours, tunes, etc) we have a different implementation. Dict[str, Any] is not a great way.
Let me add some comments in our codebase to highlight that.
| super().__init__() | ||
| self.config = config | ||
| self.model = self.build_model() | ||
| self.text_transformer_args = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on this since it is a special case
This PR adopts the same pipeline to construct both chat model and tune model; previously we used TransformerArgs to construct chat-backend model while dictionary for tune-backend model.
Also fix some annoying hacky stuff for configuration.