-
Notifications
You must be signed in to change notification settings - Fork 251
Move imports to prepare for making torchtune an optional dependency #1539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,14 @@ | |
import torch._inductor.config | ||
import torch.distributed as dist | ||
|
||
from torchchat.distributed.utils import( | ||
from torchchat.distributed.logging_utils import SingletonLogger | ||
|
||
from torchchat.distributed.utils import ( | ||
Color as color, | ||
CUDATrackTime, | ||
init_distributed, | ||
GPUMemoryMonitor, | ||
init_distributed, | ||
) | ||
from torchchat.distributed.logging_utils import SingletonLogger | ||
|
||
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs | ||
from torchchat.model_config.model_config import resolve_model_config | ||
|
@@ -37,15 +38,6 @@ | |
from torchchat.utils.quantize import quantize_model | ||
|
||
|
||
from torchtune.models.convert_weights import meta_to_tune | ||
|
||
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
||
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
||
from torchtune.training import set_default_dtype | ||
|
||
|
||
@dataclass | ||
class BuilderArgs: | ||
checkpoint_path: Optional[Union[Path, str]] = None | ||
|
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
tp = getattr(args, "tp", 1) | ||
chpt_from = getattr(args, "chpt_from", "hf") | ||
sdp_backend_dict = { | ||
'math': torch.nn.attention.SDPBackend.MATH, | ||
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
"math": torch.nn.attention.SDPBackend.MATH, | ||
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
} | ||
attention_backend = sdp_backend_dict[args.attention_backend] | ||
if args.device == "cpu" and (args.attention_backend == "efficient_attention" | ||
or args.attention_backend == "cudnn_attention"): | ||
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") | ||
if args.device == "cpu" and ( | ||
args.attention_backend == "efficient_attention" | ||
or args.attention_backend == "cudnn_attention" | ||
): | ||
print( | ||
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." | ||
) | ||
attention_backend = torch.nn.attention.SDPBackend.MATH | ||
return cls( | ||
checkpoint_dir=checkpoint_dir, | ||
|
@@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
speculative_builder_args.pte_path = None | ||
return speculative_builder_args | ||
|
||
|
||
class TokenizerType(Enum): | ||
NONE = 0 | ||
TIKTOKEN = 1 | ||
SENTENCEPIECE = 2 | ||
HF_TOKENIZER = 3 | ||
|
||
|
||
@dataclass | ||
class TokenizerArgs: | ||
tokenizer_path: Optional[Union[Path, str]] = None | ||
|
@@ -307,9 +305,9 @@ def validate_model( | |
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) | ||
|
||
if ( | ||
(is_tiktoken and not use_tiktoken) or | ||
(is_hf_tokenizer and not use_hf_tokenizer) or | ||
(is_sentencepiece and not use_sentencepiece) | ||
(is_tiktoken and not use_tiktoken) | ||
or (is_hf_tokenizer and not use_hf_tokenizer) | ||
or (is_sentencepiece and not use_sentencepiece) | ||
): | ||
raise RuntimeError( | ||
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( | ||
|
@@ -416,6 +414,8 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model: | |
|
||
|
||
def _load_checkpoint(builder_args: BuilderArgs): | ||
from torchtune.models.convert_weights import meta_to_tune | ||
|
||
if builder_args.params_table and builder_args.params_table.endswith("Tune"): | ||
print("Loading Tune checkpoint") | ||
meta_checkpoint = torch.load( | ||
|
@@ -458,6 +458,12 @@ def _load_checkpoint(builder_args: BuilderArgs): | |
|
||
|
||
def _load_model_default(builder_args: BuilderArgs) -> Model: | ||
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto, we can drop this into the Flamingo conditional |
||
from torchtune.models.llama3_2_vision._convert_weights import ( | ||
llama3_vision_meta_to_tune, | ||
) | ||
from torchtune.training import set_default_dtype | ||
|
||
assert not builder_args.gguf_path | ||
|
||
model: Model = _init_model_on_meta_device(builder_args) | ||
|
@@ -470,8 +476,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: | |
|
||
if model.config.model_type == ModelType.Flamingo: | ||
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path | ||
with set_default_dtype(builder_args.precision), torch.device( | ||
builder_args.device | ||
with ( | ||
set_default_dtype(builder_args.precision), | ||
torch.device(builder_args.device), | ||
): | ||
# It doubles the model size the memory, with redundancies of the initialized weights. | ||
# model = Model.from_params(builder_args.params_path) | ||
|
@@ -507,6 +514,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: | |
# AOTI-compoiled model will load its own weights. | ||
# Release weights here to avoid OOM | ||
import gc | ||
|
||
if hasattr(model, "model"): | ||
model.model = None | ||
gc.collect() | ||
|
@@ -564,6 +572,7 @@ def _initialize_model( | |
|
||
def do_nothing(max_batch_size, max_seq_length): | ||
pass | ||
|
||
model.setup_caches = do_nothing | ||
|
||
model.forward = torch._export.aot_load( | ||
|
@@ -601,6 +610,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
|
||
def do_nothing(max_batch_size, max_seq_length): | ||
pass | ||
|
||
model.setup_caches = do_nothing | ||
|
||
model.forward = aoti_compiled_model | ||
|
@@ -652,12 +662,15 @@ def do_nothing(max_batch_size, max_seq_length): | |
try: | ||
model = torch.load(builder_args.snapshot_path, weights_only=False) | ||
except Exception: | ||
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") | ||
raise RuntimeError( | ||
f"Failed to load torchchat snapshot {builder_args.snapshot_path}" | ||
) | ||
# _active_backend() does not allow DSO & AOTI to be true. | ||
# Choose either. | ||
from torchchat.utils.build_utils import set_backend | ||
set_backend (dso=True, pte=False, aoti_package=False) | ||
if (model.config != config): | ||
|
||
set_backend(dso=True, pte=False, aoti_package=False) | ||
if model.config != config: | ||
raise RuntimeError("loaded model architecture mismatch") | ||
## | ||
## import all libraries with custom kernels ans custom operators | ||
|
@@ -675,7 +688,9 @@ def do_nothing(max_batch_size, max_seq_length): | |
logger = SingletonLogger.get_logger() | ||
|
||
gpu_memory_monitor = GPUMemoryMonitor("cuda") | ||
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") | ||
logger.info( | ||
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" | ||
) | ||
|
||
# Model-level config | ||
if builder_args.params_table: | ||
|
@@ -686,20 +701,16 @@ def do_nothing(max_batch_size, max_seq_length): | |
config = TransformerArgs.from_params(model_config.transformer_args["text"]) | ||
logger.info(f"Transformer Config: {config}") | ||
|
||
#TODO: Move into head of file after solving circular import | ||
from torchchat.distributed.checkpoint_utils import ( | ||
load_model_weights, | ||
) | ||
# TODO: Move into head of file after solving circular import | ||
from torchchat.distributed.checkpoint_utils import load_model_weights | ||
|
||
# Validate pipeline degree | ||
assert config.n_layers % pp_degree == 0 | ||
|
||
# Create device mesh | ||
device_mesh = dist.init_device_mesh( | ||
"cuda", | ||
(pp_degree, tp_degree), | ||
mesh_dim_names=("pp", "tp") | ||
) | ||
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") | ||
) | ||
tp_mesh = device_mesh["tp"] | ||
pp_mesh = device_mesh["pp"] | ||
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") | ||
|
@@ -728,7 +739,13 @@ def do_nothing(max_batch_size, max_seq_length): | |
# Load weights | ||
logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
with CUDATrackTime() as timer: | ||
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) | ||
load_model_weights( | ||
model, | ||
builder_args.distribution_path, | ||
device, | ||
config, | ||
builder_args.chpt_from, | ||
) | ||
|
||
logger.info( | ||
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
|
@@ -742,7 +759,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
# lanes. | ||
# TODO: bump up the lane count | ||
pipeline_lanes = 1 | ||
seqlen_prefill=1024 | ||
seqlen_prefill = 1024 | ||
with device: | ||
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm let's push the install further into the check on line 419.
This function would error if we don't have torchtune installed