-
Notifications
You must be signed in to change notification settings - Fork 253
Expand file tree
/
Copy pathob_vector_embedding_handler.h
More file actions
461 lines (403 loc) · 14.7 KB
/
Copy pathob_vector_embedding_handler.h
File metadata and controls
461 lines (403 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
/*
* Copyright (c) 2025 OceanBase.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef OCEANBASE_SHARE_VECTOR_INDEX_OB_VECTOR_EMBEDDING_HANDLER_H_
#define OCEANBASE_SHARE_VECTOR_INDEX_OB_VECTOR_EMBEDDING_HANDLER_H_
#include "ob_vector_index_util.h"
#include "lib/lock/ob_thread_cond.h"
#include "lib/lock/ob_latch.h"
#include "lib/thread/thread_mgr_interface.h"
#include "lib/allocator/ob_allocator.h"
#include "share/io/ob_io_define.h"
namespace oceanbase
{
namespace storage
{
class ObEmbeddingIOCallbackHandle;
}
namespace share
{
// Internal task processing phase
enum ObEmbeddingTaskPhase {
OB_EMBEDDING_TASK_INIT = 0,
OB_EMBEDDING_TASK_HTTP_SENT = 1,
OB_EMBEDDING_TASK_HTTP_COMPLETED = 2,
OB_EMBEDDING_TASK_PARSED = 3,
OB_EMBEDDING_TASK_DONE = 4
};
// Embedding task status enum
enum ObEmbeddingTaskStatus {
OB_EMBEDDING_TASK_PENDING = 0,
OB_EMBEDDING_TASK_RUNNING = 1,
OB_EMBEDDING_TASK_FINISHED = 2,
OB_EMBEDDING_TASK_FAILED = 3,
OB_EMBEDDING_TASK_CANCELLED = 4,
OB_EMBEDDING_TASK_INVALID = 5
};
struct ObEmbeddingTaskInfo {
uint64_t tenant_id_;
int64_t task_id_;
ObEmbeddingTaskStatus status_;
int64_t task_start_time_;
int64_t task_end_time_;
ObString model_name_;
ObString model_url_;
int64_t processed_chunks_;
int64_t total_chunks_;
int64_t error_code_;
ObString error_message_;
ObString http_error_message_;
int64_t http_error_code_;
int64_t processing_time_us_;
int64_t retry_count_;
ObEmbeddingTaskInfo() {
reset();
}
void reset() {
tenant_id_ = OB_INVALID_TENANT_ID;
task_id_ = OB_INVALID_ID;
status_ = OB_EMBEDDING_TASK_INVALID;
task_start_time_ = 0;
task_end_time_ = 0;
model_name_.reset();
model_url_.reset();
processed_chunks_ = 0;
total_chunks_ = 0;
error_code_ = 0;
error_message_.reset();
processing_time_us_ = 0;
retry_count_ = 0;
}
bool is_valid() const {
return tenant_id_ != OB_INVALID_TENANT_ID &&
task_id_ != OB_INVALID_ID &&
status_ != OB_EMBEDDING_TASK_INVALID;
}
TO_STRING_KV(K_(tenant_id), K_(task_id), K_(status),
K_(task_start_time), K_(task_end_time), K_(model_name),
K_(model_url), K_(processed_chunks), K_(total_chunks),
K_(error_code), K_(error_message), K_(http_error_message), K_(http_error_code),
K_(processing_time_us), K_(retry_count));
};
// Unified phase management class
class ObEmbeddingTaskPhaseManager {
private:
// Phase transition validation - strict state machine
// Each array contains valid transitions FROM that phase
// Note: DONE is allowed from any state for failure handling
static const ObEmbeddingTaskPhase VALID_TRANSITIONS_FROM_INIT[];
static const ObEmbeddingTaskPhase VALID_TRANSITIONS_FROM_HTTP_SENT[];
static const ObEmbeddingTaskPhase VALID_TRANSITIONS_FROM_HTTP_COMPLETED[];
static const ObEmbeddingTaskPhase VALID_TRANSITIONS_FROM_PARSED[];
static const ObEmbeddingTaskPhase VALID_TRANSITIONS_FROM_DONE[];
public:
// Internal phase machine validation
static bool is_valid_transition(ObEmbeddingTaskPhase from_phase, ObEmbeddingTaskPhase to_phase);
// Get string representation of internal phase
static const char* get_phase_string(ObEmbeddingTaskPhase phase) {
switch (phase) {
case OB_EMBEDDING_TASK_INIT: return "INIT";
case OB_EMBEDDING_TASK_HTTP_SENT: return "HTTP_SENT";
case OB_EMBEDDING_TASK_HTTP_COMPLETED: return "HTTP_COMPLETED";
case OB_EMBEDDING_TASK_PARSED: return "PARSED";
case OB_EMBEDDING_TASK_DONE: return "DONE";
default: return "UNKNOWN";
}
}
// Map internal phase to system table status
static ObEmbeddingTaskStatus map_phase_to_status(ObEmbeddingTaskPhase phase, bool is_finished, int result_code) {
if (is_finished) {
return (result_code == OB_SUCCESS) ? OB_EMBEDDING_TASK_FINISHED : OB_EMBEDDING_TASK_FAILED;
}
switch (phase) {
case OB_EMBEDDING_TASK_INIT:
return OB_EMBEDDING_TASK_PENDING;
case OB_EMBEDDING_TASK_HTTP_SENT:
case OB_EMBEDDING_TASK_HTTP_COMPLETED:
case OB_EMBEDDING_TASK_PARSED:
return OB_EMBEDDING_TASK_RUNNING;
case OB_EMBEDDING_TASK_DONE:
return (result_code == OB_SUCCESS) ? OB_EMBEDDING_TASK_FINISHED : OB_EMBEDDING_TASK_FAILED;
default:
return OB_EMBEDDING_TASK_INVALID;
}
}
};
class ObEmbeddingTaskHandler;
// Constants for field lengths
class ObEmbeddingTask
{
public:
ObEmbeddingTask();
explicit ObEmbeddingTask(ObArenaAllocator &allocator);
~ObEmbeddingTask();
int init(const ObString &model_url,
const ObString &model_name,
const ObString &provider,
const ObString &user_key,
const ObIArray<ObString> &input_chunks,
int64_t dimension,
int64_t http_timeout_us,
int64_t http_max_retries,
storage::ObEmbeddingIOCallbackHandle *cb_handle = nullptr);
template <typename ThreadPoolType>
int do_work(ThreadPoolType *thread_pool);
int64_t get_task_id() const { return task_id_; }
// 线程安全的访问方法,用于虚拟表查询
int get_task_info_for_virtual_table(ObEmbeddingTaskInfo &task_info);
TO_STRING_KV(K_(is_inited),
K_(task_id),
K_(model_url),
K_(model_name),
K_(user_key),
K(input_chunks_.count()),
K(output_vectors_.count()),
K_(dimension),
K_(batch_size),
K_(processed_chunks),
K_(total_chunks),
K_(process_callback_offset));
bool is_completed();
void retain_if_managed();
void release_if_managed();
int get_async_result(ObArray<float*> &output_vectors);
// 公共方法用于外部设置任务失败
int mark_task_failed(int error_code);
int maybe_callback();
int wait_for_completion();
int wake_up();
void disable_callback();
void set_callback_done();
bool need_callback() { return cb_handle_ != nullptr ? true : false; };
public:
static const ObString MODEL_URL_NAME;
static const ObString MODEL_NAME_NAME;
static const ObString ENCODING_FORMAT_NAME;
static const ObString DATA_NAME;
static const ObString EMBEDDING_NAME;
static const ObString FLOAT_FORMAT;
static const ObString BASE64_FORMAT;
static const ObString USER_KEY_NAME;
static const ObString INPUT_NAME;
static const ObString DIMENSIONS_NAME;
static const int64_t HTTP_REQUEST_TIMEOUT; // 20 seconds
// Reschedule related constants
static const int64_t MAX_RESCHEDULE_RETRY_CNT;
static const int64_t RESCHEDULE_RETRY_INTERVAL_US;
// HTTP retry related constants
static const int64_t MAX_HTTP_RETRY_CNT;
static const int64_t HTTP_RETRY_BASE_INTERVAL_US;
static const int64_t HTTP_RETRY_MAX_INTERVAL_US;
static const int64_t HTTP_RETRY_MULTIPLIER;
// Callback related constants
static const int64_t CALLBACK_BATCH_SIZE;
private:
void reset();
bool is_finished() const; // Internal use only - no lock needed
void set_stop();
int set_phase(ObEmbeddingTaskPhase new_phase);
int complete_task(ObEmbeddingTaskPhase new_phase, int result_code, bool finished = true);
int start_async_work();
int check_async_progress();
int send_http_request_async(const char *json_data, int64_t json_len);
int process_http_response();
bool is_http_response_ready() const;
int init_http_request(const char *json_data, int64_t json_len);
int check_http_progress();
void cleanup_async_http();
void log_phase_transition(ObEmbeddingTaskPhase from_phase, ObEmbeddingTaskPhase to_phase);
int reschedule(ObEmbeddingTaskHandler *thread_pool);
int handle_reschedule_failure(ObEmbeddingTaskHandler *thread_pool, int error_code);
bool is_http_sent() const { return curl_request_in_progress_; }
void set_task_id(int64_t task_id) { task_id_ = task_id; }
int parse_embedding_response(const char *response_data, size_t response_size);
// Helper methods for retry logic
bool should_retry_http_request(int64_t http_error_code) const;
bool is_batch_size_related_error(int64_t http_error_code) const;
int64_t calculate_retry_interval() const;
int adjust_batch_size_for_retry();
void reset_retry_state();
int map_http_error_to_internal_error(int64_t http_error_code) const;
void try_increase_batch_size();
int init_curl_handler(const ObString &model_url, const ObString &user_key, const int64_t http_timeout_us);
struct HttpResponseData {
HttpResponseData(ObIAllocator &allocator) : data(nullptr), size(0), allocator(allocator) {}
~HttpResponseData() { reset(); }
void reset()
{
if (OB_NOT_NULL(data)) {
allocator.free(data);
data = nullptr;
size = 0;
}
}
char *data;
size_t size;
ObIAllocator &allocator;
};
static size_t WriteMemoryCallback(void *contents, size_t size, size_t nmemb, void *userp)
{
size_t realsize = size * nmemb;
struct HttpResponseData *mem = (struct HttpResponseData *)userp;
char *ptr = (char*)mem->allocator.alloc(mem->size + realsize + 1);
if (OB_ISNULL(ptr)) {
return 0;
} else {
if (OB_NOT_NULL(mem->data)) {
memcpy(ptr, mem->data, mem->size);
}
memcpy(&(ptr[mem->size]), contents, realsize);
mem->data = ptr;
mem->size += realsize;
mem->data[mem->size] = '\0';
}
return realsize;
}
private:
ObArenaAllocator local_allocator_;
ObArenaAllocator &allocator_;
// request model required fields
ObString model_url_;
ObString model_name_;
ObString provider_;
bool use_base64_format_;
ObString user_key_;
int64_t dimension_;
ObArray<ObString> input_chunks_;
ObArray<float*> output_vectors_;
// IO callback for progress notification
storage::ObEmbeddingIOCallbackHandle *cb_handle_;
// Callback tracking variables
int64_t process_callback_offset_;
// task status and progress
bool is_inited_;
int64_t tenant_id_;
int64_t task_id_;
ObEmbeddingTaskPhase phase_;
int64_t process_start_time_us_;
int64_t process_end_time_us_;
int64_t processed_chunks_;
int64_t total_chunks_;
int64_t http_send_count_;
int64_t http_error_code_;
ObString http_error_message_;
int64_t internal_error_code_;
ObString internal_error_message_;
// flow control
mutable ObLatch task_lock_;
uint32_t batch_size_;
int64_t current_batch_idx_;
// Async processing related members
int64_t http_timeout_us_;
int64_t http_send_time_us_; // Unified HTTP send time
int64_t http_response_data_size_;
char *http_response_data_;
// Async HTTP processing members
CURLM *curl_multi_handle_;
CURL *curl_easy_handle_;
bool curl_request_in_progress_;
HttpResponseData *curl_response_data_; // Store response data for async HTTP processing
struct curl_slist *curl_headers_; // Store HTTP headers for cleanup
// HTTP retry related members
int64_t http_retry_count_;
int64_t http_total_retry_count_;
int64_t http_retry_start_time_us_;
int64_t http_last_retry_time_us_;
int64_t http_max_retry_count_;
int64_t wait_for_completion_timeout_us_; // For controlling the maximum timeout of waiting for completion
bool need_retry_flag_;
// Batch size adjustment for retry
uint32_t original_batch_size_;
bool batch_size_adjusted_;
uint32_t current_batch_size_;
uint32_t successful_requests_count_;
ObThreadCond task_cond_;
bool callback_done_;
// TODO(fanfangyao.ffy): use taskhandle to manage task reference count
// ref_cnt_ is only used to track the reference count of the post create embedding task
int64_t ref_cnt_;
private:
DISALLOW_COPY_AND_ASSIGN(ObEmbeddingTask);
};
class ObEmbeddingTaskHandler : public lib::TGTaskHandler
{
public:
ObEmbeddingTaskHandler();
virtual ~ObEmbeddingTaskHandler();
int init();
int start();
void stop();
void wait();
void destroy();
virtual void handle(void *task) override;
virtual void handle_drop(void *task) override;
int push_task(ObEmbeddingTask &task);
int get_tg_id() { return tg_id_; }
void inc_task_ref() { ATOMIC_INC(&task_ref_cnt_); }
void dec_task_ref() { ATOMIC_DEC(&task_ref_cnt_); }
void inc_dropped_task_cnt() { ATOMIC_INC(&dropped_task_cnt_); }
int add_task_to_tracking(ObEmbeddingTask *task);
int remove_task_from_tracking(ObEmbeddingTask *task);
// wait for all tasks to finish with timeout
int wait_all_tasks_finished(int64_t timeout_us = ObEmbeddingTask::HTTP_REQUEST_TIMEOUT); // default 20s timeout
// force drop all remaining tasks and update system table
int force_drop_all_remaining_tasks();
// check if all tasks are finished
bool is_all_tasks_finished() const { return ATOMIC_LOAD(&task_ref_cnt_) == 0; }
// get current task count for monitoring
int64_t get_current_task_count() const { return ATOMIC_LOAD(&task_ref_cnt_); }
// get dropped task count for monitoring
int64_t get_dropped_task_count() const { return ATOMIC_LOAD(&dropped_task_cnt_); }
// get tracked task count for monitoring
int64_t get_tracked_task_count() const {
common::ObSpinLockGuard guard(const_cast<common::ObSpinLock&>(task_list_lock_));
return active_tasks_.count();
}
// get all active tasks for virtual table query
int get_all_active_tasks(common::ObArray<ObEmbeddingTask*> &task_list) const;
// reset statistics counters
void reset_statistics() {
ATOMIC_STORE(&task_ref_cnt_, 0);
ATOMIC_STORE(&dropped_task_cnt_, 0);
}
// get task reference count (alias for get_current_task_count for backward compatibility)
int64_t get_task_ref() const { return ATOMIC_LOAD(&task_ref_cnt_); }
bool is_inited() const { return is_inited_; }
public:
common::ObSpinLock lock_; // lock for init
public:
// push task max wait time: 1s * 10 = 10s
const static int64_t MAX_RETRY_PUSH_TASK_CNT = 10;
static const int64_t INVALID_TG_ID = -1;
// dynamic thread cnt, max cnt is THREAD_FACTOR * tenent_cpu_cnt
constexpr static const float THREAD_FACTOR = 0.6;
// 1s
const static int64_t WAIT_RETRY_PUSH_TASK_TIME = 1 * 1000 * 1000; // us
static const int64_t MIN_THREAD_COUNT = 1;
private:
bool is_inited_;
int tg_id_;
volatile int64_t task_ref_cnt_;
// track dropped tasks for monitoring
volatile int64_t dropped_task_cnt_;
// task tracking for force drop functionality
common::ObSpinLock task_list_lock_;
common::ObArray<ObEmbeddingTask*> active_tasks_;
private:
DISALLOW_COPY_AND_ASSIGN(ObEmbeddingTaskHandler);
};
}
}
#endif