This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
Integrate distributed inference into torchchat cli #1327
Merged
Merged
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
481e00b
add pp_dim, distributed, num_gpus, num_nodes as cmd line args
lessw2020 2f1787c
add tp_dim
lessw2020 fd3ddcd
add elastic_launch
lessw2020 bf79697
working, can now launch from cli
lessw2020 26a9455
Remove numpy < 2.0 pin to align with pytorch (#1301)
larryliu0820 5f0ca00
Update torchtune pin to 0.4.0-dev20241010 (#1300)
vmpuri 598caf5
Unbreak gguf util CI job by fixing numpy version (#1307)
larryliu0820 6fe1646
Remove apparently-unused import torchvision in model.py (#1305)
swolchok 78debce
remove global var for tokenizer type + patch tokenizer to allow list …
mreso 2eefb13
make pp tp visible in interface
mreso e8bb076
Add llama 3.1 to dist_run.py
mreso 1faa052
[WIP] Move dist inf into its own generator
mreso 11f29fc
Add initial generator interface to dist inference
mreso adcf232
Added generate method and placeholder scheduler
mreso 3836928
use prompt parameter for dist generation
mreso 3f6fa2d
Enforce tp>=2
mreso fd9f704
Build tokenizer from TokenizerArgs
mreso e8f7c98
Disable torchchat format + constrain possible models for distributed
mreso 9ec55fb
disable calling dist_run.py directly for now
mreso 80f8138
Restore original dist_run.py for now
mreso abf0679
Merge branch 'main' into refactor/dist_run
mreso 99606ab
disable _maybe_parallelize_model again
mreso 4b8cdcb
Reenable arg.model_name in dist_run.py
mreso b8f88fd
Use singleton logger instead of print in generate
mreso 2d37d27
Address PR comments; try/expect in launch_dist_inference; added comments
mreso File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,20 +16,14 @@ | |
| import torch._inductor.config | ||
| import torch.nn as nn | ||
|
|
||
| from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
|
||
| from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama | ||
|
|
||
| from torch.distributed.device_mesh import DeviceMesh | ||
| from torch.distributed.elastic.multiprocessing.errors import record | ||
| from torch.distributed.elastic.utils.distributed import get_free_port | ||
|
|
||
| from torchtune.models.convert_weights import meta_to_tune | ||
|
|
||
| from torchtune.training import set_default_dtype | ||
| from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama | ||
|
|
||
| from torchchat.model import Model, ModelArgs, ModelType | ||
|
|
||
| from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
|
||
| from torchchat.model_config.model_config import resolve_model_config | ||
| from torchchat.utils.build_utils import ( | ||
| device_sync, | ||
|
|
@@ -40,6 +34,14 @@ | |
| from torchchat.utils.measure_time import measure_time | ||
| from torchchat.utils.quantize import quantize_model | ||
|
|
||
| from torchtune.models.convert_weights import meta_to_tune | ||
|
|
||
| from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
|
||
| from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
|
||
| from torchtune.training import set_default_dtype | ||
|
|
||
|
|
||
| @dataclass | ||
| class BuilderArgs: | ||
|
|
@@ -55,7 +57,10 @@ class BuilderArgs: | |
| device: Optional[str] = None | ||
| precision: torch.dtype = torch.float32 | ||
| setup_caches: bool = False | ||
| use_distributed: bool = False | ||
| distributed: bool = False | ||
| pp: int = 1 | ||
| tp: int = 1 | ||
| chpt_from: str = "hf" | ||
| is_chat_model: bool = False | ||
| prefill_possible: bool = False | ||
| dynamic_shapes: bool = False | ||
|
|
@@ -87,7 +92,9 @@ def __post_init__(self): | |
| ] | ||
| for param, param_msg in ignored_params: | ||
| if param: | ||
| print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified") | ||
| print( | ||
| f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified" | ||
| ) | ||
| else: | ||
| self.prefill_possible = True | ||
|
|
||
|
|
@@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
| dtype = torch.float16 | ||
| else: | ||
| dtype = name_to_dtype(args.dtype, args.device) | ||
|
|
||
| # distributed args | ||
| distributed = getattr(args, "distributed", False) | ||
| pp = getattr(args, "pp", 1) | ||
| tp = getattr(args, "tp", 1) | ||
| chpt_from = getattr(args, "chpt_from", "hf") | ||
| return cls( | ||
| checkpoint_dir=checkpoint_dir, | ||
| checkpoint_path=checkpoint_path, | ||
|
|
@@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
| device=args.device, | ||
| precision=dtype, | ||
| setup_caches=(output_dso_path or output_pte_path), | ||
| use_distributed=args.distributed, | ||
| distributed=distributed, | ||
| pp=pp, | ||
| tp=tp, | ||
| chpt_from=chpt_from, | ||
| is_chat_model=is_chat_model, | ||
| dynamic_shapes=getattr(args, "dynamic_shapes", False), | ||
| max_seq_length=getattr(args, "max_seq_length", None), | ||
|
|
@@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: | |
| # does not host any actual values, need to reinitialize them in the actual | ||
| # device. Only do those buffer initialization, without initializing the entire | ||
| # model. | ||
| decoder_config = model.config.transformer_args['decoder'] | ||
| head_dim = decoder_config['embed_dim'] // decoder_config['num_heads'] | ||
| max_seq_len = decoder_config['max_seq_len'] | ||
| rope_base = decoder_config['rope_base'] | ||
| decoder_config = model.config.transformer_args["decoder"] | ||
| head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"] | ||
| max_seq_len = decoder_config["max_seq_len"] | ||
| rope_base = decoder_config["rope_base"] | ||
| for submodule in model.modules(): | ||
| if isinstance(submodule, Llama3ScaledRoPE): | ||
| submodule.__init__(head_dim, max_seq_len, rope_base) | ||
|
|
@@ -476,18 +490,19 @@ def _maybe_parallelize_model( | |
|
|
||
|
|
||
| def _load_model(builder_args: BuilderArgs) -> Model: | ||
| world_mesh, parallel_dims = _maybe_init_distributed(builder_args) | ||
| # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this code is now effectively dead and we should just remove it but a later PR. |
||
| if builder_args.gguf_path: | ||
| model = _load_model_gguf(builder_args) | ||
| elif builder_args.use_distributed: | ||
| model = _init_model_on_meta_device(builder_args) | ||
| # elif builder_args.use_distributed: | ||
| # model = _init_model_on_meta_device(builder_args) | ||
| else: | ||
| model = _load_model_default(builder_args) | ||
| model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) | ||
| # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) | ||
|
|
||
| model = model.to(device=builder_args.device, dtype=builder_args.precision) | ||
| return model.eval() | ||
|
|
||
|
|
||
| def _initialize_model( | ||
| builder_args: BuilderArgs, | ||
| quantize, | ||
|
|
@@ -496,7 +511,6 @@ def _initialize_model( | |
| support_tensor_subclass: bool = True, | ||
| ) -> Model: | ||
| print("Loading model...") | ||
|
|
||
| if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): | ||
| print("Setting gguf_kwargs for generate.") | ||
| is_dso = builder_args.dso_path is not None | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.