1616import torch ._inductor .config
1717import torch .nn as nn
1818
19- from torch .distributed .device_mesh import DeviceMesh
20- from torch .distributed .elastic .multiprocessing .errors import record
21- from torch .distributed .elastic .utils .distributed import get_free_port
22-
23- from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
24-
2519from torchchat .model import Model , ModelArgs , ModelType
2620
2721from torchchat .model_config .model_config import resolve_model_config
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458 return model
465459
466460
467- def _maybe_init_distributed (
468- builder_args : BuilderArgs ,
469- ) -> Tuple [Optional [DeviceMesh ], Optional [ParallelDims ]]:
470- """
471- Initialize distributed related setups if the user specified
472- using distributed inference. If not, this is a no-op.
473-
474- Args:
475- builder_args (:class:`BuilderArgs`):
476- Command args for model building.
477- Returns:
478- Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479- - The first element is an optional DeviceMesh object,
480- which which describes the mesh topology of devices for the DTensor.
481- - The second element is an optional ParallelDims object,
482- which represents the parallel dimensions configuration.
483- """
484- if not builder_args .use_distributed :
485- return None , None
486- dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487-
488- world_mesh , parallel_dims = launch_distributed (dist_config )
489-
490- assert (
491- world_mesh is not None and parallel_dims is not None
492- ), f"failed to launch distributed using { dist_config } "
493-
494- return world_mesh , parallel_dims
495-
496-
497- def _maybe_parallelize_model (
498- model : nn .Module ,
499- builder_args : BuilderArgs ,
500- world_mesh : DeviceMesh ,
501- parallel_dims : ParallelDims ,
502- ) -> nn .Module :
503- """
504- We parallelize the module and load the distributed checkpoint to the model
505- if the user specifies using distributed inference. If not, this is a no-op.
506-
507- Args:
508- model (:class:`nn.Module`):
509- Module to be parallelized.
510- builder_args (:class:`BuilderArgs`):
511- Command args for model building.
512- world_mesh (:class:`DeviceMesh`):
513- Object which describes the mesh topology
514- of devices for the DTensor.
515- parallel_dims (:class:`ParallelDims`):
516- Object which represents the parallel dimensions configuration.
517- Returns:
518- A :class:`nn.Module` object which is parallelized and checkpoint loaded
519- if the user specifies using distributed inference.
520- """
521- if world_mesh is None :
522- return model
523- assert parallel_dims is not None
524- print ("Applying model parallel to model ..." )
525- parallelize_llama (model , world_mesh , parallel_dims )
526- return load_checkpoints_to_model (model , builder_args , world_mesh )
527-
528-
529461def _load_model (builder_args : BuilderArgs ) -> Model :
530- # world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531462 if builder_args .gguf_path :
532463 model = _load_model_gguf (builder_args )
533- # elif builder_args.use_distributed:
534- # model = _init_model_on_meta_device(builder_args)
535464 else :
536465 model = _load_model_default (builder_args )
537- # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538466
539467 if builder_args .dso_path or builder_args .aoti_package_path :
540468 # AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634 return "TikToken"
707635 if tokenizers :
708636 return "Tokenizers"
709- return "SentencePiece"
637+ return "SentencePiece"
0 commit comments