Skip to content

Commit a0077b1

Browse files
Update preprocessing.py
1 parent 5b77963 commit a0077b1

1 file changed

Lines changed: 52 additions & 19 deletions

File tree

ml/preprocessing.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,52 @@
33
from pathlib import Path
44
from shared.signal import Signal
55

6-
FRAME_SAMPLE_RATE = 10 # extract every Nth frame for video
7-
MIN_QUALITY_SCORE = 0.3 # below this, reliability tanks
8-
TARGET_SIZE = (224, 224) # Xception input size
6+
FRAME_SAMPLE_RATE = 10
7+
MIN_QUALITY_SCORE = 0.3
8+
TARGET_SIZE = (224, 224)
9+
10+
# ── Security: all files must live under this directory ──────────────────────
11+
UPLOADS_ROOT = Path("/app/uploads").resolve()
12+
ALLOWED_SUFFIXES = {".jpg", ".jpeg", ".png", ".mp4"}
13+
14+
15+
def _safe_resolve(file_path: str) -> Path:
16+
"""
17+
Resolve the path and enforce it stays inside UPLOADS_ROOT.
18+
Raises ValueError (no internal detail) if anything looks wrong.
19+
"""
20+
try:
21+
# resolve() collapses ../.. and follows symlinks
22+
resolved = Path(file_path).resolve()
23+
except (TypeError, ValueError):
24+
raise ValueError("Invalid file path.")
25+
26+
# Symlink escape check — re-check the real path after resolution
27+
if not str(resolved).startswith(str(UPLOADS_ROOT)):
28+
raise ValueError("File not found.") # don't say why
29+
30+
# Must actually exist and be a regular file (no devices, sockets, etc.)
31+
if not resolved.is_file():
32+
raise ValueError("File not found.")
33+
34+
# Suffix whitelist enforced here, before any I/O
35+
if resolved.suffix.lower() not in ALLOWED_SUFFIXES:
36+
raise ValueError("Unsupported file type.")
37+
38+
return resolved
939

1040

1141
def preprocess(file_path: str) -> Signal:
12-
path = Path(file_path)
42+
path = _safe_resolve(file_path) # raises safely if bad
1343
suffix = path.suffix.lower()
1444

1545
if suffix in (".jpg", ".jpeg", ".png"):
1646
data, quality, extra = _process_image(path)
1747
elif suffix == ".mp4":
1848
data, quality, extra = _process_video(path)
1949
else:
20-
raise ValueError(f"Unsupported file type: {suffix}")
50+
# Unreachable after _safe_resolve, but keeps type-checker happy
51+
raise ValueError("Unsupported file type.")
2152

2253
reliability = _compute_reliability(quality, extra)
2354

@@ -26,7 +57,7 @@ def preprocess(file_path: str) -> Signal:
2657
reliability=reliability,
2758
module="ml.preprocessing",
2859
metadata={
29-
"file": path.name,
60+
"file": path.name, # filename only, never full path
3061
"type": suffix,
3162
"frames": len(data),
3263
"quality_score": quality,
@@ -38,26 +69,23 @@ def preprocess(file_path: str) -> Signal:
3869
def _process_image(path: Path):
3970
img = cv2.imread(str(path))
4071
if img is None:
41-
raise ValueError(f"Could not read image: {path}")
72+
raise ValueError("Could not read file.") # no path in message
4273
frame = _normalise(img)
4374
quality = _compute_quality(img)
4475
extra = {
4576
"pixel_distribution": _check_pixel_distribution(img),
4677
"compression_artifact_score": _check_compression_artifacts(img),
47-
"aspect_ratio_consistent": True, # single frame, always consistent
78+
"aspect_ratio_consistent": True,
4879
}
4980
return [frame], quality, extra
5081

5182

5283
def _process_video(path: Path):
5384
cap = cv2.VideoCapture(str(path))
5485
if not cap.isOpened():
55-
raise ValueError(f"Could not open video: {path}")
86+
raise ValueError("Could not read file.") # no path in message
5687

57-
frames = []
58-
qualities = []
59-
pixel_scores = []
60-
compression_scores = []
88+
frames, qualities, pixel_scores, compression_scores = [], [], [], []
6189
aspect_ratios = set()
6290
idx = 0
6391

@@ -77,7 +105,7 @@ def _process_video(path: Path):
77105
cap.release()
78106

79107
if not frames:
80-
raise ValueError("No frames extracted from video")
108+
raise ValueError("No frames could be extracted.")
81109

82110
extra = {
83111
"pixel_distribution": float(np.mean(pixel_scores)),
@@ -121,6 +149,13 @@ def _check_compression_artifacts(frame: np.ndarray) -> float:
121149
Returns 0.0 (heavily compressed) to 1.0 (clean).
122150
"""
123151
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
152+
153+
# cv2.dct() requires even dimensions — pad raw frames safely
154+
h, w = gray.shape
155+
new_h = h if h % 2 == 0 else h - 1
156+
new_w = w if w % 2 == 0 else w - 1
157+
gray = gray[:new_h, :new_w]
158+
124159
dct = cv2.dct(gray)
125160
high_freq_energy = np.abs(dct[16:, 16:]).mean()
126161
total_energy = np.abs(dct).mean() + 1e-6
@@ -139,17 +174,15 @@ def _compute_reliability(quality: float, extra: dict) -> float:
139174
base = 0.5 + (quality * 0.5) # 0.65 – 1.0 from sharpness
140175

141176
# pixel distribution penalty — unnatural clustering
142-
pixel_score = extra.get("pixel_distribution", 1.0)
143-
if pixel_score < 0.4:
177+
if extra.get("pixel_distribution", 1.0) < 0.4:
144178
base -= 0.15
145179

146180
# compression penalty — heavy artifacts hurt detection
147-
compression_score = extra.get("compression_artifact_score", 1.0)
148-
if compression_score < 0.3:
181+
if extra.get("compression_artifact_score", 1.0) < 0.3:
149182
base -= 0.15
150183

151184
# aspect ratio inconsistency — edited/stitched video
152185
if not extra.get("aspect_ratio_consistent", True):
153186
base -= 0.2
154187

155-
return float(max(round(base, 4), 0.0))
188+
return float(max(round(base, 4), 0.0))

0 commit comments

Comments
 (0)