|
| 1 | + |
| 2 | +from volcenginesdkwaf import WAFApi, CheckLLMResponseStreamRequest |
| 3 | +from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession |
| 4 | + |
| 5 | +global_llm_send_len = 10 |
| 6 | + |
| 7 | +class WAFRuntimeApi(WAFApi): |
| 8 | + """继承自 WAFApi 并重写 check_llm_response_stream 方法""" |
| 9 | + def check_llm_response_stream( |
| 10 | + self, |
| 11 | + body: CheckLLMResponseStreamRequest, |
| 12 | + **kwargs |
| 13 | + ): |
| 14 | + session: LLMStreamSession = kwargs.pop('session', None) |
| 15 | + if session is None: |
| 16 | + raise ValueError("session parameter is required") |
| 17 | + """ |
| 18 | + 重写父类方法,增加 session 参数,实现流式内容的合并与处理 |
| 19 | +
|
| 20 | + Args: |
| 21 | + body: 请求体,包含 content、use_stream 和 msg_id 等信息 |
| 22 | + session: LLM 流会话对象,用于记录流式内容和状态 |
| 23 | + **kwargs: 其他参数,包括 async_req 等 |
| 24 | +
|
| 25 | + Returns: |
| 26 | + CheckLLMResponseStreamResponse: 处理结果 |
| 27 | + """ |
| 28 | + # 1. 拼接内容到 session.stream_buf |
| 29 | + if body.content: |
| 30 | + session.append_stream_buf(body.content) |
| 31 | + |
| 32 | + # 2. 处理 use_stream 为 2 或 没有msgid (第一次调用) 的情况(直接发送,不累计长度) |
| 33 | + if (body.use_stream == 2 |
| 34 | + or not session.msg_id): |
| 35 | + # 准备请求体 |
| 36 | + body.content = session.get_stream_buf() |
| 37 | + body.msg_id = session.get_msg_id() |
| 38 | + # 调用原始 API |
| 39 | + # 同步调用并处理结果 |
| 40 | + resp = self.check_llm_response_stream_with_http_info(body, **kwargs) |
| 41 | + if isinstance(resp, tuple): |
| 42 | + response = resp[0] # 获取元组的第一个元素 |
| 43 | + else: |
| 44 | + response = resp |
| 45 | + |
| 46 | + # 首次调用时(没有msgid) ,保存 msg_id 到 session |
| 47 | + if (not session.msg_id) and response.msg_id: |
| 48 | + session.set_msg_id(response.msg_id) |
| 49 | + |
| 50 | + # 保存默认响应 |
| 51 | + session.set_default_body(response) |
| 52 | + |
| 53 | + # 重置流缓冲区和发送长度 |
| 54 | + session.set_stream_send_len(0) |
| 55 | + |
| 56 | + return response |
| 57 | + |
| 58 | + # 3. 处理 use_stream 为其他值的情况(累计长度,超过阈值才发送) |
| 59 | + else: |
| 60 | + # 如果未发送长度超过 10 个字符,调用 API |
| 61 | + if session.get_stream_send_len() > global_llm_send_len: |
| 62 | + # 准备请求体,使用 session 中的完整流内容 |
| 63 | + body.content = session.get_stream_buf() |
| 64 | + body.msg_id = session.get_msg_id() |
| 65 | + # 调用原始 API |
| 66 | + if kwargs.get('async_req'): |
| 67 | + return self.check_llm_response_stream_with_http_info(body, **kwargs) |
| 68 | + |
| 69 | + # 同步调用并处理结果 |
| 70 | + resp = self.check_llm_response_stream_with_http_info(body, **kwargs) |
| 71 | + if isinstance(resp, tuple) and len(resp) > 0: |
| 72 | + response = resp[0] # 获取元组的第一个元素 |
| 73 | + else: |
| 74 | + response = resp |
| 75 | + |
| 76 | + # 保存默认响应 |
| 77 | + session.set_default_body(response) |
| 78 | + |
| 79 | + # 重置发送长度,保留完整流内容(因为可能还有后续数据) |
| 80 | + session.set_stream_send_len(0) |
| 81 | + return response |
| 82 | + # 如果未发送长度不足 10 个字符,返回上次的结果 |
| 83 | + else: |
| 84 | + default_body = session.get_default_body() |
| 85 | + if default_body: |
| 86 | + return default_body |
| 87 | + else: |
| 88 | + # 如果没有默认结果,调用一次 API(这种情况理论上不会发生,因为首次调用 use_stream 应为 0) |
| 89 | + return self.check_llm_response_stream(body, **kwargs) |
0 commit comments