Skip to content

Commit 4ec39a7

Browse files
committed
fix: Reformat
1 parent ffb2fcb commit 4ec39a7

File tree

3 files changed

+35
-23
lines changed

3 files changed

+35
-23
lines changed

volcenginesdkllmshield/api/llm_shield_sdk_v2.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pydantic import BaseModel, field_validator, Field
2-
from typing import List, Optional, Any,Union
2+
from typing import List, Optional, Any, Union
33
import requests
44
import json
55

@@ -8,6 +8,7 @@
88
LLM_STREAM_SEND_BASE_WINDOW_V2 = 10
99
LLM_STREAM_SEND_EXPONENT_V2 = 2
1010

11+
1112
# 定义内容类型常量
1213
class ContentTypeV2:
1314
TEXT = 1
@@ -69,7 +70,7 @@ def __init__(self, other=None, **data):
6970
# 1. 将其他实例序列化为字典(包含嵌套对象)
7071
other_dict = other.model_dump(by_alias=True) # 使用 alias 键名
7172
# 2. 用序列化后的字典初始化当前实例(实现深拷贝)
72-
super().__init__(** other_dict)
73+
super().__init__(**other_dict)
7374
else:
7475
# 正常初始化逻辑
7576
super().__init__(**data)
@@ -201,7 +202,7 @@ class Config:
201202

202203
# 定义响应元数据结构体
203204
class ResponseMetadata(BaseModel):
204-
error:Union[ErrorInfo, None] = Field(default_factory=ErrorInfo, alias="Error")
205+
error: Union[ErrorInfo, None] = Field(default_factory=ErrorInfo, alias="Error")
205206
requestId: str = Field(..., alias="RequestId") # 添加requestId字段,映射自RequestId
206207
service: Union[str, None] = Field(None, alias="Service")
207208
action: Union[str, None] = Field(None, alias="Action")
@@ -216,14 +217,15 @@ class Config:
216217
# 定义审核响应结构体
217218
class ModerateV2Response(BaseModel):
218219
response_metadata: ResponseMetadata = Field(default_factory=ResponseMetadata, alias="ResponseMetadata")
219-
result: Union[ModerateV2Result, None] = Field(default_factory=ModerateV2Result, alias="Result")
220+
result: Union[ModerateV2Result, None] = Field(default_factory=ModerateV2Result, alias="Result")
220221

221222
class Config:
222223
populate_by_name = True
223224

224225

225226
class ModerateV2StreamSession:
226227
"""流式会话结构体,用于积累流式请求、存储未发送长度和默认响应体"""
228+
227229
def __init__(self):
228230
# 用于积累流式的请求(初始为 None,对应 Go 中的指针)
229231
self.request: Optional[ModerateV2Request] = None
@@ -263,6 +265,7 @@ class GenerateStreamResult(BaseModel):
263265
"""生成流V2版本的结果模型"""
264266
message: Optional[MessageV2] = Field(None, alias="Message", description="优化内容,isFinished为true时为空")
265267
is_finished: bool = Field(False, alias="IsFinished", description="标识是否结束")
268+
266269
# summarize: Optional[GenerateSummarizeV2] = Field(None, alias="Summarize", description="总结信息,isFinished为true时有值")
267270

268271
class Config:
@@ -280,7 +283,7 @@ class Config:
280283

281284
# 定义客户端类
282285
class ClientV2:
283-
def __init__(self, url: str, ak: str,sk:str, region :str, timeout: float):
286+
def __init__(self, url: str, ak: str, sk: str, region: str, timeout: float):
284287
self.url = url
285288
self.ak = ak
286289
self.sk = sk
@@ -300,11 +303,11 @@ def Moderate(self, request: Optional[ModerateV2Request] = None) -> ModerateV2Res
300303
header = {
301304
}
302305

303-
sign_header = request_sign(header, self.ak, self.sk , self.region, self.url, path,action, request_body)
306+
sign_header = request_sign(header, self.ak, self.sk, self.region, self.url, path, action, request_body)
304307

305308
try:
306309
resp = self.http_client.post(
307-
url=self.url+path+"?Action="+action+"&Version="+ Version,
310+
url=self.url + path + "?Action=" + action + "&Version=" + Version,
308311
data=request_body,
309312
headers=sign_header
310313
)
@@ -317,7 +320,8 @@ def Moderate(self, request: Optional[ModerateV2Request] = None) -> ModerateV2Res
317320
except Exception as e:
318321
raise Exception(f"处理响应失败: {e}")
319322

320-
def ModerateStream(self, request: ModerateV2Request, session: ModerateV2StreamSession) -> Optional[ModerateV2Response]:
323+
def ModerateStream(self, request: ModerateV2Request, session: ModerateV2StreamSession) -> Optional[
324+
ModerateV2Response]:
321325
"""
322326
处理流式审核请求
323327
:param request: 当前流式请求片段(ModerateV2Request 类型)
@@ -353,7 +357,8 @@ def ModerateStream(self, request: ModerateV2Request, session: ModerateV2StreamSe
353357

354358
# 3. 判断是否需要发送请求到后端
355359
# 只有当未检测长度 >= 10 或者是第一次或者是最后一次请求时,才发送请求
356-
need_send_request = is_last_request or is_first_request or (session.stream_send_len >= session.CurrentSendWindow)
360+
need_send_request = is_last_request or is_first_request or (
361+
session.stream_send_len >= session.CurrentSendWindow)
357362

358363
# 如果不需要发送请求,直接返回上次的默认响应(如果有)
359364
if not need_send_request:
@@ -373,7 +378,7 @@ def ModerateStream(self, request: ModerateV2Request, session: ModerateV2StreamSe
373378
sign_header = request_sign(headers, self.ak, self.sk, self.region, self.url, path, action, request_body)
374379
try:
375380
response = requests.post(
376-
url=self.url+path+"?Action="+action+"&Version="+ Version,
381+
url=self.url + path + "?Action=" + action + "&Version=" + Version,
377382
data=request_body,
378383
headers=sign_header
379384
)
@@ -395,7 +400,7 @@ def ModerateStream(self, request: ModerateV2Request, session: ModerateV2StreamSe
395400
# 7. 若为最后一次流式请求(use_stream == 2),打印最终内容
396401
if session.request.use_stream == 2:
397402
final_content = session.request.message.content if (
398-
session.request.message and session.request.message.content) else ""
403+
session.request.message and session.request.message.content) else ""
399404
print(f"最终检测内容: {final_content}")
400405

401406
return moderate_response
@@ -415,7 +420,8 @@ def GenerateV2Stream(self, request):
415420
try:
416421
sign_header = request_sign(headers, self.ak, self.sk, self.region, self.url, path, action, requestBody)
417422
# 发送 HTTP 请求
418-
resp = self.http_client.post(url=self.url+path+"?Action="+action+"&Version="+ Version, data=requestBody, headers=sign_header, stream=True)
423+
resp = self.http_client.post(url=self.url + path + "?Action=" + action + "&Version=" + Version,
424+
data=requestBody, headers=sign_header, stream=True)
419425
if resp.status_code != 200:
420426
raise Exception("bad response code: %d" % resp.status_code)
421427

@@ -459,4 +465,4 @@ def default(self, obj):
459465
elif hasattr(obj, '__dict__'):
460466
return obj.__dict__ # 返回对象的属性字典
461467
# 调用默认处理(会抛出TypeError)
462-
return super().default(obj)
468+
return super().default(obj)

volcenginesdkllmshield/models/llm_shield_sign.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ContentType = "application/json"
2828
Method = "POST"
2929

30+
3031
def norm_query(params):
3132
query = ""
3233
for key in sorted(params.keys()):
@@ -53,7 +54,7 @@ def hash_sha256(content: bytes):
5354

5455

5556
# 第二步:签名请求函数
56-
def request_sign(header, ak, sk,region, url, path, action, body):
57+
def request_sign(header, ak, sk, region, url, path, action, body):
5758
host = urlparse(url).netloc
5859
date = utc_now()
5960
# 第三步:创建身份证明。其中的 Service 和 Region 字段是固定的。ak 和 sk 分别代表
@@ -103,7 +104,7 @@ def request_sign(header, ak, sk,region, url, path, action, body):
103104
"host:" + request_param["host"],
104105
"x-content-sha256:" + x_content_sha256,
105106
"x-date:" + x_date,
106-
]
107+
]
107108
),
108109
"",
109110
signed_headers_str,
@@ -112,16 +113,16 @@ def request_sign(header, ak, sk,region, url, path, action, body):
112113
)
113114

114115
# 打印正规化的请求用于调试比对
115-
#print(canonical_request_str)
116+
# print(canonical_request_str)
116117
hashed_canonical_request = hash_sha256(canonical_request_str.encode("utf-8"))
117118

118119
# 打印hash值用于调试比对
119-
#print(hashed_canonical_request)
120+
# print(hashed_canonical_request)
120121
credential_scope = "/".join([short_x_date, credential["region"], credential["service"], "request"])
121122
string_to_sign = "\n".join(["HMAC-SHA256", x_date, credential_scope, hashed_canonical_request])
122123

123124
# 打印最终计算的签名字符串用于调试比对
124-
#print(string_to_sign)
125+
# print(string_to_sign)
125126
k_date = hmac_sha256(credential["secret_access_key"].encode("utf-8"), short_x_date)
126127
k_region = hmac_sha256(k_date, credential["region"])
127128
k_service = hmac_sha256(k_region, credential["service"])
@@ -132,24 +133,27 @@ def request_sign(header, ak, sk,region, url, path, action, body):
132133
credential["access_key_id"] + "/" + credential_scope,
133134
signed_headers_str,
134135
signature,
135-
)
136-
header = {**header, **sign_result ,"X-Top-Service": Service, "X-Top-Region":region }
136+
)
137+
header = {**header, **sign_result, "X-Top-Service": Service, "X-Top-Region": region}
137138
# header = {**header, **{"X-Security-Token": SessionToken}}
138139
# 第六步:将 Signature 签名写入 HTTP Header 中,并发送 HTTP 请求。
139140
return header
140141

142+
141143
# datetime.utcnow() 在 3.12+ 已经过期,使用如下方法兼容
142144
def utc_now():
143-
144145
try:
145146
from datetime import timezone
146147
return datetime.datetime.now(timezone.utc)
147148
except ImportError:
148149
class UTC(datetime.tzinfo):
149150
def utcoffset(self, dt):
150151
return datetime.timedelta(0)
152+
151153
def tzname(self, dt):
152154
return "UTC"
155+
153156
def dst(self, dt):
154157
return datetime.timedelta(0)
158+
155159
return datetime.datetime.now(UTC())

volcenginesdkwafruntime/api/waf_runtime_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
21
from volcenginesdkwaf import WAFApi, CheckLLMResponseStreamRequest
3-
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession,LLM_STREAM_SEND_EXPONENT,LLM_STREAM_SEND_BASE_WINDOW
2+
from volcenginesdkwafruntime.models.llm_stream_session import LLMStreamSession, LLM_STREAM_SEND_EXPONENT, \
3+
LLM_STREAM_SEND_BASE_WINDOW
44

55
global_llm_send_len = 10
66

7+
78
class WAFRuntimeApi(WAFApi):
89
"""继承自 WAFApi 并重写 check_llm_response_stream 方法"""
10+
911
def check_llm_response_stream(
1012
self,
1113
body: CheckLLMResponseStreamRequest,

0 commit comments

Comments
 (0)