-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocal_download_model.py
More file actions
25 lines (21 loc) · 898 Bytes
/
local_download_model.py
File metadata and controls
25 lines (21 loc) · 898 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""
This script downloads a union artifact from the Union platform.
"""
import torch
from flytekit.types.file import FlyteFile
from union import Artifact, UnionRemote
# --------------------------------------------------
# Download & save the fine-tuned model from Union Artifacts
# --------------------------------------------------
FRCCNFineTunedModel = Artifact(name="frccn_fine_tuned_model")
query = FRCCNFineTunedModel.query(
project="default",
domain="development",
# version="anmrqcq8pfbnlp42j2vp/n3/0/o0" # Optional: specify version. Will download the latest version if not specified
)
remote = UnionRemote()
artifact = remote.get_artifact(query=query)
model_file: FlyteFile = artifact.get(as_type=FlyteFile)
model = torch.load(model_file.download(), map_location="cpu", weights_only=False)
save_dir = "local_frccn_faster_rcnn_trained_union.pth"
torch.save(model, save_dir)