Skip to content

Commit 92cdb9e

Browse files
committed
gradio_app changes
1 parent e18c56c commit 92cdb9e

File tree

3 files changed

+208
-28
lines changed

3 files changed

+208
-28
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,4 +301,5 @@ predictions_*.xlsx
301301
*.pth
302302
*.pth.gz
303303

304-
EVALUATION_EXPLANATION.md
304+
EVALUATION_EXPLANATION.md
305+
MODAL*.md

gradio_app.py

Lines changed: 205 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,34 @@
33
import tempfile
44
import zipfile
55
import shutil
6+
import time
67
from pathlib import Path
78
from datetime import datetime
89
import numpy as np
9-
import torch
10-
import pandas as pd
1110
import traceback
12-
import matplotlib
11+
12+
# Conditional imports for Modal deploy-time parsing
13+
try:
14+
import pandas as pd
15+
PANDAS_AVAILABLE = True
16+
except ImportError:
17+
PANDAS_AVAILABLE = False
18+
pd = None
19+
20+
try:
21+
import matplotlib
22+
MATPLOTLIB_AVAILABLE = True
23+
except ImportError:
24+
MATPLOTLIB_AVAILABLE = False
25+
matplotlib = None
26+
27+
# Conditional imports - only import torch if needed (for local PyTorch mode)
28+
try:
29+
import torch
30+
TORCH_AVAILABLE = True
31+
except ImportError:
32+
TORCH_AVAILABLE = False
33+
torch = None
1334

1435
from predict import (
1536
preprocess_audio,
@@ -71,7 +92,11 @@ def patched_json_schema(schema, defs=None):
7192
GRADIO_AVAILABLE = False
7293
gr = None
7394

74-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95+
# Device definition - only needed for PyTorch mode
96+
if TORCH_AVAILABLE:
97+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98+
else:
99+
device = None # Not needed for ONNX mode
75100

76101
class GradioAudioInterface:
77102
def __init__(self, model_path: str, threshold: float = DEFAULT_THRESHOLD, use_onnx: bool = False, is_modal: bool = False):
@@ -84,16 +109,59 @@ def __init__(self, model_path: str, threshold: float = DEFAULT_THRESHOLD, use_on
84109
if not ONNXRUNTIME_AVAILABLE:
85110
raise ImportError("onnxruntime is required for ONNX models. Install with: pip install onnxruntime")
86111
try:
112+
print(f"[{time.strftime('%H:%M:%S')}] Loading ONNX model from: {model_path}")
113+
87114
# Configure ONNX session options for stability
88115
sess_options = ort.SessionOptions()
89116
sess_options.intra_op_num_threads = 1 # Avoid threading issues in container
90117
sess_options.inter_op_num_threads = 1
91-
self.onnx_session = ort.InferenceSession(model_path, sess_options)
118+
119+
# Try to use GPU if available (for Modal GPU instances)
120+
# NOTE: Start with CPU to avoid GPU provider issues that can cause hangs
121+
providers = ['CPUExecutionProvider']
122+
123+
if is_modal:
124+
# On Modal with GPU, try CUDA provider but only if explicitly needed
125+
# CPU is more reliable and avoids hanging issues
126+
available_providers = ort.get_available_providers()
127+
if 'CUDAExecutionProvider' in available_providers:
128+
# Add CUDA as fallback (will use CPU first, then GPU if CPU fails)
129+
# Actually, let's use CPU only for now to avoid hanging
130+
print(f"[{time.strftime('%H:%M:%S')}] CUDAExecutionProvider available but using CPU for stability")
131+
# providers.insert(0, 'CUDAExecutionProvider') # Uncomment to try GPU
132+
else:
133+
print(f"[{time.strftime('%H:%M:%S')}] CUDAExecutionProvider not available, using CPU")
134+
135+
print(f"[{time.strftime('%H:%M:%S')}] Using providers: {providers}")
136+
137+
load_start = time.time()
138+
self.onnx_session = ort.InferenceSession(
139+
model_path,
140+
sess_options,
141+
providers=providers
142+
)
143+
load_time = time.time() - load_start
144+
print(f"[{time.strftime('%H:%M:%S')}] ONNX model loaded in {load_time:.2f}s")
145+
print(f"[{time.strftime('%H:%M:%S')}] Using providers: {self.onnx_session.get_providers()}")
146+
147+
# Log model input/output info
148+
for input_info in self.onnx_session.get_inputs():
149+
print(f"[{time.strftime('%H:%M:%S')}] Model input: {input_info.name}, shape: {input_info.shape}, type: {input_info.type}")
150+
for output_info in self.onnx_session.get_outputs():
151+
print(f"[{time.strftime('%H:%M:%S')}] Model output: {output_info.name}, shape: {output_info.shape}, type: {output_info.type}")
152+
92153
self.model = None
93154
except Exception as e:
94-
raise RuntimeError(f"Failed to load ONNX model from {model_path}: {str(e)}")
155+
error_msg = f"Failed to load ONNX model from {model_path}: {str(e)}"
156+
print(f"[{time.strftime('%H:%M:%S')}] {error_msg}")
157+
print(f"[{time.strftime('%H:%M:%S')}] Traceback: {traceback.format_exc()}")
158+
raise RuntimeError(error_msg)
95159
else:
96160
# PyTorch model loading
161+
if not TORCH_AVAILABLE:
162+
raise ImportError("PyTorch is required for PyTorch models. Install with: pip install torch")
163+
if device is None:
164+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97165
self.model = DualHeadCnn14Simple(pretrained=False)
98166
self.model.load_state_dict(torch.load(self.model_path, map_location=device, weights_only=False))
99167
self.model.eval().to(device)
@@ -181,36 +249,96 @@ def __predict_single_audio(self, audio_file):
181249
return f"Error: Path is not a regular file: {audio_file}. Got type: {type(audio_path)}", None, empty_df
182250

183251
# Preprocess audio file
184-
print(f"Processing audio file: {audio_file}")
252+
print(f"[{time.strftime('%H:%M:%S')}] Processing audio file: {audio_file}", flush=True)
253+
import sys
254+
sys.stdout.flush()
255+
preprocess_start = time.time()
185256
input_data = preprocess_audio(str(audio_file))
186-
print("Audio preprocessing complete")
257+
preprocess_time = time.time() - preprocess_start
258+
print(f"[{time.strftime('%H:%M:%S')}] Audio preprocessing completed in {preprocess_time:.2f}s", flush=True)
259+
print(f"[{time.strftime('%H:%M:%S')}] Preprocessed tensor shape: {input_data.shape}, dtype: {input_data.dtype}", flush=True)
260+
sys.stdout.flush()
187261

188262
if self.onnx:
189263
# ONNX inference
190-
print("Starting ONNX inference...")
264+
start_time = time.time()
265+
print(f"[{time.strftime('%H:%M:%S')}] Starting ONNX inference...")
266+
191267
# input_data shape: [1, audio_length] -> squeeze to [audio_length] -> reshape to [1, audio_length]
192268
input_tensor = input_data.squeeze(0).numpy().reshape(1, -1).astype(np.float32)
269+
print(f"[{time.strftime('%H:%M:%S')}] Input tensor shape: {input_tensor.shape}, dtype: {input_tensor.dtype}")
193270

194271
# Verify input shape matches ONNX model expectations
195272
if self.onnx_session is None:
196-
print("Error: ONNX session is None")
197-
return "Error: ONNX session not initialized", None, empty_df
273+
error_msg = "Error: ONNX session not initialized"
274+
print(f"[{time.strftime('%H:%M:%S')}] {error_msg}", flush=True)
275+
return error_msg, None, empty_df
276+
277+
import sys
278+
sys.stdout.flush()
279+
280+
# Get expected input shape from model
281+
try:
282+
input_name = self.onnx_session.get_inputs()[0].name
283+
expected_shape = self.onnx_session.get_inputs()[0].shape
284+
print(f"[{time.strftime('%H:%M:%S')}] Model expects input '{input_name}' with shape: {expected_shape}", flush=True)
285+
except Exception as e:
286+
print(f"[{time.strftime('%H:%M:%S')}] Warning: Could not get model input info: {e}", flush=True)
287+
288+
print(f"[{time.strftime('%H:%M:%S')}] Running ONNX session.run()...", flush=True)
289+
sys.stdout.flush()
290+
291+
try:
292+
inference_start = time.time()
293+
294+
# Ensure input is contiguous and correct shape
295+
if not input_tensor.flags['C_CONTIGUOUS']:
296+
input_tensor = np.ascontiguousarray(input_tensor)
297+
298+
# Validate input shape matches expected
299+
print(f"[{time.strftime('%H:%M:%S')}] Expected shape: {expected_shape}, Got: {input_tensor.shape}", flush=True)
300+
301+
# Handle dynamic batch dimension
302+
if len(expected_shape) == 2 and (expected_shape[0] == -1 or expected_shape[0] == 'batch_size'):
303+
# Dynamic batch size - ensure we have batch dimension
304+
if len(input_tensor.shape) == 1:
305+
input_tensor = input_tensor.reshape(1, -1)
306+
print(f"[{time.strftime('%H:%M:%S')}] Final input shape: {input_tensor.shape}", flush=True)
307+
308+
# Run inference with explicit input name
309+
print(f"[{time.strftime('%H:%M:%S')}] Calling session.run()...", flush=True)
310+
print(f"[{time.strftime('%H:%M:%S')}] Active providers: {self.onnx_session.get_providers()}", flush=True)
311+
sys.stdout.flush()
312+
313+
# This is the critical call - if it hangs, we'll see it in logs
314+
outputs = self.onnx_session.run(
315+
['binary_logit', 'tag_logits'],
316+
{input_name: input_tensor}
317+
)
318+
inference_time = time.time() - inference_start
319+
print(f"[{time.strftime('%H:%M:%S')}] ONNX inference completed in {inference_time:.2f}s", flush=True)
320+
sys.stdout.flush()
321+
except Exception as e:
322+
error_msg = f"ONNX inference failed: {str(e)}"
323+
print(f"[{time.strftime('%H:%M:%S')}] {error_msg}")
324+
print(f"[{time.strftime('%H:%M:%S')}] Traceback: {traceback.format_exc()}")
325+
return error_msg, None, empty_df
198326

199-
print("Running session.run()...")
200-
outputs = self.onnx_session.run(
201-
['binary_logit', 'tag_logits'],
202-
{'audio': input_tensor}
203-
)
204-
print("Inference complete")
205327
binary_logit, _ = outputs
328+
print(f"[{time.strftime('%H:%M:%S')}] Raw binary_logit: {binary_logit}, shape: {binary_logit.shape}")
329+
206330
# Convert numpy scalar to Python float for consistency
207331
ai_prob = float(1 / (1 + np.exp(-binary_logit[0, 0])))
332+
total_time = time.time() - start_time
333+
print(f"[{time.strftime('%H:%M:%S')}] Total ONNX processing time: {total_time:.2f}s, AI probability: {ai_prob:.3f}")
208334
else:
209335
# PyTorch inference
210336
print("Starting PyTorch inference...")
211337
if self.model is None:
212338
return "Error: PyTorch model not initialized", None, empty_df
213339

340+
if device is None:
341+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214342
input_tensor = input_data.to(device)
215343
ai_prob = predict_ai_only(self.model, input_tensor)
216344
# Ensure it's a Python float
@@ -380,6 +508,8 @@ def __predict_folder_batch(self, zip_file: str):
380508
ai_prob = float(1 / (1 + np.exp(-binary_logit[0, 0])))
381509
else:
382510
# PyTorch inference
511+
if device is None:
512+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
383513
input_tensor = input_data.to(device)
384514
ai_prob = predict_ai_only(self.model, input_tensor)
385515

@@ -534,6 +664,19 @@ def patched_hash_file(path, *args, **kwargs):
534664
theme=gradio_module.themes.Soft()
535665
)
536666

667+
# Configure queue for long-running requests
668+
# Note: Modal's ASGI wrapper handles concurrency, so we use a higher limit
669+
try:
670+
demo.queue(
671+
default_concurrency_limit=10, # Higher limit - Modal handles actual concurrency
672+
max_size=50, # Allow more queued requests
673+
api_open=False # Don't expose queue API
674+
)
675+
print("Gradio queue configured successfully")
676+
except Exception as e:
677+
print(f"Warning: Could not configure Gradio queue: {e}")
678+
# Continue anyway - queue might not be available in all Gradio versions
679+
537680
with demo:
538681
gradio_module.Markdown("""
539682
# AI-Generated Audio Detection
@@ -571,30 +714,64 @@ def patched_hash_file(path, *args, **kwargs):
571714

572715
def safe_predict(audio_file):
573716
"""Wrapper to catch Gradio preprocessing errors and Modal's root directory bug"""
717+
import sys
718+
# Log to both stdout and stderr for maximum visibility in Modal
719+
log_msg = f"[{time.strftime('%H:%M:%S')}] ===== safe_predict CALLED ====="
720+
print(log_msg, flush=True)
721+
print(log_msg, file=sys.stderr, flush=True)
722+
sys.stdout.flush()
723+
sys.stderr.flush()
724+
725+
log_msg = f"[{time.strftime('%H:%M:%S')}] audio_file type: {type(audio_file)}, value: {audio_file}"
726+
print(log_msg, flush=True)
727+
print(log_msg, file=sys.stderr, flush=True)
728+
sys.stdout.flush()
729+
sys.stderr.flush()
730+
574731
# Early filter for Modal's root directory bug
575732
if audio_file is None:
733+
print(f"[{time.strftime('%H:%M:%S')}] audio_file is None, returning early", flush=True)
576734
return self.__predict_single_audio(None)
577735
# Handle list inputs (Modal/Gradio might return lists)
578736
if isinstance(audio_file, list):
579737
if len(audio_file) == 0:
738+
print(f"[{time.strftime('%H:%M:%S')}] audio_file is empty list, returning early", flush=True)
580739
return self.__predict_single_audio(None)
581740
audio_file = audio_file[0]
741+
print(f"[{time.strftime('%H:%M:%S')}] Extracted from list: {audio_file}", flush=True)
582742
# Filter out root directory and favicon (Modal bug)
583743
# Also check if it's a directory
584744
if isinstance(audio_file, str):
585745
if audio_file == "/" or audio_file == "/favicon.ico" or audio_file == "" or audio_file == ".":
746+
print(f"[{time.strftime('%H:%M:%S')}] Invalid path detected: {audio_file}, returning early", flush=True)
586747
return self.__predict_single_audio(None)
587748
# Check if it's a directory path
588749
try:
589750
import os
590751
if os.path.isdir(audio_file):
752+
print(f"[{time.strftime('%H:%M:%S')}] Path is directory: {audio_file}, returning early", flush=True)
591753
return self.__predict_single_audio(None)
592-
except Exception:
754+
except Exception as e:
755+
print(f"[{time.strftime('%H:%M:%S')}] Error checking if directory: {e}", flush=True)
593756
pass # If check fails, continue
757+
758+
log_msg = f"[{time.strftime('%H:%M:%S')}] Calling __predict_single_audio with: {audio_file}"
759+
print(log_msg, flush=True)
760+
print(log_msg, file=sys.stderr, flush=True)
761+
sys.stdout.flush()
762+
sys.stderr.flush()
763+
594764
try:
595-
return self.__predict_single_audio(audio_file)
765+
result = self.__predict_single_audio(audio_file)
766+
log_msg = f"[{time.strftime('%H:%M:%S')}] ===== safe_predict completed successfully ====="
767+
print(log_msg, flush=True)
768+
print(log_msg, file=sys.stderr, flush=True)
769+
sys.stdout.flush()
770+
sys.stderr.flush()
771+
return result
596772
except (IsADirectoryError, OSError) as e:
597773
# Handle directory errors gracefully
774+
print(f"[{time.strftime('%H:%M:%S')}] Directory error caught: {e}", flush=True)
598775
if "Is a directory" in str(e) or "IsADirectoryError" in str(type(e).__name__):
599776
return self.__predict_single_audio(None)
600777
raise
@@ -605,6 +782,8 @@ def safe_predict(audio_file):
605782
error_msg += "- Corrupted audio file\n"
606783
error_msg += "- Missing audio codecs\n\n"
607784
error_msg += f"Technical details: {traceback.format_exc()}"
785+
print(f"[{time.strftime('%H:%M:%S')}] Exception in safe_predict: {error_msg}", flush=True)
786+
sys.stdout.flush()
608787
empty_df = pd.DataFrame(columns=['Filename', 'AI-Generated', 'Confidence', 'Genre', 'Mood', 'Tempo (BPM)', 'Energy'])
609788
return error_msg, None, empty_df
610789

@@ -794,11 +973,10 @@ def create_summary_visualizations(df):
794973
image=image,
795974
volumes={"/models": model_volume},
796975
timeout=600, # Increased timeout to 10 minutes
797-
container_idle_timeout=300, # Keep container alive for 5 minutes
798-
gpu="any", # Use any available GPU
976+
scaledown_window=300, # Keep container alive for 5 minutes (renamed from container_idle_timeout)
977+
# Removed gpu="any" since we're using CPU for ONNX inference
799978
)
800-
@modal.concurrent(max_inputs=100)
801-
@modal.asgi_app()
979+
@modal.asgi_app() # Outermost decorator - ASGI apps handle concurrency internally
802980
def gradio_app_modal():
803981
"""Modal deployment function - uses ONNX model from volume."""
804982
import sys
@@ -843,11 +1021,12 @@ def gradio_app_modal():
8431021
except Exception:
8441022
pass # If setting fails, continue anyway - not critical
8451023

846-
# Return the underlying FastAPI app which is ASGI compatible
1024+
# Return the ASGI-compatible app
1025+
# Gradio Blocks are ASGI-compatible, but demo.app (FastAPI) is more explicit
8471026
if hasattr(demo, "app"):
848-
return demo.app
1027+
return demo.app # FastAPI instance - preferred for Modal
8491028
elif callable(demo):
850-
return demo
1029+
return demo # Gradio Blocks are also ASGI-compatible
8511030
else:
8521031
raise RuntimeError(f"Expected an ASGI app, but got {type(demo)}")
8531032

0 commit comments

Comments
 (0)