66import subprocess
77import mlflow
88from mlflow .models import infer_signature
9+ import requests
10+ import time
11+ import socket
12+ from urllib .parse import urlparse
13+ import logging
914
1015import pandas as pd
1116from sklearn import datasets
1217from sklearn .model_selection import train_test_split
1318from sklearn .linear_model import LogisticRegression
1419from sklearn .metrics import accuracy_score
1520
16- def run_mlflow_test (tracking_uri ):
21+ # Configure logging
22+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
23+ logger = logging .getLogger (__name__ )
24+
25+ def check_server_connection (tracking_uri , timeout = 30 , retry_interval = 5 ):
26+ """
27+ Check if the MLflow server is reachable
28+
29+ Args:
30+ tracking_uri: The URI of the MLflow server
31+ timeout: Maximum time in seconds to wait for the server
32+ retry_interval: Interval in seconds between retries
33+
34+ Returns:
35+ bool: True if the server is reachable, False otherwise
36+ """
37+ logger .info (f"Checking connection to MLflow server at { tracking_uri } " )
38+
39+ url = tracking_uri
40+ if not url .endswith ('/' ):
41+ url += '/'
42+
43+ # Add health check endpoint if using standard MLflow API
44+ health_url = f"{ url } api/2.0/mlflow/experiments/list"
45+
46+ # Parse URL to get host and port for socket check
47+ parsed_url = urlparse (tracking_uri )
48+ host = parsed_url .hostname
49+ port = parsed_url .port or (443 if parsed_url .scheme == 'https' else 80 )
50+
51+ start_time = time .time ()
52+ while time .time () - start_time < timeout :
53+ # First try a basic socket connection
54+ try :
55+ socket .create_connection ((host , port ), timeout = 5 )
56+ logger .info (f"Socket connection to { host } :{ port } successful" )
57+ except (socket .timeout , socket .error , ConnectionRefusedError ) as e :
58+ logger .warning (f"Socket connection failed: { e } " )
59+ logger .info (f"Retrying in { retry_interval } seconds..." )
60+ time .sleep (retry_interval )
61+ continue
62+
63+ # Then try an HTTP request to the API
64+ try :
65+ response = requests .get (health_url , timeout = 5 , verify = False )
66+ if response .status_code == 200 :
67+ logger .info (f"MLflow server is reachable at { tracking_uri } " )
68+ return True
69+ else :
70+ logger .warning (f"MLflow server returned status code { response .status_code } " )
71+ except requests .exceptions .RequestException as e :
72+ logger .warning (f"HTTP request failed: { e } " )
73+
74+ logger .info (f"Retrying in { retry_interval } seconds..." )
75+ time .sleep (retry_interval )
76+
77+ logger .error (f"Could not connect to MLflow server at { tracking_uri } after { timeout } seconds" )
78+ return False
79+
80+ def run_mlflow_test (tracking_uri , connection_timeout = 60 ):
1781 """
1882 Run MLflow test with the specified tracking URI
1983
2084 Args:
2185 tracking_uri: The URI to use for the MLflow tracking server
86+ connection_timeout: Timeout in seconds for server connection
2287
2388 Returns:
2489 True if the test passed, False otherwise
2590 """
2691 try :
27- print (f"Setting MLflow tracking URI to: { tracking_uri } " )
92+ logger .info (f"Setting MLflow tracking URI to: { tracking_uri } " )
93+
94+ # Disable SSL warnings for self-signed certificates
95+ import urllib3
96+ urllib3 .disable_warnings (urllib3 .exceptions .InsecureRequestWarning )
97+
98+ # Check if the server is reachable before proceeding
99+ if not check_server_connection (tracking_uri , timeout = connection_timeout ):
100+ logger .error ("Failed to connect to MLflow server, aborting test" )
101+ return False
102+
28103 mlflow .set_tracking_uri (tracking_uri )
29104
30105 # Load the Iris dataset
106+ logger .info ("Loading dataset and training model..." )
31107 X , y = datasets .load_iris (return_X_y = True )
32108
33109 # Split the data into training and test sets
@@ -39,7 +115,7 @@ def run_mlflow_test(tracking_uri):
39115 params = {
40116 "solver" : "lbfgs" ,
41117 "max_iter" : 1000 ,
42- "multi_class" : "auto" ,
118+ "multi_class" : "auto" , # Deprecated but keeping for now
43119 "random_state" : 8888 ,
44120 }
45121
@@ -53,47 +129,69 @@ def run_mlflow_test(tracking_uri):
53129 # Calculate metrics
54130 accuracy = accuracy_score (y_test , y_pred )
55131
56- print ("Current tracking URI:" , mlflow .get_tracking_uri ())
132+ logger .info (f"Current tracking URI: { mlflow .get_tracking_uri ()} " )
133+ logger .info (f"Model trained with accuracy: { accuracy :.4f} " )
57134
58135 # Create a new MLflow Experiment
59- mlflow .set_experiment ("MLflow CI Test" )
136+ logger .info ("Creating MLflow experiment..." )
137+ experiment_name = "MLflow CI Test"
138+ try :
139+ experiment = mlflow .get_experiment_by_name (experiment_name )
140+ if experiment is None :
141+ experiment_id = mlflow .create_experiment (experiment_name )
142+ logger .info (f"Created new experiment with ID: { experiment_id } " )
143+ else :
144+ logger .info (f"Using existing experiment with ID: { experiment .experiment_id } " )
145+ mlflow .set_experiment (experiment_name )
146+ except Exception as e :
147+ logger .error (f"Failed to create or set experiment: { e } " )
148+ return False
60149
61150 # Start an MLflow run
62- with mlflow .start_run ():
63- # Log the hyperparameters
64- mlflow .log_params (params )
65-
66- # Log the loss metric
67- mlflow .log_metric ("accuracy" , accuracy )
68-
69- # Set a tag that we can use to remind ourselves what this run was for
70- mlflow .set_tag ("Training Info" , "CI Test for MLflow" )
71-
72- # Infer the model signature
73- signature = infer_signature (X_train , lr .predict (X_train ))
74-
75- # Log the model
76- model_info = mlflow .sklearn .log_model (
77- sk_model = lr ,
78- artifact_path = "iris_model" ,
79- registered_model_name = "ci-test-model" ,
80- signature = signature
81- )
82-
83- print (f"Model URI: { model_info .model_uri } " )
84-
85- # Load the model back for predictions as a generic Python Function model
151+ logger .info ("Starting MLflow run..." )
86152 try :
87- loaded_model = mlflow .pyfunc .load_model (model_info .model_uri )
88- predictions = loaded_model .predict (X_test [:3 ])
89- print (f"Test predictions: { predictions } " )
90- return True
153+ with mlflow .start_run ():
154+ # Log the hyperparameters
155+ mlflow .log_params (params )
156+
157+ # Log the loss metric
158+ mlflow .log_metric ("accuracy" , accuracy )
159+
160+ # Set a tag that we can use to remind ourselves what this run was for
161+ mlflow .set_tag ("Training Info" , "CI Test for MLflow" )
162+
163+ # Infer the model signature
164+ signature = infer_signature (X_train , lr .predict (X_train ))
165+
166+ # Log the model
167+ logger .info ("Logging model to MLflow..." )
168+ model_info = mlflow .sklearn .log_model (
169+ sk_model = lr ,
170+ artifact_path = "iris_model" ,
171+ registered_model_name = "ci-test-model" ,
172+ signature = signature
173+ )
174+
175+ logger .info (f"Model URI: { model_info .model_uri } " )
176+
177+ # Load the model back for predictions as a generic Python Function model
178+ try :
179+ logger .info ("Loading model for predictions..." )
180+ loaded_model = mlflow .pyfunc .load_model (model_info .model_uri )
181+ predictions = loaded_model .predict (X_test [:3 ])
182+ logger .info (f"Test predictions: { predictions } " )
183+ return True
184+ except Exception as e :
185+ logger .error (f"Error loading model: { e } " )
186+ return False
91187 except Exception as e :
92- print (f"Error loading model : { e } " )
188+ logger . error (f"Error during MLflow run : { e } " )
93189 return False
94190
95191 except Exception as e :
96- print (f"Test failed with error: { e } " )
192+ logger .error (f"Test failed with error: { e } " )
193+ import traceback
194+ logger .error (traceback .format_exc ())
97195 return False
98196
99197def ensure_dependencies ():
@@ -102,21 +200,29 @@ def ensure_dependencies():
102200 import mlflow
103201 import pandas
104202 import sklearn
203+ import requests
105204 except ImportError :
106- print ("Installing required dependencies..." )
205+ logger . info ("Installing required dependencies..." )
107206 subprocess .check_call ([
108207 sys .executable , "-m" , "pip" , "install" ,
109- "mlflow" , "pandas" , "scikit-learn"
208+ "mlflow" , "pandas" , "scikit-learn" , "requests"
110209 ])
111210
112211def main ():
113212 parser = argparse .ArgumentParser (description = "MLflow CI testing tool" )
114213 parser .add_argument ("hostname" , help = "Hostname of the MLflow server" )
115214 parser .add_argument ("--port" , type = int , help = "Port number (if not included in hostname)" )
116215 parser .add_argument ("--protocol" , default = "https" , help = "Protocol (http or https, default: https)" )
216+ parser .add_argument ("--connection-timeout" , type = int , default = 60 ,
217+ help = "Timeout in seconds for server connection (default: 60)" )
218+ parser .add_argument ("--debug" , action = "store_true" , help = "Enable debug logs" )
117219
118220 args = parser .parse_args ()
119221
222+ # Set logging level based on debug flag
223+ if args .debug :
224+ logging .getLogger ().setLevel (logging .DEBUG )
225+
120226 # Build the tracking URI
121227 tracking_uri = f"{ args .protocol } ://{ args .hostname } "
122228 if args .port :
@@ -126,13 +232,14 @@ def main():
126232 ensure_dependencies ()
127233
128234 # Run the test
129- success = run_mlflow_test (tracking_uri )
235+ logger .info (f"Starting MLflow test against server: { tracking_uri } " )
236+ success = run_mlflow_test (tracking_uri , connection_timeout = args .connection_timeout )
130237
131238 if success :
132- print ("✅ MLflow test completed successfully" )
239+ logger . info ("✅ MLflow test completed successfully" )
133240 sys .exit (0 )
134241 else :
135- print ("❌ MLflow test failed" )
242+ logger . error ("❌ MLflow test failed" )
136243 sys .exit (1 )
137244
138245if __name__ == "__main__" :
0 commit comments