1010from pathlib import Path
1111from typing import Optional
1212
13- from torchchat .cli .convert_hf_checkpoint import convert_hf_checkpoint , convert_hf_checkpoint_to_tune
13+ from torchchat .cli .convert_hf_checkpoint import (
14+ convert_hf_checkpoint ,
15+ convert_hf_checkpoint_to_tune ,
16+ )
1417from torchchat .model_config .model_config import (
1518 load_model_configs ,
1619 ModelConfig ,
2023
2124# By default, download models from HuggingFace to the Hugginface hub directory.
2225# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
23- HUGGINGFACE_HOME_PATH = Path (os .environ .get ("HF_HOME" , os .environ .get ("HUGGINGFACE_HUB_CACHE" , os .path .expanduser ("~/.cache/huggingface/hub" ))))
26+ HUGGINGFACE_HOME_PATH = Path (
27+ os .environ .get (
28+ "HF_HOME" ,
29+ os .environ .get (
30+ "HUGGINGFACE_HUB_CACHE" , os .path .expanduser ("~/.cache/huggingface/hub" )
31+ ),
32+ )
33+ )
2434
2535if os .environ .get ("HF_HUB_ENABLE_HF_TRANSFER" , None ) is None :
2636 os .environ ["HF_HUB_ENABLE_HF_TRANSFER" ] = "1"
2737
28- def _download_hf_snapshot (
29- model_config : ModelConfig , hf_token : Optional [str ]
30- ):
38+
39+ # Previously, all models were stored in the torchchat models directory (by default ~/.torchchat/model-cache)
40+ # For Hugging Face models, we now store them in the HuggingFace cache directory.
41+ # This function will delete all model artifacts in the old directory for each model with the Hugging Face distribution path.
42+ def _cleanup_hf_models_from_torchchat_dir (models_dir : Path ):
43+ for model_config in load_model_configs ().values ():
44+ if (
45+ model_config .distribution_channel
46+ == ModelDistributionChannel .HuggingFaceSnapshot
47+ ):
48+ if os .path .exists (models_dir / model_config .name ):
49+ print (
50+ f"Cleaning up old model artifacts in { models_dir / model_config .name } . New artifacts will be downloaded to { HUGGINGFACE_HOME_PATH } "
51+ )
52+ shutil .rmtree (models_dir / model_config .name )
53+
54+
55+ def _download_hf_snapshot (model_config : ModelConfig , hf_token : Optional [str ]):
3156 from huggingface_hub import model_info , snapshot_download
3257 from requests .exceptions import HTTPError
3358
3459 # Download and store the HF model artifacts.
3560 model_dir = get_model_dir (model_config , None )
36- print (f"Downloading { model_config .name } from Hugging Face to { model_dir } " , file = sys .stderr , flush = True )
61+ print (
62+ f"Downloading { model_config .name } from Hugging Face to { model_dir } " ,
63+ file = sys .stderr ,
64+ flush = True ,
65+ )
3766 try :
3867 # Fetch the info about the model's repo
3968 model_info = model_info (model_config .distribution_path , token = hf_token )
@@ -81,14 +110,17 @@ def _download_hf_snapshot(
81110 else :
82111 raise e
83112
84- # Update the model dir to include the snapshot we just downloaded.
113+ # Update the model dir to include the snapshot we just downloaded.
85114 model_dir = get_model_dir (model_config , None )
86115 print ("Model downloaded to" , model_dir )
87116
88117 # Convert the Multimodal Llama model to the torchtune format.
89- if model_config .name in {"meta-llama/Llama-3.2-11B-Vision-Instruct" , "meta-llama/Llama-3.2-11B-Vision" }:
118+ if model_config .name in {
119+ "meta-llama/Llama-3.2-11B-Vision-Instruct" ,
120+ "meta-llama/Llama-3.2-11B-Vision" ,
121+ }:
90122 print (f"Converting { model_config .name } to torchtune format..." , file = sys .stderr )
91- convert_hf_checkpoint_to_tune ( model_dir = model_dir , model_name = model_config .name )
123+ convert_hf_checkpoint_to_tune (model_dir = model_dir , model_name = model_config .name )
92124
93125 else :
94126 # Convert the model to the torchchat format.
@@ -108,32 +140,44 @@ def _download_direct(
108140 print (f"Downloading { url } ..." , file = sys .stderr )
109141 urllib .request .urlretrieve (url , str (local_path .absolute ()))
110142
143+
111144def _get_hf_artifact_dir (model_config : ModelConfig ) -> Path :
112145 """
113146 Returns the directory where the model artifacts are stored.
114-
147+
115148 This is the root folder with blobs, refs and snapshots
116149 """
117- assert (model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot )
118- return HUGGINGFACE_HOME_PATH / f"models--{ model_config .distribution_path .replace ('/' , '--' )} "
150+ assert (
151+ model_config .distribution_channel
152+ == ModelDistributionChannel .HuggingFaceSnapshot
153+ )
154+ return (
155+ HUGGINGFACE_HOME_PATH
156+ / f"models--{ model_config .distribution_path .replace ('/' , '--' )} "
157+ )
119158
120159
121160def get_model_dir (model_config : ModelConfig , models_dir : Optional [Path ]) -> Path :
122161 """
123- Returns the directory where the model artifacts are stored.
124- For HuggingFace snapshots , this is the HuggingFace cache directory.
125- For all other distribution channels, we use the models_dir.
126-
127- For CLI usage, pass in args.model_directory.
162+ Returns the directory where the model artifacts are expected to be stored.
163+ For Hugging Face artifacts , this will be the location of the "main" snapshot if it exists, or the expected model directory otherwise .
164+ For all other distribution channels, we use the models_dir.
165+
166+ For CLI usage, pass in args.model_directory.
128167 """
129- if model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot :
130- artifact_dir = _get_hf_artifact_dir (model_config )
131-
168+ if (
169+ model_config .distribution_channel
170+ == ModelDistributionChannel .HuggingFaceSnapshot
171+ ):
172+ artifact_dir = _get_hf_artifact_dir (model_config )
173+
132174 # If these paths doesn't exist, it means the model hasn't been downloaded yet.
133- if not os .path .isdir (artifact_dir ) and not os .path .isdir (artifact_dir / "snapshots" ):
175+ if not os .path .isdir (artifact_dir ) and not os .path .isdir (
176+ artifact_dir / "snapshots"
177+ ):
134178 return artifact_dir
135179 snapshot = open (artifact_dir / "refs" / "main" , "r" ).read ().strip ()
136- return artifact_dir / "snapshots" / snapshot
180+ return artifact_dir / "snapshots" / snapshot
137181 else :
138182 return models_dir / model_config .name
139183
@@ -164,9 +208,7 @@ def download_and_convert(
164208 os .makedirs (temp_dir , exist_ok = True )
165209
166210 try :
167- if (
168- model_config .distribution_channel == ModelDistributionChannel .DirectDownload
169- ):
211+ if model_config .distribution_channel == ModelDistributionChannel .DirectDownload :
170212 _download_direct (model_config , temp_dir )
171213 else :
172214 raise RuntimeError (
@@ -187,7 +229,7 @@ def download_and_convert(
187229
188230def is_model_downloaded (model : str , models_dir : Path ) -> bool :
189231 model_config = resolve_model_config (model )
190-
232+
191233 # Check if the model directory exists and is not empty.
192234 model_dir = get_model_dir (model_config , models_dir )
193235 return os .path .isdir (model_dir ) and os .listdir (model_dir )
@@ -242,7 +284,10 @@ def remove_main(args) -> None:
242284 if not os .path .isdir (model_dir ):
243285 print (f"Model { args .model } has no downloaded artifacts in { model_dir } ." )
244286 return
245- if model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot :
287+ if (
288+ model_config .distribution_channel
289+ == ModelDistributionChannel .HuggingFaceSnapshot
290+ ):
246291 # For HuggingFace models, we need to remove the entire root directory.
247292 model_dir = _get_hf_artifact_dir (model_config )
248293
@@ -265,12 +310,15 @@ def where_main(args) -> None:
265310 model_dir = get_model_dir (model_config , args .model_directory )
266311
267312 if not os .path .isdir (model_dir ):
268- raise RuntimeError (f"Model { args .model } has no downloaded artifacts in { model_dir } ." )
313+ raise RuntimeError (
314+ f"Model { args .model } has no downloaded artifacts in { model_dir } ."
315+ )
269316
270317 print (str (os .path .abspath (model_dir )))
271318 exit (0 )
272319
273320
274321# Subcommand to download model artifacts.
275322def download_main (args ) -> None :
323+ _cleanup_hf_models_from_torchchat_dir (args .model_directory )
276324 download_and_convert (args .model , args .model_directory , args .hf_token )
0 commit comments