Skip to content

Commit 8e04011

Browse files
committed
Add util to export/import mlflow data to remote mlflow
1 parent b2df00d commit 8e04011

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

utils/export_import_mlflow_exp.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
import json
3+
import tarfile
4+
import mlflow
5+
import argparse
6+
import shutil
7+
from mlflow.tracking import MlflowClient
8+
from botocore.exceptions import NoCredentialsError
9+
10+
def export_run(experiment_id, run_id, export_dir="mlflow_export", mlruns_dir=None):
11+
"""
12+
Export an MLflow run to a directory.
13+
"""
14+
if mlruns_dir:
15+
print(f"Setting MLflow tracking URI to local directory: {mlruns_dir}")
16+
mlflow.set_tracking_uri(f"{mlruns_dir}")
17+
18+
client = MlflowClient()
19+
experiment = client.get_experiment(experiment_id)
20+
21+
if experiment is None:
22+
raise ValueError(f"Experiment ID {experiment_id} does not exist.")
23+
print(f"Found experiment: {experiment.name}")
24+
25+
run = client.get_run(run_id)
26+
if run is None:
27+
raise ValueError(f"Run ID {run_id} does not exist in Experiment ID {experiment_id}.")
28+
print(f"Found run: {run_id}")
29+
30+
os.makedirs(export_dir, exist_ok=True)
31+
print(f"Created export directory: {export_dir}")
32+
33+
# Save run metadata
34+
run_data = {
35+
"experiment_id": experiment_id,
36+
"experiment_name": experiment.name,
37+
"run_id": run_id,
38+
"params": run.data.params,
39+
"metrics": run.data.metrics,
40+
"tags": run.data.tags
41+
}
42+
with open(os.path.join(export_dir, "run.json"), "w") as f:
43+
json.dump(run_data, f, indent=4)
44+
print("Saved run metadata")
45+
46+
# Save artifacts
47+
artifact_dir = os.path.join(export_dir, run_id)
48+
os.makedirs(artifact_dir, exist_ok=True)
49+
print(f"Created artifact directory: {artifact_dir}")
50+
51+
# Download artifacts
52+
artifacts = client.list_artifacts(run_id, f"{mlruns_dir}")
53+
if artifacts:
54+
for artifact in artifacts:
55+
artifact_path = os.path.join(artifact_dir, artifact.path)
56+
if not os.path.exists(artifact_path):
57+
try:
58+
os.system(f"mlflow artifacts download -r {run_id} -d {artifact_dir}")
59+
print(f"Downloaded artifact: {artifact.path}")
60+
except FileNotFoundError:
61+
print(f"Artifact {artifact.path} not found, skipping.")
62+
else:
63+
print("No artifacts to export.")
64+
65+
# Compress exported data
66+
tar_path = f"{export_dir}.tar.gz"
67+
with tarfile.open(tar_path, "w:gz") as tar:
68+
tar.add(export_dir, arcname=os.path.basename(export_dir))
69+
print(f"Compressed exported data to: {tar_path}")
70+
71+
return tar_path
72+
73+
def import_run(tar_path, remote_url):
74+
"""
75+
Import an MLflow run from a tar file into a remote MLflow server.
76+
"""
77+
extract_dir = tar_path.replace(".tar.gz", "")
78+
with tarfile.open(tar_path, "r:gz") as tar:
79+
tar.extractall(extract_dir)
80+
print(f"Extracted tar file to: {extract_dir}")
81+
82+
mlflow.set_tracking_uri(remote_url)
83+
print(f"Setting MLflow tracking URI to remote URL: {remote_url}")
84+
client = MlflowClient()
85+
86+
# Read run metadata
87+
with open(os.path.join(extract_dir, "run.json"), "r") as f:
88+
run_data = json.load(f)
89+
print("Read run metadata")
90+
91+
experiment_name = run_data['experiment_name']
92+
try:
93+
experiment_id = client.create_experiment(experiment_name)
94+
print(f"Created new experiment: {experiment_name} with ID: {experiment_id}")
95+
except mlflow.exceptions.MlflowException:
96+
experiment = client.get_experiment_by_name(experiment_name)
97+
experiment_id = experiment.experiment_id
98+
print(f"Experiment {experiment_name} already exists with ID: {experiment_id}")
99+
100+
# Check if the run ID already exists in the remote
101+
remote_runs = client.search_runs(experiment_ids=[experiment_id])
102+
if any(run.info.run_id == run_data['run_id'] for run in remote_runs):
103+
print(f"Run ID {run_data['run_id']} already exists in the remote. Skipping import.")
104+
return
105+
106+
# Recreate run
107+
with mlflow.start_run(experiment_id=experiment_id, run_name=run_data['run_id']):
108+
for param, value in run_data['params'].items():
109+
mlflow.log_param(param, value)
110+
for metric, value in run_data['metrics'].items():
111+
mlflow.log_metric(metric, value)
112+
for tag, value in run_data['tags'].items():
113+
mlflow.set_tag(tag, value)
114+
print("Logged run parameters, metrics, and tags")
115+
116+
# Upload artifacts
117+
artifact_path = os.path.join(extract_dir, run_data['run_id'])
118+
if os.path.exists(artifact_path) and os.listdir(artifact_path):
119+
try:
120+
mlflow.log_artifacts(artifact_path)
121+
print(f"Uploaded artifacts from: {artifact_path}")
122+
except NoCredentialsError:
123+
print("No AWS credentials found. Skipping artifact upload.")
124+
else:
125+
print("No artifacts to upload.")
126+
127+
print(f"Run successfully imported to {remote_url} with Experiment ID {experiment_id}")
128+
129+
# Delete the local export directory
130+
shutil.rmtree(extract_dir)
131+
print(f"Deleted local export directory: {extract_dir}")
132+
133+
if __name__ == "__main__":
134+
parser = argparse.ArgumentParser(description="Export and import an MLflow run.")
135+
parser.add_argument("--experiment_id", type=str, required=True, help="MLflow experiment ID.")
136+
parser.add_argument("--run_id", type=str, required=True, help="MLflow run ID.")
137+
parser.add_argument("--remote_mlflow_url", type=str, required=True, help="Remote MLflow tracking server URL.")
138+
parser.add_argument("--mlruns_dir", type=str, required=True, help="Directory of the local mlruns.")
139+
args = parser.parse_args()
140+
141+
print('Warning. There is no way to check if this RUN id was already uploaded. So please check by hand if you did it already.')
142+
tar_file = export_run(args.experiment_id, args.run_id, mlruns_dir=args.mlruns_dir)
143+
import_run(tar_file, args.remote_mlflow_url)

0 commit comments

Comments
 (0)