Skip to content

Commit 013afef

Browse files
authored
Merge pull request #18 from saichandrapandraju/global-state
feat: Refactor Detector lifecycle to use FastAPI app.state
2 parents ff88f67 + 0a3bd13 commit 013afef

File tree

5 files changed

+50
-19
lines changed

5 files changed

+50
-19
lines changed

detectors/common/app.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ async def http_exception_handler(self, request, exc):
9494
content={"code": exc.status_code, "message": exc.detail},
9595
)
9696

97+
def set_detector(self, detector) -> None:
98+
"""Store detector in app.state"""
99+
self.state.detector = detector
100+
101+
def get_detector(self):
102+
"""Retrieve detector from app.state"""
103+
return getattr(self.state, 'detector', None)
104+
105+
def cleanup_detector(self) -> None:
106+
"""Clean up detector resources"""
107+
self.state.detector = None
97108

98109
async def health():
99110
return "ok"

detectors/huggingface/app.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
Error,
1616
)
1717

18-
detector_objects = {}
19-
2018

2119
@asynccontextmanager
2220
async def lifespan(app: FastAPI):
23-
detector_objects["detector"] = Detector()
21+
app.set_detector(Detector())
2422
yield
2523
# Clean up the ML models and release the resources
26-
detector_objects.clear()
24+
detector: Detector = app.get_detector()
25+
if detector and hasattr(detector, 'close'):
26+
detector.close()
27+
app.cleanup_detector()
2728

2829

2930
app = FastAPI(lifespan=lifespan, dependencies=[])
@@ -44,4 +45,7 @@ async def detector_unary_handler(
4445
request: ContentAnalysisHttpRequest,
4546
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
4647
):
47-
return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))
48+
detector: Detector = app.get_detector()
49+
if not detector:
50+
raise RuntimeError("Detector is not initialized")
51+
return ContentsAnalysisResponse(root=detector.run(request))

detectors/huggingface/detector.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ContentAnalysisResponse,
1717
ContentsAnalysisResponse,
1818
)
19-
19+
import gc
2020

2121
class Detector:
2222
risk_names = [
@@ -280,3 +280,20 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
280280
raise ValueError("Unsupported model type for analysis.")
281281
contents_analyses.append(analyses)
282282
return contents_analyses
283+
284+
285+
def close(self) -> None:
286+
"""Clean up model and tokenizer resources."""
287+
288+
if self.model:
289+
if hasattr(self.model, 'to') and hasattr(self.model, 'device') and self.model.device.type != "cpu":
290+
self.model = self.model.to(torch.device("cpu"))
291+
self.model = None
292+
293+
if self.tokenizer:
294+
self.tokenizer = None
295+
296+
gc.collect()
297+
298+
if torch.cuda.is_available():
299+
torch.cuda.empty_cache()

detectors/llm_judge/app.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,18 @@
1313
Error,
1414
)
1515

16-
detector_objects: Dict[str, LLMJudgeDetector] = {}
17-
1816

1917
@asynccontextmanager
2018
async def lifespan(app: FastAPI):
2119
"""Application lifespan management."""
22-
try:
23-
detector_objects["detector"] = LLMJudgeDetector()
24-
yield
25-
finally:
26-
# Clean up resources
27-
if "detector" in detector_objects:
28-
await detector_objects["detector"].close()
29-
detector_objects.clear()
20+
21+
app.set_detector(LLMJudgeDetector())
22+
yield
23+
# Clean up resources
24+
detector: LLMJudgeDetector = app.get_detector()
25+
if detector and hasattr(detector, 'close'):
26+
await detector.close()
27+
app.cleanup_detector()
3028

3129

3230
app = FastAPI(lifespan=lifespan, dependencies=[])
@@ -49,7 +47,8 @@ async def detector_unary_handler(
4947
detector_id: Annotated[str, Header(example="llm_judge_safety")],
5048
):
5149
"""Analyze content using LLM-as-Judge evaluation."""
52-
return ContentsAnalysisResponse(root=await detector_objects["detector"].run(request))
50+
detector: LLMJudgeDetector = app.get_detector()
51+
return ContentsAnalysisResponse(root=await detector.run(request))
5352

5453

5554
@app.get(
@@ -62,7 +61,7 @@ async def detector_unary_handler(
6261
)
6362
async def list_metrics():
6463
"""List all available evaluation metrics."""
65-
detector = detector_objects.get("detector")
64+
detector: LLMJudgeDetector = app.get_detector()
6665
if not detector:
6766
return {"metrics": [], "total": 0}
6867

detectors/llm_judge/detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class LLMJudgeDetector:
1616

1717
def __init__(self) -> None:
1818
"""Initialize the LLM Judge Detector."""
19-
self.judge = None
19+
self.judge: Judge = None
2020
self.available_metrics = set(BUILTIN_METRICS.keys())
2121

2222
# Get configuration from environment

0 commit comments

Comments
 (0)