Skip to content

Commit 37930f1

Browse files
author
BitsAdmin
committed
Merge branch 'feat/llm_stream_check_python' into 'integration_2025-05-22_911100755970'
feat: [development task] waf runtime (1244130) See merge request iaasng/volcengine-python-sdk!618
2 parents 33a47be + b11ef24 commit 37930f1

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# coding: utf-8
2+
3+
# flake8: noqa
4+
5+
"""
6+
waf
7+
8+
No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) # noqa: E501
9+
10+
OpenAPI spec version: common-version
11+
12+
Generated by: https://github.com/swagger-api/swagger-codegen.git
13+
"""
14+
15+
from volcenginesdkwaf import *
16+
from volcenginesdkwafruntime.api.waf_runtime_api import WAFRuntimeApi
17+
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
from volcenginesdkwaf import WAFApi, CheckLLMResponseStreamRequest
3+
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession
4+
5+
global_llm_send_len = 10
6+
7+
class WAFRuntimeApi(WAFApi):
8+
"""继承自 WAFApi 并重写 check_llm_response_stream 方法"""
9+
def check_llm_response_stream(
10+
self,
11+
body: CheckLLMResponseStreamRequest,
12+
**kwargs
13+
):
14+
session: LLMStreamSession = kwargs.pop('session', None)
15+
if session is None:
16+
raise ValueError("session parameter is required")
17+
"""
18+
重写父类方法,增加 session 参数,实现流式内容的合并与处理
19+
20+
Args:
21+
body: 请求体,包含 content、use_stream 和 msg_id 等信息
22+
session: LLM 流会话对象,用于记录流式内容和状态
23+
**kwargs: 其他参数,包括 async_req 等
24+
25+
Returns:
26+
CheckLLMResponseStreamResponse: 处理结果
27+
"""
28+
# 1. 拼接内容到 session.stream_buf
29+
if body.content:
30+
session.append_stream_buf(body.content)
31+
32+
# 2. 处理 use_stream 为 2 或 没有msgid (第一次调用) 的情况(直接发送,不累计长度)
33+
if (body.use_stream == 2
34+
or not session.msg_id):
35+
# 准备请求体
36+
body.content = session.get_stream_buf()
37+
body.msg_id = session.get_msg_id()
38+
# 调用原始 API
39+
# 同步调用并处理结果
40+
resp = self.check_llm_response_stream_with_http_info(body, **kwargs)
41+
if isinstance(resp, tuple):
42+
response = resp[0] # 获取元组的第一个元素
43+
else:
44+
response = resp
45+
46+
# 首次调用时(没有msgid) ,保存 msg_id 到 session
47+
if (not session.msg_id) and response.msg_id:
48+
session.set_msg_id(response.msg_id)
49+
50+
# 保存默认响应
51+
session.set_default_body(response)
52+
53+
# 重置流缓冲区和发送长度
54+
session.set_stream_send_len(0)
55+
56+
return response
57+
58+
# 3. 处理 use_stream 为其他值的情况(累计长度,超过阈值才发送)
59+
else:
60+
# 如果未发送长度超过 10 个字符,调用 API
61+
if session.get_stream_send_len() > global_llm_send_len:
62+
# 准备请求体,使用 session 中的完整流内容
63+
body.content = session.get_stream_buf()
64+
body.msg_id = session.get_msg_id()
65+
# 调用原始 API
66+
if kwargs.get('async_req'):
67+
return self.check_llm_response_stream_with_http_info(body, **kwargs)
68+
69+
# 同步调用并处理结果
70+
resp = self.check_llm_response_stream_with_http_info(body, **kwargs)
71+
if isinstance(resp, tuple) and len(resp) > 0:
72+
response = resp[0] # 获取元组的第一个元素
73+
else:
74+
response = resp
75+
76+
# 保存默认响应
77+
session.set_default_body(response)
78+
79+
# 重置发送长度,保留完整流内容(因为可能还有后续数据)
80+
session.set_stream_send_len(0)
81+
return response
82+
# 如果未发送长度不足 10 个字符,返回上次的结果
83+
else:
84+
default_body = session.get_default_body()
85+
if default_body:
86+
return default_body
87+
else:
88+
# 如果没有默认结果,调用一次 API(这种情况理论上不会发生,因为首次调用 use_stream 应为 0)
89+
return self.check_llm_response_stream(body, **kwargs)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
3+
from volcenginesdkwaf import CheckLLMResponseStreamResponse
4+
5+
6+
class LLMStreamSession:
7+
"""对应 Java 中的 LLMStreamSession 类"""
8+
9+
def __init__(
10+
self,
11+
stream_buf: str = "",
12+
stream_send_len: int = 0,
13+
msg_id: str = "",
14+
default_body: Optional[CheckLLMResponseStreamResponse] = None
15+
):
16+
self.stream_buf = stream_buf
17+
self.stream_send_len = stream_send_len
18+
self.msg_id = msg_id
19+
self.default_body = default_body
20+
21+
def get_stream_buf(self) -> str:
22+
return self.stream_buf
23+
24+
def set_stream_buf(self, stream_buf: str) -> None:
25+
self.stream_buf = stream_buf
26+
27+
def get_stream_send_len(self) -> int:
28+
return self.stream_send_len
29+
30+
def set_stream_send_len(self, stream_send_len: int) -> None:
31+
self.stream_send_len = stream_send_len
32+
33+
def get_msg_id(self) -> str:
34+
return self.msg_id
35+
36+
def set_msg_id(self, msg_id: str) -> None:
37+
self.msg_id = msg_id
38+
39+
def get_default_body(self) -> Optional[CheckLLMResponseStreamResponse]:
40+
return self.default_body
41+
42+
def set_default_body(self, default_body: Optional[CheckLLMResponseStreamResponse]) -> None:
43+
self.default_body = default_body
44+
45+
def append_stream_buf(self, text: Optional[str]) -> None:
46+
"""追加文本到流缓冲区并更新发送长度"""
47+
if text is not None:
48+
self.stream_buf += text
49+
self.stream_send_len += len(text)

0 commit comments

Comments
 (0)