Skip to content

Commit b8d1254

Browse files
committed
debug
1 parent d2f9820 commit b8d1254

File tree

3 files changed

+163
-42
lines changed

3 files changed

+163
-42
lines changed

.github/workflows/mlflow-ci.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,15 @@ jobs:
277277
run: |
278278
cd applications/mlflow
279279
echo "Installing Python dependencies for tests..."
280-
pip install mlflow pandas scikit-learn
280+
pip install mlflow pandas scikit-learn requests urllib3
281281
282282
echo "Running MLflow application tests against ${{ steps.expose-port.outputs.hostname }}"
283-
python tests/mlflow_test.py ${{ steps.expose-port.outputs.hostname }} --protocol https
283+
echo "This may take some time as it will retry connections for up to 2 minutes"
284+
python tests/mlflow_test.py ${{ steps.expose-port.outputs.hostname }} \
285+
--protocol https \
286+
--connection-timeout 60 \
287+
--debug
288+
timeout-minutes: 5
284289

285290
- name: Install troubleshoot
286291
run: curl -L https://github.com/replicatedhq/troubleshoot/releases/latest/download/support-bundle_linux_amd64.tar.gz | tar xzvf -

applications/mlflow/Makefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ test-replicated-helm-with-values: registry-login
143143
MLFLOW_VALUES_ARGS=""; \
144144
if [ -n "$$MLFLOW_VALUES" ]; then \
145145
echo "Using MLflow values file: $$MLFLOW_VALUES"; \
146+
# Check if values file exists
147+
if [ ! -f "$$MLFLOW_VALUES" ]; then \
148+
echo "ERROR: Values file '$$MLFLOW_VALUES' does not exist"; \
149+
exit 1; \
150+
fi; \
146151
MLFLOW_VALUES_ARGS="--values $$MLFLOW_VALUES"; \
152+
echo "Values args: $$MLFLOW_VALUES_ARGS"; \
153+
else \
154+
echo "No custom values file provided. Using default values."; \
147155
fi; \
148156
\
149157
# Create namespace if it doesn't exist
@@ -163,6 +171,7 @@ test-replicated-helm-with-values: registry-login
163171
# Install MLflow chart from Replicated registry with custom values
164172
echo "Installing mlflow chart from Replicated registry with custom values..."; \
165173
echo "Chart path: $$OCI_URL/mlflow"; \
174+
echo "Using values args: $$MLFLOW_VALUES_ARGS"; \
166175
helm upgrade --install mlflow-values-test $$OCI_URL/mlflow \
167176
--namespace values-test \
168177
$$MLFLOW_VALUES_ARGS \

applications/mlflow/tests/mlflow_test.py

Lines changed: 147 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,104 @@
66
import subprocess
77
import mlflow
88
from mlflow.models import infer_signature
9+
import requests
10+
import time
11+
import socket
12+
from urllib.parse import urlparse
13+
import logging
914

1015
import pandas as pd
1116
from sklearn import datasets
1217
from sklearn.model_selection import train_test_split
1318
from sklearn.linear_model import LogisticRegression
1419
from sklearn.metrics import accuracy_score
1520

16-
def run_mlflow_test(tracking_uri):
21+
# Configure logging
22+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23+
logger = logging.getLogger(__name__)
24+
25+
def check_server_connection(tracking_uri, timeout=30, retry_interval=5):
26+
"""
27+
Check if the MLflow server is reachable
28+
29+
Args:
30+
tracking_uri: The URI of the MLflow server
31+
timeout: Maximum time in seconds to wait for the server
32+
retry_interval: Interval in seconds between retries
33+
34+
Returns:
35+
bool: True if the server is reachable, False otherwise
36+
"""
37+
logger.info(f"Checking connection to MLflow server at {tracking_uri}")
38+
39+
url = tracking_uri
40+
if not url.endswith('/'):
41+
url += '/'
42+
43+
# Add health check endpoint if using standard MLflow API
44+
health_url = f"{url}api/2.0/mlflow/experiments/list"
45+
46+
# Parse URL to get host and port for socket check
47+
parsed_url = urlparse(tracking_uri)
48+
host = parsed_url.hostname
49+
port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
50+
51+
start_time = time.time()
52+
while time.time() - start_time < timeout:
53+
# First try a basic socket connection
54+
try:
55+
socket.create_connection((host, port), timeout=5)
56+
logger.info(f"Socket connection to {host}:{port} successful")
57+
except (socket.timeout, socket.error, ConnectionRefusedError) as e:
58+
logger.warning(f"Socket connection failed: {e}")
59+
logger.info(f"Retrying in {retry_interval} seconds...")
60+
time.sleep(retry_interval)
61+
continue
62+
63+
# Then try an HTTP request to the API
64+
try:
65+
response = requests.get(health_url, timeout=5, verify=False)
66+
if response.status_code == 200:
67+
logger.info(f"MLflow server is reachable at {tracking_uri}")
68+
return True
69+
else:
70+
logger.warning(f"MLflow server returned status code {response.status_code}")
71+
except requests.exceptions.RequestException as e:
72+
logger.warning(f"HTTP request failed: {e}")
73+
74+
logger.info(f"Retrying in {retry_interval} seconds...")
75+
time.sleep(retry_interval)
76+
77+
logger.error(f"Could not connect to MLflow server at {tracking_uri} after {timeout} seconds")
78+
return False
79+
80+
def run_mlflow_test(tracking_uri, connection_timeout=60):
1781
"""
1882
Run MLflow test with the specified tracking URI
1983
2084
Args:
2185
tracking_uri: The URI to use for the MLflow tracking server
86+
connection_timeout: Timeout in seconds for server connection
2287
2388
Returns:
2489
True if the test passed, False otherwise
2590
"""
2691
try:
27-
print(f"Setting MLflow tracking URI to: {tracking_uri}")
92+
logger.info(f"Setting MLflow tracking URI to: {tracking_uri}")
93+
94+
# Disable SSL warnings for self-signed certificates
95+
import urllib3
96+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
97+
98+
# Check if the server is reachable before proceeding
99+
if not check_server_connection(tracking_uri, timeout=connection_timeout):
100+
logger.error("Failed to connect to MLflow server, aborting test")
101+
return False
102+
28103
mlflow.set_tracking_uri(tracking_uri)
29104

30105
# Load the Iris dataset
106+
logger.info("Loading dataset and training model...")
31107
X, y = datasets.load_iris(return_X_y=True)
32108

33109
# Split the data into training and test sets
@@ -39,7 +115,7 @@ def run_mlflow_test(tracking_uri):
39115
params = {
40116
"solver": "lbfgs",
41117
"max_iter": 1000,
42-
"multi_class": "auto",
118+
"multi_class": "auto", # Deprecated but keeping for now
43119
"random_state": 8888,
44120
}
45121

@@ -53,47 +129,69 @@ def run_mlflow_test(tracking_uri):
53129
# Calculate metrics
54130
accuracy = accuracy_score(y_test, y_pred)
55131

56-
print("Current tracking URI:", mlflow.get_tracking_uri())
132+
logger.info(f"Current tracking URI: {mlflow.get_tracking_uri()}")
133+
logger.info(f"Model trained with accuracy: {accuracy:.4f}")
57134

58135
# Create a new MLflow Experiment
59-
mlflow.set_experiment("MLflow CI Test")
136+
logger.info("Creating MLflow experiment...")
137+
experiment_name = "MLflow CI Test"
138+
try:
139+
experiment = mlflow.get_experiment_by_name(experiment_name)
140+
if experiment is None:
141+
experiment_id = mlflow.create_experiment(experiment_name)
142+
logger.info(f"Created new experiment with ID: {experiment_id}")
143+
else:
144+
logger.info(f"Using existing experiment with ID: {experiment.experiment_id}")
145+
mlflow.set_experiment(experiment_name)
146+
except Exception as e:
147+
logger.error(f"Failed to create or set experiment: {e}")
148+
return False
60149

61150
# Start an MLflow run
62-
with mlflow.start_run():
63-
# Log the hyperparameters
64-
mlflow.log_params(params)
65-
66-
# Log the loss metric
67-
mlflow.log_metric("accuracy", accuracy)
68-
69-
# Set a tag that we can use to remind ourselves what this run was for
70-
mlflow.set_tag("Training Info", "CI Test for MLflow")
71-
72-
# Infer the model signature
73-
signature = infer_signature(X_train, lr.predict(X_train))
74-
75-
# Log the model
76-
model_info = mlflow.sklearn.log_model(
77-
sk_model=lr,
78-
artifact_path="iris_model",
79-
registered_model_name="ci-test-model",
80-
signature=signature
81-
)
82-
83-
print(f"Model URI: {model_info.model_uri}")
84-
85-
# Load the model back for predictions as a generic Python Function model
151+
logger.info("Starting MLflow run...")
86152
try:
87-
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
88-
predictions = loaded_model.predict(X_test[:3])
89-
print(f"Test predictions: {predictions}")
90-
return True
153+
with mlflow.start_run():
154+
# Log the hyperparameters
155+
mlflow.log_params(params)
156+
157+
# Log the loss metric
158+
mlflow.log_metric("accuracy", accuracy)
159+
160+
# Set a tag that we can use to remind ourselves what this run was for
161+
mlflow.set_tag("Training Info", "CI Test for MLflow")
162+
163+
# Infer the model signature
164+
signature = infer_signature(X_train, lr.predict(X_train))
165+
166+
# Log the model
167+
logger.info("Logging model to MLflow...")
168+
model_info = mlflow.sklearn.log_model(
169+
sk_model=lr,
170+
artifact_path="iris_model",
171+
registered_model_name="ci-test-model",
172+
signature=signature
173+
)
174+
175+
logger.info(f"Model URI: {model_info.model_uri}")
176+
177+
# Load the model back for predictions as a generic Python Function model
178+
try:
179+
logger.info("Loading model for predictions...")
180+
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
181+
predictions = loaded_model.predict(X_test[:3])
182+
logger.info(f"Test predictions: {predictions}")
183+
return True
184+
except Exception as e:
185+
logger.error(f"Error loading model: {e}")
186+
return False
91187
except Exception as e:
92-
print(f"Error loading model: {e}")
188+
logger.error(f"Error during MLflow run: {e}")
93189
return False
94190

95191
except Exception as e:
96-
print(f"Test failed with error: {e}")
192+
logger.error(f"Test failed with error: {e}")
193+
import traceback
194+
logger.error(traceback.format_exc())
97195
return False
98196

99197
def ensure_dependencies():
@@ -102,21 +200,29 @@ def ensure_dependencies():
102200
import mlflow
103201
import pandas
104202
import sklearn
203+
import requests
105204
except ImportError:
106-
print("Installing required dependencies...")
205+
logger.info("Installing required dependencies...")
107206
subprocess.check_call([
108207
sys.executable, "-m", "pip", "install",
109-
"mlflow", "pandas", "scikit-learn"
208+
"mlflow", "pandas", "scikit-learn", "requests"
110209
])
111210

112211
def main():
113212
parser = argparse.ArgumentParser(description="MLflow CI testing tool")
114213
parser.add_argument("hostname", help="Hostname of the MLflow server")
115214
parser.add_argument("--port", type=int, help="Port number (if not included in hostname)")
116215
parser.add_argument("--protocol", default="https", help="Protocol (http or https, default: https)")
216+
parser.add_argument("--connection-timeout", type=int, default=60,
217+
help="Timeout in seconds for server connection (default: 60)")
218+
parser.add_argument("--debug", action="store_true", help="Enable debug logs")
117219

118220
args = parser.parse_args()
119221

222+
# Set logging level based on debug flag
223+
if args.debug:
224+
logging.getLogger().setLevel(logging.DEBUG)
225+
120226
# Build the tracking URI
121227
tracking_uri = f"{args.protocol}://{args.hostname}"
122228
if args.port:
@@ -126,13 +232,14 @@ def main():
126232
ensure_dependencies()
127233

128234
# Run the test
129-
success = run_mlflow_test(tracking_uri)
235+
logger.info(f"Starting MLflow test against server: {tracking_uri}")
236+
success = run_mlflow_test(tracking_uri, connection_timeout=args.connection_timeout)
130237

131238
if success:
132-
print("✅ MLflow test completed successfully")
239+
logger.info("✅ MLflow test completed successfully")
133240
sys.exit(0)
134241
else:
135-
print("❌ MLflow test failed")
242+
logger.error("❌ MLflow test failed")
136243
sys.exit(1)
137244

138245
if __name__ == "__main__":

0 commit comments

Comments
 (0)