|
| 1 | +"""Internal pipeline functions for SceneFlow API. |
| 2 | +
|
| 3 | +This module contains implementation details and should not be imported directly. |
| 4 | +Use the public API functions from sceneflow.api instead. |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +from typing import Dict, List, Optional, Tuple |
| 9 | + |
| 10 | +from sceneflow.shared.config import RankingConfig |
| 11 | +from sceneflow.detection import EnergyRefiner |
| 12 | +from sceneflow.shared.models import RankedFrame, FrameScore, FrameFeatures |
| 13 | +from sceneflow.core import CutPointRanker |
| 14 | +from sceneflow.detection import SpeechDetector |
| 15 | +from sceneflow.selection import LLMFrameSelector |
| 16 | +from sceneflow.utils.video import get_video_duration |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +def detect_speech_end( |
| 22 | + video_path: str, |
| 23 | + use_energy_refinement: bool, |
| 24 | + energy_threshold_db: float, |
| 25 | + energy_lookback_frames: int, |
| 26 | +) -> Tuple[float, float, List[Dict[str, float]]]: |
| 27 | + """Detect when speech ends in video using VAD and optional refinements.""" |
| 28 | + logger.info("Stage 1: Detecting speech end time...") |
| 29 | + detector = SpeechDetector() |
| 30 | + |
| 31 | + vad_speech_end_time, vad_timestamps = detector.get_speech_timestamps(video_path) |
| 32 | + logger.info("VAD detected speech end at: %.4fs", vad_speech_end_time) |
| 33 | + |
| 34 | + speech_end_time = vad_speech_end_time |
| 35 | + pre_refinement_time = vad_speech_end_time |
| 36 | + |
| 37 | + if use_energy_refinement: |
| 38 | + logger.info("Stage 1.5: Refining VAD-detected speech end time with energy analysis...") |
| 39 | + |
| 40 | + before_energy = speech_end_time |
| 41 | + |
| 42 | + refiner = EnergyRefiner( |
| 43 | + threshold_db=energy_threshold_db, |
| 44 | + lookback_frames=energy_lookback_frames |
| 45 | + ) |
| 46 | + result = refiner.refine_speech_end( |
| 47 | + speech_end_time, |
| 48 | + video_path |
| 49 | + ) |
| 50 | + |
| 51 | + speech_end_time = result.refined_timestamp |
| 52 | + |
| 53 | + if result.frames_adjusted > 0: |
| 54 | + logger.info( |
| 55 | + "Energy refinement adjusted timestamp by %d frames", |
| 56 | + result.frames_adjusted |
| 57 | + ) |
| 58 | + pre_refinement_time = before_energy |
| 59 | + else: |
| 60 | + logger.info("Energy refinement: No adjustment needed") |
| 61 | + |
| 62 | + visual_search_end_time = pre_refinement_time if speech_end_time < pre_refinement_time else -1.0 |
| 63 | + |
| 64 | + return speech_end_time, visual_search_end_time, vad_timestamps |
| 65 | + |
| 66 | + |
| 67 | +def rank_frames( |
| 68 | + video_path: str, |
| 69 | + speech_end_time: float, |
| 70 | + duration: float, |
| 71 | + config: Optional[RankingConfig], |
| 72 | + sample_rate: int, |
| 73 | + visual_search_end_time: float = -1.0, |
| 74 | + return_internals: bool = False |
| 75 | +) -> Tuple[List[RankedFrame], Optional[List[FrameFeatures]], Optional[List[FrameScore]]]: |
| 76 | + """Rank frames after speech ends.""" |
| 77 | + end_time = visual_search_end_time if visual_search_end_time > 0 else duration |
| 78 | + |
| 79 | + logger.info("Stage 2: Ranking frames based on visual quality...") |
| 80 | + logger.info("Analyzing frames from %.4fs to %.4fs", speech_end_time, end_time) |
| 81 | + |
| 82 | + ranker = CutPointRanker(config) |
| 83 | + |
| 84 | + if return_internals: |
| 85 | + ranked_frames, features, scores = ranker.rank_frames( |
| 86 | + video_path=video_path, |
| 87 | + start_time=speech_end_time, |
| 88 | + end_time=end_time, |
| 89 | + sample_rate=sample_rate, |
| 90 | + return_internals=True |
| 91 | + ) |
| 92 | + return ranked_frames, features, scores |
| 93 | + else: |
| 94 | + ranked_frames = ranker.rank_frames( |
| 95 | + video_path=video_path, |
| 96 | + start_time=speech_end_time, |
| 97 | + end_time=end_time, |
| 98 | + sample_rate=sample_rate, |
| 99 | + ) |
| 100 | + return ranked_frames, None, None |
| 101 | + |
| 102 | + |
| 103 | +def select_best_with_llm( |
| 104 | + video_path: str, |
| 105 | + ranked_frames: List[RankedFrame], |
| 106 | + speech_end_time: float, |
| 107 | + duration: float, |
| 108 | + scores: List[FrameScore], |
| 109 | + features: List[FrameFeatures], |
| 110 | + openai_api_key: Optional[str] |
| 111 | +) -> RankedFrame: |
| 112 | + """Use LLM to select best frame from top candidates.""" |
| 113 | + if len(ranked_frames) < 2: |
| 114 | + return ranked_frames[0] |
| 115 | + try: |
| 116 | + selector = LLMFrameSelector(api_key=openai_api_key) |
| 117 | + return selector.select_best_frame( |
| 118 | + video_path=video_path, |
| 119 | + ranked_frames=ranked_frames, |
| 120 | + speech_end_time=speech_end_time, |
| 121 | + video_duration=duration, |
| 122 | + ) |
| 123 | + except Exception as e: |
| 124 | + logger.warning("LLM selection failed: %s, using top result", e) |
| 125 | + return ranked_frames[0] |
| 126 | + |
| 127 | + |
| 128 | +def upload_to_airtable( |
| 129 | + video_path: str, |
| 130 | + best_frame: RankedFrame, |
| 131 | + scores: List[FrameScore], |
| 132 | + features: List[FrameFeatures], |
| 133 | + speech_end_time: float, |
| 134 | + duration: float, |
| 135 | + config: Optional[RankingConfig], |
| 136 | + sample_rate: int, |
| 137 | + airtable_access_token: Optional[str], |
| 138 | + airtable_base_id: Optional[str], |
| 139 | + airtable_table_name: Optional[str] |
| 140 | +) -> str: |
| 141 | + """Upload analysis results to Airtable.""" |
| 142 | + from sceneflow.integration import upload_to_airtable as airtable_upload |
| 143 | + |
| 144 | + best_score = next((s for s in scores if s.frame_index == best_frame.frame_index), None) |
| 145 | + best_features = next((f for f in features if f.frame_index == best_frame.frame_index), None) |
| 146 | + |
| 147 | + if not best_score or not best_features: |
| 148 | + raise RuntimeError("Could not upload to Airtable - missing data") |
| 149 | + |
| 150 | + config_dict = { |
| 151 | + "sample_rate": sample_rate, |
| 152 | + "weights": { |
| 153 | + "eye": config.eye_weight if config else 0.4, |
| 154 | + "mouth": config.mouth_weight if config else 0.6, |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + record_id = airtable_upload( |
| 159 | + video_path=video_path, |
| 160 | + best_frame=best_frame, |
| 161 | + frame_score=best_score, |
| 162 | + frame_features=best_features, |
| 163 | + speech_end_time=speech_end_time, |
| 164 | + duration=duration, |
| 165 | + config_dict=config_dict, |
| 166 | + access_token=airtable_access_token, |
| 167 | + base_id=airtable_base_id, |
| 168 | + table_name=airtable_table_name |
| 169 | + ) |
| 170 | + return record_id |
0 commit comments