Skip to content

Commit daf4f13

Browse files
committed
重写接口CheckLLMResponseStreamResponse ,支持流式调用
1 parent 85c204a commit daf4f13

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# coding: utf-8
2+
3+
# flake8: noqa
4+
5+
"""
6+
waf
7+
8+
No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) # noqa: E501
9+
10+
OpenAPI spec version: common-version
11+
12+
Generated by: https://github.com/swagger-api/swagger-codegen.git
13+
"""
14+
15+
from volcenginesdkwaf import *
16+
from volcenginesdkwafruntime.api.waf_runtime_api import WAFRuntimeApi
17+
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
2+
from volcenginesdkwaf import WAFApi, CheckLLMResponseStreamRequest
3+
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession
4+
5+
6+
class WAFRuntimeApi(WAFApi):
7+
"""继承自 WAFApi 并重写 check_llm_response_stream 方法"""
8+
def check_llm_response_stream(
9+
self,
10+
body: CheckLLMResponseStreamRequest,
11+
**kwargs
12+
):
13+
session: LLMStreamSession = kwargs.pop('session', None)
14+
if session is None:
15+
raise ValueError("session parameter is required")
16+
"""
17+
重写父类方法,增加 session 参数,实现流式内容的合并与处理
18+
19+
Args:
20+
body: 请求体,包含 content、use_stream 和 msg_id 等信息
21+
session: LLM 流会话对象,用于记录流式内容和状态
22+
**kwargs: 其他参数,包括 async_req 等
23+
24+
Returns:
25+
CheckLLMResponseStreamResponse: 处理结果
26+
"""
27+
# 1. 拼接内容到 session.stream_buf
28+
if body.content:
29+
session.append_stream_buf(body.content)
30+
31+
# 2. 处理 use_stream 为 2 或 没有msgid (第一次调用) 的情况(直接发送,不累计长度)
32+
if (body.use_stream == 2
33+
or not session.msg_id):
34+
# 准备请求体
35+
body.content = session.get_stream_buf()
36+
body.msg_id = session.get_msg_id()
37+
# 调用原始 API
38+
# 同步调用并处理结果
39+
resp = self.check_llm_response_stream_with_http_info(body, **kwargs)
40+
if isinstance(resp, tuple):
41+
response = resp[0] # 获取元组的第一个元素
42+
else:
43+
response = resp
44+
45+
# 首次调用时(没有msgid) ,保存 msg_id 到 session
46+
if (not session.msg_id) and response.msg_id:
47+
session.set_msg_id(response.msg_id)
48+
49+
# 保存默认响应
50+
session.set_default_body(response)
51+
52+
# 重置流缓冲区和发送长度
53+
session.set_stream_send_len(0)
54+
55+
return response
56+
57+
# 3. 处理 use_stream 为其他值的情况(累计长度,超过阈值才发送)
58+
else:
59+
# 如果未发送长度超过 10 个字符,调用 API
60+
if session.get_stream_send_len() > 10:
61+
# 准备请求体,使用 session 中的完整流内容
62+
body.content = session.get_stream_buf()
63+
body.msg_id = session.get_msg_id()
64+
# 调用原始 API
65+
if kwargs.get('async_req'):
66+
return self.check_llm_response_stream_with_http_info(body, **kwargs)
67+
68+
# 同步调用并处理结果
69+
resp = self.check_llm_response_stream_with_http_info(body, **kwargs)
70+
if isinstance(resp, tuple):
71+
response = resp[0] # 获取元组的第一个元素
72+
else:
73+
response = resp
74+
75+
# 保存默认响应
76+
session.set_default_body(response)
77+
78+
# 重置发送长度,保留完整流内容(因为可能还有后续数据)
79+
session.set_stream_send_len(0)
80+
return response
81+
# 如果未发送长度不足 10 个字符,返回上次的结果
82+
else:
83+
default_body = session.get_default_body()
84+
if default_body:
85+
return default_body
86+
else:
87+
# 如果没有默认结果,调用一次 API(这种情况理论上不会发生,因为首次调用 use_stream 应为 0)
88+
return self.check_llm_response_stream(body, **kwargs)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
3+
from volcenginesdkwaf import CheckLLMResponseStreamResponse
4+
5+
6+
class LLMStreamSession:
7+
"""对应 Java 中的 LLMStreamSession 类"""
8+
9+
def __init__(
10+
self,
11+
stream_buf: str = "",
12+
stream_send_len: int = 0,
13+
msg_id: str = "",
14+
default_body: Optional[CheckLLMResponseStreamResponse] = None
15+
):
16+
self.stream_buf = stream_buf
17+
self.stream_send_len = stream_send_len
18+
self.msg_id = msg_id
19+
self.default_body = default_body
20+
21+
def get_stream_buf(self) -> str:
22+
return self.stream_buf
23+
24+
def set_stream_buf(self, stream_buf: str) -> None:
25+
self.stream_buf = stream_buf
26+
27+
def get_stream_send_len(self) -> int:
28+
return self.stream_send_len
29+
30+
def set_stream_send_len(self, stream_send_len: int) -> None:
31+
self.stream_send_len = stream_send_len
32+
33+
def get_msg_id(self) -> str:
34+
return self.msg_id
35+
36+
def set_msg_id(self, msg_id: str) -> None:
37+
self.msg_id = msg_id
38+
39+
def get_default_body(self) -> Optional[CheckLLMResponseStreamResponse]:
40+
return self.default_body
41+
42+
def set_default_body(self, default_body: Optional[CheckLLMResponseStreamResponse]) -> None:
43+
self.default_body = default_body
44+
45+
def append_stream_buf(self, text: Optional[str]) -> None:
46+
"""追加文本到流缓冲区并更新发送长度"""
47+
if text is not None:
48+
self.stream_buf += text
49+
self.stream_send_len += len(text)

0 commit comments

Comments
 (0)