|
| 1 | +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import json |
| 16 | +import threading |
| 17 | + |
| 18 | +from volcengine.ApiInfo import ApiInfo |
| 19 | +from volcengine.auth.SignerV4 import SignerV4 |
| 20 | +from volcengine.base.Service import Service |
| 21 | +from volcengine.Credentials import Credentials |
| 22 | +from volcengine.ServiceInfo import ServiceInfo |
| 23 | + |
| 24 | + |
| 25 | +class VikingDBMemoryException(Exception): |
| 26 | + def __init__(self, code, request_id, message=None): |
| 27 | + self.code = code |
| 28 | + self.request_id = request_id |
| 29 | + self.message = "{}, code:{},request_id:{}".format( |
| 30 | + message, self.code, self.request_id |
| 31 | + ) |
| 32 | + |
| 33 | + def __str__(self): |
| 34 | + return self.message |
| 35 | + |
| 36 | + |
| 37 | +class VikingDBMemoryClient(Service): |
| 38 | + _instance_lock = threading.Lock() |
| 39 | + |
| 40 | + def __new__(cls, *args, **kwargs): |
| 41 | + if not hasattr(VikingDBMemoryClient, "_instance"): |
| 42 | + with VikingDBMemoryClient._instance_lock: |
| 43 | + if not hasattr(VikingDBMemoryClient, "_instance"): |
| 44 | + VikingDBMemoryClient._instance = object.__new__(cls) |
| 45 | + return VikingDBMemoryClient._instance |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + host="api-knowledgebase.mlp.cn-beijing.volces.com", |
| 50 | + region="cn-beijing", |
| 51 | + ak="", |
| 52 | + sk="", |
| 53 | + sts_token="", |
| 54 | + scheme="http", |
| 55 | + connection_timeout=30, |
| 56 | + socket_timeout=30, |
| 57 | + ): |
| 58 | + self.service_info = VikingDBMemoryClient.get_service_info( |
| 59 | + host, region, scheme, connection_timeout, socket_timeout |
| 60 | + ) |
| 61 | + self.api_info = VikingDBMemoryClient.get_api_info() |
| 62 | + super(VikingDBMemoryClient, self).__init__(self.service_info, self.api_info) |
| 63 | + if ak: |
| 64 | + self.set_ak(ak) |
| 65 | + if sk: |
| 66 | + self.set_sk(sk) |
| 67 | + if sts_token: |
| 68 | + self.set_session_token(session_token=sts_token) |
| 69 | + try: |
| 70 | + self.get_body("Ping", {}, json.dumps({})) |
| 71 | + except Exception as e: |
| 72 | + raise VikingDBMemoryException( |
| 73 | + 1000028, "missed", "host or region is incorrect: {}".format(str(e)) |
| 74 | + ) from None |
| 75 | + |
| 76 | + def setHeader(self, header): |
| 77 | + api_info = VikingDBMemoryClient.get_api_info() |
| 78 | + for key in api_info: |
| 79 | + for item in header: |
| 80 | + api_info[key].header[item] = header[item] |
| 81 | + self.api_info = api_info |
| 82 | + |
| 83 | + @staticmethod |
| 84 | + def get_service_info(host, region, scheme, connection_timeout, socket_timeout): |
| 85 | + service_info = ServiceInfo( |
| 86 | + host, |
| 87 | + {"Host": host}, |
| 88 | + Credentials("", "", "air", region), |
| 89 | + connection_timeout, |
| 90 | + socket_timeout, |
| 91 | + scheme=scheme, |
| 92 | + ) |
| 93 | + return service_info |
| 94 | + |
| 95 | + @staticmethod |
| 96 | + def get_api_info(): |
| 97 | + api_info = { |
| 98 | + "CreateCollection": ApiInfo( |
| 99 | + "POST", |
| 100 | + "/api/memory/collection/create", |
| 101 | + {}, |
| 102 | + {}, |
| 103 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 104 | + ), |
| 105 | + "GetCollection": ApiInfo( |
| 106 | + "POST", |
| 107 | + "/api/memory/collection/info", |
| 108 | + {}, |
| 109 | + {}, |
| 110 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 111 | + ), |
| 112 | + "DropCollection": ApiInfo( |
| 113 | + "POST", |
| 114 | + "/api/memory/collection/delete", |
| 115 | + {}, |
| 116 | + {}, |
| 117 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 118 | + ), |
| 119 | + "UpdateCollection": ApiInfo( |
| 120 | + "POST", |
| 121 | + "/api/memory/collection/update", |
| 122 | + {}, |
| 123 | + {}, |
| 124 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 125 | + ), |
| 126 | + "SearchMemory": ApiInfo( |
| 127 | + "POST", |
| 128 | + "/api/memory/search", |
| 129 | + {}, |
| 130 | + {}, |
| 131 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 132 | + ), |
| 133 | + "AddMessages": ApiInfo( |
| 134 | + "POST", |
| 135 | + "/api/memory/messages/add", |
| 136 | + {}, |
| 137 | + {}, |
| 138 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 139 | + ), |
| 140 | + "Ping": ApiInfo( |
| 141 | + "GET", |
| 142 | + "/api/memory/ping", |
| 143 | + {}, |
| 144 | + {}, |
| 145 | + {"Accept": "application/json", "Content-Type": "application/json"}, |
| 146 | + ), |
| 147 | + } |
| 148 | + return api_info |
| 149 | + |
| 150 | + def get_body(self, api, params, body): |
| 151 | + if api not in self.api_info: |
| 152 | + raise Exception("no such api") |
| 153 | + api_info = self.api_info[api] |
| 154 | + r = self.prepare_request(api_info, params) |
| 155 | + r.headers["Content-Type"] = "application/json" |
| 156 | + r.headers["Traffic-Source"] = "SDK" |
| 157 | + r.body = body |
| 158 | + |
| 159 | + SignerV4.sign(r, self.service_info.credentials) |
| 160 | + |
| 161 | + url = r.build() |
| 162 | + resp = self.session.get( |
| 163 | + url, |
| 164 | + headers=r.headers, |
| 165 | + data=r.body, |
| 166 | + timeout=( |
| 167 | + self.service_info.connection_timeout, |
| 168 | + self.service_info.socket_timeout, |
| 169 | + ), |
| 170 | + ) |
| 171 | + if resp.status_code == 200: |
| 172 | + return json.dumps(resp.json()) |
| 173 | + else: |
| 174 | + raise Exception(resp.text.encode("utf-8")) |
| 175 | + |
| 176 | + def get_body_exception(self, api, params, body): |
| 177 | + try: |
| 178 | + res = self.get_body(api, params, body) |
| 179 | + except Exception as e: |
| 180 | + try: |
| 181 | + res_json = json.loads(e.args[0].decode("utf-8")) |
| 182 | + except Exception as e: |
| 183 | + raise VikingDBMemoryException( |
| 184 | + 1000028, "missed", "json load res error, res:{}".format(str(e)) |
| 185 | + ) from None |
| 186 | + code = res_json.get("code", 1000028) |
| 187 | + request_id = res_json.get("request_id", 1000028) |
| 188 | + message = res_json.get("message", None) |
| 189 | + |
| 190 | + raise VikingDBMemoryException(code, request_id, message) |
| 191 | + |
| 192 | + if res == "": |
| 193 | + raise VikingDBMemoryException( |
| 194 | + 1000028, |
| 195 | + "missed", |
| 196 | + "empty response due to unknown error, please contact customer service", |
| 197 | + ) from None |
| 198 | + return res |
| 199 | + |
| 200 | + def get_exception(self, api, params): |
| 201 | + try: |
| 202 | + res = self.get(api, params) |
| 203 | + except Exception as e: |
| 204 | + try: |
| 205 | + res_json = json.loads(e.args[0].decode("utf-8")) |
| 206 | + except Exception as e: |
| 207 | + raise VikingDBMemoryException( |
| 208 | + 1000028, "missed", "json load res error, res:{}".format(str(e)) |
| 209 | + ) from None |
| 210 | + code = res_json.get("code", 1000028) |
| 211 | + request_id = res_json.get("request_id", 1000028) |
| 212 | + message = res_json.get("message", None) |
| 213 | + raise VikingDBMemoryException(code, request_id, message) |
| 214 | + if res == "": |
| 215 | + raise VikingDBMemoryException( |
| 216 | + 1000028, |
| 217 | + "missed", |
| 218 | + "empty response due to unknown error, please contact customer service", |
| 219 | + ) from None |
| 220 | + return res |
| 221 | + |
| 222 | + def create_collection( |
| 223 | + self, |
| 224 | + collection_name, |
| 225 | + description="", |
| 226 | + custom_event_type_schemas=[], |
| 227 | + custom_entity_type_schemas=[], |
| 228 | + builtin_event_types=[], |
| 229 | + builtin_entity_types=[], |
| 230 | + ): |
| 231 | + params = { |
| 232 | + "CollectionName": collection_name, |
| 233 | + "Description": description, |
| 234 | + "CustomEventTypeSchemas": custom_event_type_schemas, |
| 235 | + "CustomEntityTypeSchemas": custom_entity_type_schemas, |
| 236 | + "BuiltinEventTypes": builtin_event_types, |
| 237 | + "BuiltinEntityTypes": builtin_entity_types, |
| 238 | + } |
| 239 | + res = self.json("CreateCollection", {}, json.dumps(params)) |
| 240 | + return json.loads(res) |
| 241 | + |
| 242 | + def get_collection(self, collection_name): |
| 243 | + params = {"CollectionName": collection_name} |
| 244 | + res = self.json("GetCollection", {}, json.dumps(params)) |
| 245 | + return json.loads(res) |
| 246 | + |
| 247 | + def drop_collection(self, collection_name): |
| 248 | + params = {"CollectionName": collection_name} |
| 249 | + res = self.json("DropCollection", {}, json.dumps(params)) |
| 250 | + return json.loads(res) |
| 251 | + |
| 252 | + def update_collection( |
| 253 | + self, |
| 254 | + collection_name, |
| 255 | + custom_event_type_schemas=[], |
| 256 | + custom_entity_type_schemas=[], |
| 257 | + builtin_event_types=[], |
| 258 | + builtin_entity_types=[], |
| 259 | + ): |
| 260 | + params = { |
| 261 | + "CollectionName": collection_name, |
| 262 | + "CustomEventTypeSchemas": custom_event_type_schemas, |
| 263 | + "CustomEntityTypeSchemas": custom_entity_type_schemas, |
| 264 | + "BuiltinEventTypes": builtin_event_types, |
| 265 | + "BuiltinEntityTypes": builtin_entity_types, |
| 266 | + } |
| 267 | + res = self.json("UpdateCollection", {}, json.dumps(params)) |
| 268 | + return json.loads(res) |
| 269 | + |
| 270 | + def search_memory(self, collection_name, query, filter, limit=10): |
| 271 | + params = { |
| 272 | + "collection_name": collection_name, |
| 273 | + "limit": limit, |
| 274 | + "filter": filter, |
| 275 | + } |
| 276 | + if query: |
| 277 | + params["query"] = query |
| 278 | + res = self.json("SearchMemory", {}, json.dumps(params)) |
| 279 | + return json.loads(res) |
| 280 | + |
| 281 | + def add_messages( |
| 282 | + self, collection_name, session_id, messages, metadata, entities=None |
| 283 | + ): |
| 284 | + params = { |
| 285 | + "collection_name": collection_name, |
| 286 | + "session_id": session_id, |
| 287 | + "messages": messages, |
| 288 | + "metadata": metadata, |
| 289 | + } |
| 290 | + if entities is not None: |
| 291 | + params["entities"] = entities |
| 292 | + res = self.json("AddMessages", {}, json.dumps(params)) |
| 293 | + return json.loads(res) |
0 commit comments