@@ -160,11 +160,23 @@ def __init__(self, model_path: str, threshold: float = DEFAULT_THRESHOLD, use_on
160160 # PyTorch model loading
161161 if not TORCH_AVAILABLE :
162162 raise ImportError ("PyTorch is required for PyTorch models. Install with: pip install torch" )
163- if device is None :
164- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
163+
164+ print (f"[{ time .strftime ('%H:%M:%S' )} ] Loading PyTorch model from: { model_path } " )
165+
166+ # Create device for this instance (don't rely on module-level device)
167+ model_device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
168+
169+ print (f"[{ time .strftime ('%H:%M:%S' )} ] Using device: { model_device } " )
170+
171+ load_start = time .time ()
165172 self .model = DualHeadCnn14Simple (pretrained = False )
166- self .model .load_state_dict (torch .load (self .model_path , map_location = device , weights_only = False ))
167- self .model .eval ().to (device )
173+ self .model .load_state_dict (torch .load (self .model_path , map_location = model_device , weights_only = False ))
174+ self .model .eval ().to (model_device )
175+ load_time = time .time () - load_start
176+ print (f"[{ time .strftime ('%H:%M:%S' )} ] PyTorch model loaded in { load_time :.2f} s on { model_device } " )
177+
178+ # Store device for later use in inference
179+ self .device = model_device
168180 self .onnx_session = None
169181
170182
@@ -333,19 +345,35 @@ def __predict_single_audio(self, audio_file):
333345 print (f"[{ time .strftime ('%H:%M:%S' )} ] Total ONNX processing time: { total_time :.2f} s, AI probability: { ai_prob :.3f} " )
334346 else :
335347 # PyTorch inference
336- print ("Starting PyTorch inference..." )
348+ print (f"[{ time .strftime ('%H:%M:%S' )} ] Starting PyTorch inference..." , flush = True )
349+ import sys
350+ sys .stdout .flush ()
351+
337352 if self .model is None :
338- return "Error: PyTorch model not initialized" , None , empty_df
353+ error_msg = "Error: PyTorch model not initialized"
354+ print (f"[{ time .strftime ('%H:%M:%S' )} ] { error_msg } " , flush = True )
355+ return error_msg , None , empty_df
356+
357+ # Use the device stored during model loading
358+ inference_device = getattr (self , 'device' , torch .device ("cuda" if torch .cuda .is_available () else "cpu" ))
359+
360+ print (f"[{ time .strftime ('%H:%M:%S' )} ] Using device: { inference_device } " , flush = True )
361+ inference_start = time .time ()
362+
363+ input_tensor = input_data .to (inference_device )
364+ print (f"[{ time .strftime ('%H:%M:%S' )} ] Input tensor moved to device, shape: { input_tensor .shape } " , flush = True )
365+ sys .stdout .flush ()
339366
340- if device is None :
341- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
342- input_tensor = input_data .to (device )
343367 ai_prob = predict_ai_only (self .model , input_tensor )
368+
344369 # Ensure it's a Python float
345370 if isinstance (ai_prob , torch .Tensor ):
346371 ai_prob = ai_prob .item ()
347372 ai_prob = float (ai_prob )
348- print ("Inference complete" )
373+
374+ inference_time = time .time () - inference_start
375+ print (f"[{ time .strftime ('%H:%M:%S' )} ] PyTorch inference completed in { inference_time :.2f} s, AI probability: { ai_prob :.3f} " , flush = True )
376+ sys .stdout .flush ()
349377
350378 is_ai = ai_prob > self .threshold
351379 result = f"**AI-Generated: { 'Yes' if is_ai else 'No' } **\n "
@@ -664,18 +692,21 @@ def patched_hash_file(path, *args, **kwargs):
664692 theme = gradio_module .themes .Soft ()
665693 )
666694
667- # Configure queue for long-running requests
668- # Note: Modal's ASGI wrapper handles concurrency, so we use a higher limit
669- try :
670- demo .queue (
671- default_concurrency_limit = 10 , # Higher limit - Modal handles actual concurrency
672- max_size = 50 , # Allow more queued requests
673- api_open = False # Don't expose queue API
674- )
675- print ("Gradio queue configured successfully" )
676- except Exception as e :
677- print (f"Warning: Could not configure Gradio queue: { e } " )
678- # Continue anyway - queue might not be available in all Gradio versions
695+ # For Modal ASGI deployment, don't use queue - Modal handles concurrency
696+ # Queue causes session management issues with ASGI
697+ if not self .is_modal :
698+ # Only enable queue for local deployment
699+ try :
700+ demo .queue (
701+ default_concurrency_limit = 10 ,
702+ max_size = 50 ,
703+ api_open = False
704+ )
705+ print ("Gradio queue configured successfully (local mode)" )
706+ except Exception as e :
707+ print (f"Warning: Could not configure Gradio queue: { e } " )
708+ else :
709+ print ("Gradio queue disabled for Modal ASGI deployment" )
679710
680711 with demo :
681712 gradio_module .Markdown ("""
@@ -864,14 +895,10 @@ def safe_predict(audio_file):
864895 if not hasattr (demo , 'max_file_size' ):
865896 demo .max_file_size = 1024 * 1024 * 1024 # 1 GB limit
866897
867- # Set root_path to empty string for Modal to prevent '/' path issues
868- # This prevents Gradio from trying to process '/' as a file path on page load
869- if hasattr (demo , 'root_path' ):
870- demo .root_path = ""
871- elif hasattr (demo , 'config' ):
872- # Try setting it via config if available
873- if hasattr (demo .config , 'root_path' ):
874- demo .config .root_path = ""
898+ # For Modal ASGI deployment, don't set root_path
899+ # Gradio handles routing automatically for ASGI apps
900+ # Setting root_path to "" can break session management and cause "Session not found" errors
901+ # The monkey-patch for hash_file handles the '/' path issue instead
875902
876903 return demo
877904
@@ -974,11 +1001,11 @@ def create_summary_visualizations(df):
9741001 volumes = {"/models" : model_volume },
9751002 timeout = 600 , # Increased timeout to 10 minutes
9761003 scaledown_window = 300 , # Keep container alive for 5 minutes (renamed from container_idle_timeout)
977- # Removed gpu="any" since we're using CPU for ONNX inference
1004+ gpu = "any" , # Use GPU for PyTorch inference
9781005 )
9791006 @modal .asgi_app () # Outermost decorator - ASGI apps handle concurrency internally
9801007 def gradio_app_modal ():
981- """Modal deployment function - uses ONNX model from volume."""
1008+ """Modal deployment function - uses PyTorch model from volume."""
9821009 import sys
9831010 import os
9841011 sys .path .insert (0 , "/root" )
@@ -990,45 +1017,46 @@ def gradio_app_modal():
9901017 # Set environment variable for Gradio cache
9911018 os .environ ["GRADIO_TEMP_DIR" ] = gradio_cache_dir
9921019
993- ONNX_MODEL_PATH = "/models/model.onnx "
1020+ PYTORCH_MODEL_PATH = "/models/model.pth "
9941021 THRESHOLD = DEFAULT_THRESHOLD
9951022
9961023 # Check if model exists, if not provide helpful error
997- if not os .path .exists (ONNX_MODEL_PATH ):
1024+ if not os .path .exists (PYTORCH_MODEL_PATH ):
9981025 raise FileNotFoundError (
999- f"ONNX model not found at { ONNX_MODEL_PATH } . "
1026+ f"PyTorch model not found at { PYTORCH_MODEL_PATH } . "
10001027 "Please upload your model to the Modal volume first using:\n "
1001- "modal volume put ai-audio-models /path/to/your/model.onnx model.onnx "
1028+ "modal volume put ai-audio-models /path/to/your/model.pth model.pth "
10021029 )
10031030
10041031 interface = GradioAudioInterface (
1005- model_path = ONNX_MODEL_PATH ,
1032+ model_path = PYTORCH_MODEL_PATH ,
10061033 threshold = THRESHOLD ,
1007- use_onnx = True ,
1034+ use_onnx = False , # Use PyTorch instead of ONNX
10081035 is_modal = True
10091036 )
10101037 demo = interface .run_gradio ()
10111038
1012- # In Gradio 4.x, Blocks implements the ASGI interface directly
1013- # Set root_path to empty string to prevent '/' path issues in Modal
1014- # This is similar to setting root_path="" in demo.launch() but for ASGI deployment
1015- try :
1016- if hasattr (demo , 'root_path' ):
1017- demo .root_path = ""
1018- # Also try setting via config if available
1019- if hasattr (demo , 'config' ) and hasattr (demo .config , 'root_path' ):
1020- demo .config .root_path = ""
1021- except Exception :
1022- pass # If setting fails, continue anyway - not critical
1039+ # For Modal ASGI deployment, don't modify root_path
1040+ # Gradio handles routing automatically for ASGI apps
1041+ # Modifying root_path can break session management
10231042
10241043 # Return the ASGI-compatible app
1025- # Gradio Blocks are ASGI-compatible, but demo.app (FastAPI) is more explicit
1026- if hasattr (demo , "app" ):
1027- return demo .app # FastAPI instance - preferred for Modal
1028- elif callable (demo ):
1029- return demo # Gradio Blocks are also ASGI-compatible
1030- else :
1031- raise RuntimeError (f"Expected an ASGI app, but got { type (demo )} " )
1044+ # For Modal, we should return demo.app (FastAPI) which handles ASGI properly
1045+ # This avoids session management issues with Gradio's queue system
1046+ try :
1047+ if hasattr (demo , "app" ):
1048+ # FastAPI instance - preferred for Modal ASGI deployment
1049+ # This properly handles session management without queue conflicts
1050+ return demo .app
1051+ elif callable (demo ):
1052+ # Gradio Blocks are ASGI-compatible but may have session issues
1053+ return demo
1054+ else :
1055+ raise RuntimeError (f"Expected an ASGI app, but got { type (demo )} " )
1056+ except Exception as e :
1057+ print (f"Error getting ASGI app: { e } " )
1058+ # Fallback: return demo directly
1059+ return demo
10321060
10331061
10341062if __name__ == "__main__" :
0 commit comments