generated from ultralytics/template
-
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_prediction.py
More file actions
77 lines (57 loc) · 2.41 KB
/
test_prediction.py
File metadata and controls
77 lines (57 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Test YOLO model predictions locally before deploying to Replicate.
Usage:
python test_prediction.py --model yolo11n --image test.jpg
python test_prediction.py --model custom --image test.jpg
"""
import argparse
import sys
from pathlib import Path
# Add model directories to path
sys.path.append(str(Path(__file__).parent))
def test_prediction(model_dir: str, image_path: str) -> None:
"""Test prediction locally."""
try:
# Import the predictor from the specified model directory
model_path = Path(model_dir)
if not model_path.exists():
print(f"❌ Model directory {model_dir} not found")
return
# Check for model files
pt_files = list(model_path.glob("*.pt"))
if not pt_files:
print(f"❌ No PyTorch model files found in {model_dir}")
if model_dir == "yolo11n":
print("💡 Make sure yolo11n.pt is in the yolo11n/ directory")
else:
print("💡 Make sure your custom .pt model is in the custom/ directory")
return
print(f"✅ Found PyTorch models: {[f.name for f in pt_files]}")
# Import and test the predictor
sys.path.insert(0, str(model_path))
from predict import Predictor
predictor = Predictor()
predictor.setup()
# Test prediction
from cog import Path as CogPath
result = predictor.predict(image=CogPath(image_path))
print(f"✅ Prediction successful: {result}")
print(f"💡 Ready to deploy: cd {model_dir} && cog push r8.im/your-username/your-model")
except ImportError as e:
print(f"❌ Import error: {e}")
print("💡 Make sure you have ultralytics installed: pip install ultralytics")
except Exception as e:
print(f"❌ Prediction failed: {e}")
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Test YOLO predictions locally")
parser.add_argument("--model", required=True, choices=["yolo11n", "custom"], help="Model directory to test")
parser.add_argument("--image", required=True, help="Path to test image")
args = parser.parse_args()
if not Path(args.image).exists():
print(f"❌ Image file {args.image} not found")
return
test_prediction(args.model, args.image)
if __name__ == "__main__":
main()