Skip to content

Commit 49cd61b

Browse files
committed
Added slides
1 parent 92cdb9e commit 49cd61b

File tree

5 files changed

+85
-57
lines changed

5 files changed

+85
-57
lines changed

gradio_app.py

Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,23 @@ def __init__(self, model_path: str, threshold: float = DEFAULT_THRESHOLD, use_on
160160
# PyTorch model loading
161161
if not TORCH_AVAILABLE:
162162
raise ImportError("PyTorch is required for PyTorch models. Install with: pip install torch")
163-
if device is None:
164-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163+
164+
print(f"[{time.strftime('%H:%M:%S')}] Loading PyTorch model from: {model_path}")
165+
166+
# Create device for this instance (don't rely on module-level device)
167+
model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168+
169+
print(f"[{time.strftime('%H:%M:%S')}] Using device: {model_device}")
170+
171+
load_start = time.time()
165172
self.model = DualHeadCnn14Simple(pretrained=False)
166-
self.model.load_state_dict(torch.load(self.model_path, map_location=device, weights_only=False))
167-
self.model.eval().to(device)
173+
self.model.load_state_dict(torch.load(self.model_path, map_location=model_device, weights_only=False))
174+
self.model.eval().to(model_device)
175+
load_time = time.time() - load_start
176+
print(f"[{time.strftime('%H:%M:%S')}] PyTorch model loaded in {load_time:.2f}s on {model_device}")
177+
178+
# Store device for later use in inference
179+
self.device = model_device
168180
self.onnx_session = None
169181

170182

@@ -333,19 +345,35 @@ def __predict_single_audio(self, audio_file):
333345
print(f"[{time.strftime('%H:%M:%S')}] Total ONNX processing time: {total_time:.2f}s, AI probability: {ai_prob:.3f}")
334346
else:
335347
# PyTorch inference
336-
print("Starting PyTorch inference...")
348+
print(f"[{time.strftime('%H:%M:%S')}] Starting PyTorch inference...", flush=True)
349+
import sys
350+
sys.stdout.flush()
351+
337352
if self.model is None:
338-
return "Error: PyTorch model not initialized", None, empty_df
353+
error_msg = "Error: PyTorch model not initialized"
354+
print(f"[{time.strftime('%H:%M:%S')}] {error_msg}", flush=True)
355+
return error_msg, None, empty_df
356+
357+
# Use the device stored during model loading
358+
inference_device = getattr(self, 'device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
359+
360+
print(f"[{time.strftime('%H:%M:%S')}] Using device: {inference_device}", flush=True)
361+
inference_start = time.time()
362+
363+
input_tensor = input_data.to(inference_device)
364+
print(f"[{time.strftime('%H:%M:%S')}] Input tensor moved to device, shape: {input_tensor.shape}", flush=True)
365+
sys.stdout.flush()
339366

340-
if device is None:
341-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
342-
input_tensor = input_data.to(device)
343367
ai_prob = predict_ai_only(self.model, input_tensor)
368+
344369
# Ensure it's a Python float
345370
if isinstance(ai_prob, torch.Tensor):
346371
ai_prob = ai_prob.item()
347372
ai_prob = float(ai_prob)
348-
print("Inference complete")
373+
374+
inference_time = time.time() - inference_start
375+
print(f"[{time.strftime('%H:%M:%S')}] PyTorch inference completed in {inference_time:.2f}s, AI probability: {ai_prob:.3f}", flush=True)
376+
sys.stdout.flush()
349377

350378
is_ai = ai_prob > self.threshold
351379
result = f"**AI-Generated: {'Yes' if is_ai else 'No'}**\n"
@@ -664,18 +692,21 @@ def patched_hash_file(path, *args, **kwargs):
664692
theme=gradio_module.themes.Soft()
665693
)
666694

667-
# Configure queue for long-running requests
668-
# Note: Modal's ASGI wrapper handles concurrency, so we use a higher limit
669-
try:
670-
demo.queue(
671-
default_concurrency_limit=10, # Higher limit - Modal handles actual concurrency
672-
max_size=50, # Allow more queued requests
673-
api_open=False # Don't expose queue API
674-
)
675-
print("Gradio queue configured successfully")
676-
except Exception as e:
677-
print(f"Warning: Could not configure Gradio queue: {e}")
678-
# Continue anyway - queue might not be available in all Gradio versions
695+
# For Modal ASGI deployment, don't use queue - Modal handles concurrency
696+
# Queue causes session management issues with ASGI
697+
if not self.is_modal:
698+
# Only enable queue for local deployment
699+
try:
700+
demo.queue(
701+
default_concurrency_limit=10,
702+
max_size=50,
703+
api_open=False
704+
)
705+
print("Gradio queue configured successfully (local mode)")
706+
except Exception as e:
707+
print(f"Warning: Could not configure Gradio queue: {e}")
708+
else:
709+
print("Gradio queue disabled for Modal ASGI deployment")
679710

680711
with demo:
681712
gradio_module.Markdown("""
@@ -864,14 +895,10 @@ def safe_predict(audio_file):
864895
if not hasattr(demo, 'max_file_size'):
865896
demo.max_file_size = 1024 * 1024 * 1024 # 1 GB limit
866897

867-
# Set root_path to empty string for Modal to prevent '/' path issues
868-
# This prevents Gradio from trying to process '/' as a file path on page load
869-
if hasattr(demo, 'root_path'):
870-
demo.root_path = ""
871-
elif hasattr(demo, 'config'):
872-
# Try setting it via config if available
873-
if hasattr(demo.config, 'root_path'):
874-
demo.config.root_path = ""
898+
# For Modal ASGI deployment, don't set root_path
899+
# Gradio handles routing automatically for ASGI apps
900+
# Setting root_path to "" can break session management and cause "Session not found" errors
901+
# The monkey-patch for hash_file handles the '/' path issue instead
875902

876903
return demo
877904

@@ -974,11 +1001,11 @@ def create_summary_visualizations(df):
9741001
volumes={"/models": model_volume},
9751002
timeout=600, # Increased timeout to 10 minutes
9761003
scaledown_window=300, # Keep container alive for 5 minutes (renamed from container_idle_timeout)
977-
# Removed gpu="any" since we're using CPU for ONNX inference
1004+
gpu="any", # Use GPU for PyTorch inference
9781005
)
9791006
@modal.asgi_app() # Outermost decorator - ASGI apps handle concurrency internally
9801007
def gradio_app_modal():
981-
"""Modal deployment function - uses ONNX model from volume."""
1008+
"""Modal deployment function - uses PyTorch model from volume."""
9821009
import sys
9831010
import os
9841011
sys.path.insert(0, "/root")
@@ -990,45 +1017,46 @@ def gradio_app_modal():
9901017
# Set environment variable for Gradio cache
9911018
os.environ["GRADIO_TEMP_DIR"] = gradio_cache_dir
9921019

993-
ONNX_MODEL_PATH = "/models/model.onnx"
1020+
PYTORCH_MODEL_PATH = "/models/model.pth"
9941021
THRESHOLD = DEFAULT_THRESHOLD
9951022

9961023
# Check if model exists, if not provide helpful error
997-
if not os.path.exists(ONNX_MODEL_PATH):
1024+
if not os.path.exists(PYTORCH_MODEL_PATH):
9981025
raise FileNotFoundError(
999-
f"ONNX model not found at {ONNX_MODEL_PATH}. "
1026+
f"PyTorch model not found at {PYTORCH_MODEL_PATH}. "
10001027
"Please upload your model to the Modal volume first using:\n"
1001-
"modal volume put ai-audio-models /path/to/your/model.onnx model.onnx"
1028+
"modal volume put ai-audio-models /path/to/your/model.pth model.pth"
10021029
)
10031030

10041031
interface = GradioAudioInterface(
1005-
model_path=ONNX_MODEL_PATH,
1032+
model_path=PYTORCH_MODEL_PATH,
10061033
threshold=THRESHOLD,
1007-
use_onnx=True,
1034+
use_onnx=False, # Use PyTorch instead of ONNX
10081035
is_modal=True
10091036
)
10101037
demo = interface.run_gradio()
10111038

1012-
# In Gradio 4.x, Blocks implements the ASGI interface directly
1013-
# Set root_path to empty string to prevent '/' path issues in Modal
1014-
# This is similar to setting root_path="" in demo.launch() but for ASGI deployment
1015-
try:
1016-
if hasattr(demo, 'root_path'):
1017-
demo.root_path = ""
1018-
# Also try setting via config if available
1019-
if hasattr(demo, 'config') and hasattr(demo.config, 'root_path'):
1020-
demo.config.root_path = ""
1021-
except Exception:
1022-
pass # If setting fails, continue anyway - not critical
1039+
# For Modal ASGI deployment, don't modify root_path
1040+
# Gradio handles routing automatically for ASGI apps
1041+
# Modifying root_path can break session management
10231042

10241043
# Return the ASGI-compatible app
1025-
# Gradio Blocks are ASGI-compatible, but demo.app (FastAPI) is more explicit
1026-
if hasattr(demo, "app"):
1027-
return demo.app # FastAPI instance - preferred for Modal
1028-
elif callable(demo):
1029-
return demo # Gradio Blocks are also ASGI-compatible
1030-
else:
1031-
raise RuntimeError(f"Expected an ASGI app, but got {type(demo)}")
1044+
# For Modal, we should return demo.app (FastAPI) which handles ASGI properly
1045+
# This avoids session management issues with Gradio's queue system
1046+
try:
1047+
if hasattr(demo, "app"):
1048+
# FastAPI instance - preferred for Modal ASGI deployment
1049+
# This properly handles session management without queue conflicts
1050+
return demo.app
1051+
elif callable(demo):
1052+
# Gradio Blocks are ASGI-compatible but may have session issues
1053+
return demo
1054+
else:
1055+
raise RuntimeError(f"Expected an ASGI app, but got {type(demo)}")
1056+
except Exception as e:
1057+
print(f"Error getting ASGI app: {e}")
1058+
# Fallback: return demo directly
1059+
return demo
10321060

10331061

10341062
if __name__ == "__main__":
505 KB
Loading

slides/architecture.png

75.4 KB
Loading

slides/audio_presentation.key

1.35 MB
Binary file not shown.

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def train_loop(
128128
# Fix label smoothing
129129
labels = torch.where(labels > 0.5, 1 - smooth/2, smooth/2)
130130

131-
# Forward pass
131+
# Forward pass (not using tag_logits)
132132
binary_logits, _ = model(each_input)
133133
loss = loss_fn(binary_logits, labels)
134134

0 commit comments

Comments
 (0)