Skip to content

Commit 5a992e4

Browse files
committed
routes: Extract shared auth, model validation, and error handling utilities
Consolidate duplicated boilerplate across all batch pipeline route handlers into shared utility functions (check_auth_token, check_model_id, execute_pipeline) and a shared RESPONSES dict in routes/utils.py. This addresses Victor's feedback on PR livepeer#900 about reducing redundant pipeline interface patterns and making common behavior properly defined in one place rather than copy-pasted across every route. https://claude.ai/code/session_01RvBGa2npztEMxwfAHH3Xve
1 parent 50a742c commit 5a992e4

File tree

11 files changed

+304
-493
lines changed

11 files changed

+304
-493
lines changed

runner/src/runner/routes/audio_to_text.py

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
2-
import os
32
from typing import Annotated, Dict, Tuple, Union
43

5-
import torch
64
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
75
from fastapi.responses import JSONResponse
86
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -12,9 +10,12 @@
1210
from runner.routes.utils import (
1311
HTTPError,
1412
TextResponse,
13+
RESPONSES,
14+
check_auth_token,
15+
check_model_id,
16+
execute_pipeline,
1517
file_exceeds_max_size,
1618
get_media_duration_ffmpeg,
17-
handle_pipeline_exception,
1819
http_error,
1920
parse_key_from_metadata,
2021
)
@@ -37,21 +38,11 @@
3738
),
3839
}
3940

40-
RESPONSES = {
41-
status.HTTP_200_OK: {
42-
"content": {
43-
"application/json": {
44-
"schema": {
45-
"x-speakeasy-name-override": "data",
46-
}
47-
}
48-
},
49-
},
50-
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
51-
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
41+
# Extend shared RESPONSES with additional status codes for this route.
42+
AUDIO_RESPONSES = {
43+
**RESPONSES,
5244
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
5345
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE: {"model": HTTPError},
54-
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
5546
}
5647

5748

@@ -76,7 +67,7 @@ def parse_return_timestamps(value: str) -> Union[bool, str]:
7667
@router.post(
7768
"/audio-to-text",
7869
response_model=TextResponse,
79-
responses=RESPONSES,
70+
responses=AUDIO_RESPONSES,
8071
description="Transcribe audio files to text.",
8172
operation_id="genAudioToText",
8273
summary="Audio To Text",
@@ -86,7 +77,7 @@ def parse_return_timestamps(value: str) -> Union[bool, str]:
8677
@router.post(
8778
"/audio-to-text/",
8879
response_model=TextResponse,
89-
responses=RESPONSES,
80+
responses=AUDIO_RESPONSES,
9081
include_in_schema=False,
9182
)
9283
def audio_to_text(
@@ -116,23 +107,12 @@ def audio_to_text(
116107
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
117108
):
118109
return_timestamps = parse_return_timestamps(return_timestamps)
119-
auth_token = os.environ.get("AUTH_TOKEN")
120-
if auth_token:
121-
if not token or token.credentials != auth_token:
122-
return JSONResponse(
123-
status_code=status.HTTP_401_UNAUTHORIZED,
124-
headers={"WWW-Authenticate": "Bearer"},
125-
content=http_error("Invalid bearer token."),
126-
)
127110

128-
if model_id != "" and model_id != pipeline.model_id:
129-
return JSONResponse(
130-
status_code=status.HTTP_400_BAD_REQUEST,
131-
content=http_error(
132-
f"pipeline configured with {pipeline.model_id} but called with "
133-
f"{model_id}."
134-
),
135-
)
111+
if auth_error := check_auth_token(token):
112+
return auth_error
113+
114+
if model_error := check_model_id(model_id, pipeline.model_id):
115+
return model_error
136116

137117
if file_exceeds_max_size(audio, 50 * 1024 * 1024):
138118
return JSONResponse(
@@ -154,17 +134,14 @@ def audio_to_text(
154134
content=http_error("Unable to calculate duration of file"),
155135
)
156136

157-
try:
158-
return pipeline(
159-
audio=audio, return_timestamps=return_timestamps, duration=duration
160-
)
161-
except Exception as e:
162-
if isinstance(e, torch.cuda.OutOfMemoryError):
163-
# TODO: Investigate why not all VRAM memory is cleared.
164-
torch.cuda.empty_cache()
165-
logger.error(f"AudioToText pipeline error: {e}")
166-
return handle_pipeline_exception(
167-
e,
168-
default_error_message="Audio-to-text pipeline error.",
169-
custom_error_config=PIPELINE_ERROR_CONFIG,
170-
)
137+
result, error = execute_pipeline(
138+
pipeline,
139+
default_error_message="Audio-to-text pipeline error.",
140+
custom_error_config=PIPELINE_ERROR_CONFIG,
141+
audio=audio,
142+
return_timestamps=return_timestamps,
143+
duration=duration,
144+
)
145+
if error:
146+
return error
147+
return result

runner/src/runner/routes/image_to_image.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import logging
2-
import os
32
import random
43
from typing import Annotated, Dict, Tuple, Union
54

6-
import torch
75
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
8-
from fastapi.responses import JSONResponse
96
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
107
from PIL import Image, ImageFile
118

129
from runner.dependencies import get_pipeline
1310
from runner.pipelines.base import Pipeline
1411
from runner.routes.utils import (
15-
HTTPError,
1612
ImageResponse,
17-
handle_pipeline_exception,
18-
http_error,
13+
RESPONSES,
14+
check_auth_token,
15+
check_model_id,
16+
execute_pipeline,
1917
image_to_data_url,
2018
)
2119

@@ -35,21 +33,6 @@
3533
)
3634
}
3735

38-
RESPONSES = {
39-
status.HTTP_200_OK: {
40-
"content": {
41-
"application/json": {
42-
"schema": {
43-
"x-speakeasy-name-override": "data",
44-
}
45-
}
46-
},
47-
},
48-
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
49-
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
50-
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
51-
}
52-
5336

5437
# TODO: Make model_id and other None properties optional once Go codegen tool supports
5538
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@@ -153,23 +136,11 @@ async def image_to_image(
153136
pipeline: Pipeline = Depends(get_pipeline),
154137
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
155138
):
156-
auth_token = os.environ.get("AUTH_TOKEN")
157-
if auth_token:
158-
if not token or token.credentials != auth_token:
159-
return JSONResponse(
160-
status_code=status.HTTP_401_UNAUTHORIZED,
161-
headers={"WWW-Authenticate": "Bearer"},
162-
content=http_error("Invalid bearer token."),
163-
)
139+
if auth_error := check_auth_token(token):
140+
return auth_error
164141

165-
if model_id != "" and model_id != pipeline.model_id:
166-
return JSONResponse(
167-
status_code=status.HTTP_400_BAD_REQUEST,
168-
content=http_error(
169-
f"pipeline configured with {pipeline.model_id} but called with "
170-
f"{model_id}."
171-
),
172-
)
142+
if model_error := check_model_id(model_id, pipeline.model_id):
143+
return model_error
173144

174145
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
175146
seeds = [seed + i for i in range(num_images_per_prompt)]
@@ -181,30 +152,25 @@ async def image_to_image(
181152
images = []
182153
has_nsfw_concept = []
183154
for seed in seeds:
184-
try:
185-
imgs, nsfw_checks = pipeline(
186-
prompt=prompt,
187-
image=image,
188-
strength=strength,
189-
loras=loras,
190-
guidance_scale=guidance_scale,
191-
image_guidance_scale=image_guidance_scale,
192-
negative_prompt=negative_prompt,
193-
safety_check=safety_check,
194-
seed=seed,
195-
num_images_per_prompt=1,
196-
num_inference_steps=num_inference_steps,
197-
)
198-
except Exception as e:
199-
if isinstance(e, torch.cuda.OutOfMemoryError):
200-
# TODO: Investigate why not all VRAM memory is cleared.
201-
torch.cuda.empty_cache()
202-
logger.error(f"ImageToImagePipeline pipeline error: {e}")
203-
return handle_pipeline_exception(
204-
e,
205-
default_error_message="Image-to-image pipeline error.",
206-
custom_error_config=PIPELINE_ERROR_CONFIG,
207-
)
155+
result, error = execute_pipeline(
156+
pipeline,
157+
default_error_message="Image-to-image pipeline error.",
158+
custom_error_config=PIPELINE_ERROR_CONFIG,
159+
prompt=prompt,
160+
image=image,
161+
strength=strength,
162+
loras=loras,
163+
guidance_scale=guidance_scale,
164+
image_guidance_scale=image_guidance_scale,
165+
negative_prompt=negative_prompt,
166+
safety_check=safety_check,
167+
seed=seed,
168+
num_images_per_prompt=1,
169+
num_inference_steps=num_inference_steps,
170+
)
171+
if error:
172+
return error
173+
imgs, nsfw_checks = result
208174
images.extend(imgs)
209175
has_nsfw_concept.extend(nsfw_checks)
210176

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
2-
import os
32
from typing import Annotated, Dict, Tuple, Union
43

5-
import torch
64
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
75
from fastapi.responses import JSONResponse
86
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -13,8 +11,11 @@
1311
from runner.routes.utils import (
1412
HTTPError,
1513
ImageToTextResponse,
14+
RESPONSES,
15+
check_auth_token,
16+
check_model_id,
17+
execute_pipeline,
1618
file_exceeds_max_size,
17-
handle_pipeline_exception,
1819
http_error,
1920
)
2021

@@ -31,27 +32,17 @@
3132
)
3233
}
3334

34-
RESPONSES = {
35-
status.HTTP_200_OK: {
36-
"content": {
37-
"application/json": {
38-
"schema": {
39-
"x-speakeasy-name-override": "data",
40-
}
41-
}
42-
},
43-
},
44-
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
45-
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
35+
# Extend shared RESPONSES with additional status codes for this route.
36+
IMAGE_TO_TEXT_RESPONSES = {
37+
**RESPONSES,
4638
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
47-
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
4839
}
4940

5041

5142
@router.post(
5243
"/image-to-text",
5344
response_model=ImageToTextResponse,
54-
responses=RESPONSES,
45+
responses=IMAGE_TO_TEXT_RESPONSES,
5546
description="Transform image files to text.",
5647
operation_id="genImageToText",
5748
summary="Image To Text",
@@ -61,7 +52,7 @@
6152
@router.post(
6253
"/image-to-text/",
6354
response_model=ImageToTextResponse,
64-
responses=RESPONSES,
55+
responses=IMAGE_TO_TEXT_RESPONSES,
6556
include_in_schema=False,
6657
)
6758
async def image_to_text(
@@ -79,23 +70,11 @@ async def image_to_text(
7970
pipeline: Pipeline = Depends(get_pipeline),
8071
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
8172
):
82-
auth_token = os.environ.get("AUTH_TOKEN")
83-
if auth_token:
84-
if not token or token.credentials != auth_token:
85-
return JSONResponse(
86-
status_code=status.HTTP_401_UNAUTHORIZED,
87-
headers={"WWW-Authenticate": "Bearer"},
88-
content=http_error("Invalid bearer token"),
89-
)
73+
if auth_error := check_auth_token(token):
74+
return auth_error
9075

91-
if model_id != "" and model_id != pipeline.model_id:
92-
return JSONResponse(
93-
status_code=status.HTTP_400_BAD_REQUEST,
94-
content=http_error(
95-
f"pipeline configured with {pipeline.model_id} but called with "
96-
f"{model_id}"
97-
),
98-
)
76+
if model_error := check_model_id(model_id, pipeline.model_id):
77+
return model_error
9978

10079
if file_exceeds_max_size(image, 50 * 1024 * 1024):
10180
return JSONResponse(
@@ -104,15 +83,13 @@ async def image_to_text(
10483
)
10584

10685
image = Image.open(image.file).convert("RGB")
107-
try:
108-
return ImageToTextResponse(text=pipeline(prompt=prompt, image=image))
109-
except Exception as e:
110-
if isinstance(e, torch.cuda.OutOfMemoryError):
111-
# TODO: Investigate why not all VRAM memory is cleared.
112-
torch.cuda.empty_cache()
113-
logger.error(f"ImageToTextPipeline error: {e}")
114-
return handle_pipeline_exception(
115-
e,
116-
default_error_message="Image-to-text pipeline error.",
117-
custom_error_config=PIPELINE_ERROR_CONFIG,
118-
)
86+
result, error = execute_pipeline(
87+
pipeline,
88+
default_error_message="Image-to-text pipeline error.",
89+
custom_error_config=PIPELINE_ERROR_CONFIG,
90+
prompt=prompt,
91+
image=image,
92+
)
93+
if error:
94+
return error
95+
return ImageToTextResponse(text=result)

0 commit comments

Comments
 (0)