Skip to content

Commit 4a47e4d

Browse files
authored
[AI Tagger] V2 (#478)
1 parent 6d38e77 commit 4a47e4d

File tree

7 files changed

+368
-337
lines changed

7 files changed

+368
-337
lines changed

plugins/AITagger/ai_server.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Optional, Set
22
import aiohttp
33
import pydantic
44
import config
55
import stashapi.log as log
66

7-
current_videopipeline = None
8-
97
# ----------------- AI Server Calling Functions -----------------
108

119
async def post_api_async(session, endpoint, payload):
@@ -38,55 +36,47 @@ async def process_images_async(image_paths, threshold=config.IMAGE_THRESHOLD, re
3836
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
3937
return await post_api_async(session, 'process_images/', {"paths": image_paths, "threshold": threshold, "return_confidence": return_confidence})
4038

41-
async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True):
42-
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
43-
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video})
44-
45-
async def get_image_config_async(threshold=config.IMAGE_THRESHOLD):
39+
async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True, existing_json=None):
4640
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
47-
return await get_api_async(session, f'image_pipeline_info/?threshold={threshold}')
41+
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video, "existing_json_data": existing_json})
4842

49-
async def get_video_config_async(frame_interval=config.FRAME_INTERVAL, threshold=config.AI_VIDEO_THRESHOLD):
43+
async def find_optimal_marker_settings(existing_json, desired_timespan_data):
5044
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
51-
return await get_api_async(session, f'video_pipeline_info/?frame_interval={frame_interval}&threshold={threshold}&return_confidence=True')
52-
45+
return await post_api_async(session, 'optimize_timeframe_settings/', {"existing_json_data": existing_json, "desired_timespan_data": desired_timespan_data})
46+
47+
5348
class VideoResult(pydantic.BaseModel):
54-
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
55-
pipeline_short_name: str
56-
pipeline_version: float
57-
threshold: float
58-
frame_interval: float
59-
return_confidence: bool
49+
result: Dict[str, Any]
6050

61-
class ImageResult(pydantic.BaseModel):
62-
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
63-
pipeline_short_name: str
64-
pipeline_version: float
65-
threshold: float
66-
return_confidence: bool
51+
class TimeFrame(pydantic.BaseModel):
52+
start: float
53+
end: float
54+
totalConfidence: Optional[float]
6755

68-
class ImagePipelineInfo(pydantic.BaseModel):
69-
pipeline_short_name: str
70-
pipeline_version: float
71-
threshold: float
72-
return_confidence: bool
56+
def to_json(self):
57+
return self.model_dump_json(exclude_none=True)
7358

74-
class VideoPipelineInfo(pydantic.BaseModel):
75-
pipeline_short_name: str
76-
pipeline_version: float
77-
threshold: float
78-
frame_interval: float
79-
return_confidence: bool
59+
def __str__(self):
60+
return f"TimeFrame(start={self.start}, end={self.end})"
8061

81-
async def get_current_video_pipeline():
82-
global current_videopipeline
83-
if current_videopipeline is not None:
84-
return current_videopipeline
85-
try:
86-
current_videopipeline = VideoPipelineInfo(**await get_video_config_async())
87-
except aiohttp.ClientConnectionError as e:
88-
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
89-
except Exception as e:
90-
log.error(f"Failed to get pipeline info: {e}. Ensure the AI server is running with at least version 1.3.1!")
91-
raise
92-
return current_videopipeline
62+
class VideoTagInfo(pydantic.BaseModel):
63+
video_duration: float
64+
video_tags: Dict[str, Set[str]]
65+
tag_totals: Dict[str, Dict[str, float]]
66+
tag_timespans: Dict[str, Dict[str, List[TimeFrame]]]
67+
68+
@classmethod
69+
def from_json(cls, json_str: str):
70+
log.info(f"json_str: {json_str}")
71+
log.info(f"video_duration: {json_str['video_duration']}, video_tags: {json_str['video_tags']}, tag_totals: {json_str['tag_totals']}, tag_timespans: {json_str['tag_timespans']}")
72+
return cls(video_duration=json_str["video_duration"], video_tags=json_str["video_tags"], tag_totals=json_str["tag_totals"], tag_timespans=json_str["tag_timespans"])
73+
74+
def __str__(self):
75+
return f"VideoTagInfo(video_duration={self.video_duration}, video_tags={self.video_tags}, tag_totals={self.tag_totals}, tag_timespans={self.tag_timespans})"
76+
77+
class ImageResult(pydantic.BaseModel):
78+
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
79+
80+
class OptimizeMarkerSettings(pydantic.BaseModel):
81+
existing_json_data: Any = None
82+
desired_timespan_data: Dict[str, TimeFrame]

0 commit comments

Comments
 (0)