diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 14dfeb062..f145c93fb 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import Optional -from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune +from torchchat.cli.convert_hf_checkpoint import ( + convert_hf_checkpoint, + convert_hf_checkpoint_to_tune, +) from torchchat.model_config.model_config import ( load_model_configs, ModelConfig, @@ -57,7 +60,6 @@ def _download_hf_snapshot( snapshot_download( model_config.distribution_path, local_dir=artifact_dir, - local_dir_use_symlinks=False, token=hf_token, ignore_patterns=ignore_patterns, ) @@ -77,9 +79,14 @@ def _download_hf_snapshot( raise e # Convert the Multimodal Llama model to the torchtune format. - if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}: + if model_config.name in { + "meta-llama/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-11B-Vision", + }: print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr) - convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name) + convert_hf_checkpoint_to_tune( + model_dir=artifact_dir, model_name=model_config.name + ) else: # Convert the model to the torchchat format.