Skip to content

Commit b2df00d

Browse files
committed
Add mlflow url. Store and print the mlflow id
1 parent 004ebf1 commit b2df00d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

agents/attackers/random/random_agent.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def select_action(self, observation:Observation)->Action:
6767
parser.add_argument("--test_each", help="Evaluate performance during testing every this number of episodes.", default=10, type=int)
6868
parser.add_argument("--logdir", help="Folder to store logs", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs"))
6969
parser.add_argument("--evaluate", help="Evaluate the agent and report, instead of playing the game only once.", default=True)
70+
parser.add_argument("--mlflow_url", help="URL for mlflow tracking server. If not provided, mlflow will store locally.", default=None)
7071
args = parser.parse_args()
7172

7273
if not os.path.exists(args.logdir):
@@ -94,7 +95,8 @@ def select_action(self, observation:Observation)->Action:
9495

9596
# Mlflow experiment name
9697
experiment_name = "Evaluation of Random Agent"
97-
mlflow.set_tracking_uri("http://127.0.0.1:8000")
98+
if args.mlflow_url:
99+
mlflow.set_tracking_uri(args.mlflow_url)
98100
mlflow.set_experiment(experiment_name)
99101
# Register in the game
100102
observation = agent.register()
@@ -203,4 +205,15 @@ def select_action(self, observation:Observation)->Action:
203205
agent.logger.info(text)
204206
print(text)
205207
agent._logger.info("Terminating interaction")
206-
agent.terminate_connection()
208+
agent.terminate_connection()
209+
210+
# Print and log the mlflow experiment ID, run ID, and storage location
211+
experiment_id = run.info.experiment_id
212+
run_id = run.info.run_id
213+
storage_location = "locally" if not args.mlflow_url else f"at {args.mlflow_url}"
214+
print(f"MLflow Experiment ID: {experiment_id}")
215+
print(f"MLflow Run ID: {run_id}")
216+
print(f"Experiment saved {storage_location}")
217+
agent._logger.info(f"MLflow Experiment ID: {experiment_id}")
218+
agent._logger.info(f"MLflow Run ID: {run_id}")
219+
agent._logger.info(f"Experiment saved {storage_location}")

0 commit comments

Comments
 (0)