From 44c70d025c3b6c9cd4ff5f5ca856c6e6a48919eb Mon Sep 17 00:00:00 2001 From: Haodong <220246386@edu.seu.cn> Date: Fri, 14 Mar 2025 16:16:19 +0800 Subject: [PATCH] Your bug plz fix it --- inference/predict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/inference/predict.py b/inference/predict.py index e0fb0f2..b347141 100644 --- a/inference/predict.py +++ b/inference/predict.py @@ -11,7 +11,7 @@ from ivideogpt.vq_model import CompressiveVQModel from ivideogpt.transformer import HeadModelWithAction from utils import NPZParser - +from huggingface_hub import hf_hub_download device = 'cuda' @@ -104,7 +104,9 @@ def main(): tokens_num_per_dyna=tokens_per_dyna, context=args.context_length, segment_length=args.segment_length).to(device) - state_dict = load_file(os.path.join(args.pretrained_model_name_or_path, 'transformer', 'model.safetensors')) + local_file_path = hf_hub_download(repo_id=args.pretrained_model_name_or_path, filename="transformer/model.safetensors") + state_dict = load_file(local_file_path) + # state_dict = load_file(os.path.join(args.pretrained_model_name_or_path, 'transformer', 'model.safetensors')) model.load_state_dict(state_dict, strict=True) assert model.llm.config.vocab_size == tokenizer.num_vq_embeddings + tokenizer.num_dyn_embeddings + 2 else: