Skip to content

Commit e65b738

Browse files
committed
enable bert lora adapter support
Signed-off-by: Huamin Chen <[email protected]>
1 parent 2723b8c commit e65b738

File tree

13 files changed

+2887
-0
lines changed

13 files changed

+2887
-0
lines changed

openvino-binding/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ set(SOURCES
103103
# Classifiers module
104104
cpp/src/classifiers/text_classifier.cpp
105105
cpp/src/classifiers/token_classifier.cpp
106+
cpp/src/classifiers/lora_adapter.cpp
107+
cpp/src/classifiers/lora_classifier.cpp
106108

107109
# Embeddings module
108110
cpp/src/embeddings/embedding_generator.cpp
@@ -123,6 +125,8 @@ set(HEADERS
123125
# Classifier headers
124126
cpp/include/classifiers/text_classifier.h
125127
cpp/include/classifiers/token_classifier.h
128+
cpp/include/classifiers/lora_adapter.h
129+
cpp/include/classifiers/lora_classifier.h
126130

127131
# Embedding headers
128132
cpp/include/embeddings/embedding_generator.h

openvino-binding/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ High-performance Go bindings for semantic routing using Intel® OpenVINO™ Tool
77
- 🚀 **High Performance**: Optimized inference with OpenVINO on Intel hardware
88
- 🔍 **Semantic Search**: BERT embeddings and cosine similarity
99
- 📊 **Classification**: Text classification with confidence scores
10+
- 🧩 **LoRA Adapter Support**: Parameter-efficient fine-tuning for BERT and ModernBERT
1011
- 🏷️ **Token Classification**: Named entity recognition and PII detection
1112
- 🔄 **Batch Processing**: Efficient batch similarity computation
1213
- 💻 **Multi-Device**: Support for CPU, GPU, VPU, and other Intel accelerators
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#pragma once
2+
3+
#include <openvino/openvino.hpp>
4+
#include <vector>
5+
#include <memory>
6+
#include <string>
7+
8+
namespace openvino_sr {
9+
namespace classifiers {
10+
11+
/**
12+
* @brief LoRA configuration
13+
*/
14+
struct LoRAConfig {
15+
size_t rank = 16; // LoRA rank
16+
double alpha = 32.0; // LoRA alpha for scaling
17+
double dropout = 0.1; // Dropout rate (used during training)
18+
bool use_bias = false; // Whether to use bias in LoRA layers
19+
20+
double get_scaling() const {
21+
return alpha / static_cast<double>(rank);
22+
}
23+
};
24+
25+
/**
26+
* @brief LoRA adapter for parameter-efficient fine-tuning
27+
*
28+
* Implements Low-Rank Adaptation by applying:
29+
* output = input + LoRA_B(LoRA_A(input)) * scaling
30+
*/
31+
class LoRAAdapter {
32+
public:
33+
LoRAAdapter() = default;
34+
35+
/**
36+
* @brief Load LoRA adapter from OpenVINO IR model
37+
* @param adapter_model_path Path to LoRA adapter model (.xml file)
38+
* @param config LoRA configuration
39+
* @param device Device name ("CPU", "GPU", etc.)
40+
* @return true if successful
41+
*/
42+
bool load(
43+
const std::string& adapter_model_path,
44+
const LoRAConfig& config,
45+
const std::string& device
46+
);
47+
48+
/**
49+
* @brief Apply LoRA adapter to input tensor
50+
* @param input Input tensor (pooled output from BERT/ModernBERT)
51+
* @return Output tensor after LoRA transformation
52+
*/
53+
ov::Tensor forward(const ov::Tensor& input);
54+
55+
/**
56+
* @brief Check if adapter is loaded
57+
*/
58+
bool isLoaded() const { return compiled_model_ != nullptr; }
59+
60+
/**
61+
* @brief Get LoRA configuration
62+
*/
63+
const LoRAConfig& getConfig() const { return config_; }
64+
65+
private:
66+
std::shared_ptr<ov::CompiledModel> compiled_model_;
67+
LoRAConfig config_;
68+
ov::InferRequest infer_request_;
69+
};
70+
71+
} // namespace classifiers
72+
} // namespace openvino_sr
73+
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#pragma once
2+
3+
#include "../core/types.h"
4+
#include "../core/tokenizer.h"
5+
#include "lora_adapter.h"
6+
#include <string>
7+
#include <memory>
8+
#include <mutex>
9+
#include <unordered_map>
10+
11+
namespace openvino_sr {
12+
namespace classifiers {
13+
14+
/**
15+
* @brief Task types for LoRA multi-task classification
16+
*/
17+
enum class TaskType {
18+
Intent,
19+
PII,
20+
Security,
21+
Classification
22+
};
23+
24+
/**
25+
* @brief Token-level prediction for token classification models
26+
*/
27+
struct TokenPrediction {
28+
std::string token; // The token text
29+
int class_id; // Predicted class ID
30+
float confidence; // Confidence score (0.0 to 1.0)
31+
};
32+
33+
/**
34+
* @brief Detected entity from BIO tagging
35+
*/
36+
struct DetectedEntity {
37+
std::string type; // Entity type (e.g., "EMAIL_ADDRESS", "PERSON")
38+
std::string text; // The detected entity text
39+
int start_token; // Start token index
40+
int end_token; // End token index (inclusive)
41+
float confidence; // Average confidence of tokens in entity
42+
};
43+
44+
/**
45+
* @brief Token classification result
46+
*/
47+
struct TokenClassificationResult {
48+
std::vector<TokenPrediction> token_predictions; // Per-token predictions
49+
std::vector<DetectedEntity> entities; // Detected entities (aggregated from BIO tags)
50+
float processing_time_ms; // Processing time in milliseconds
51+
};
52+
53+
/**
54+
* @brief LoRA-enabled classifier for BERT and ModernBERT
55+
*
56+
* Supports multi-task classification with parameter-efficient LoRA adapters.
57+
* Each task has its own LoRA adapter and classification head.
58+
*/
59+
class LoRAClassifier {
60+
public:
61+
LoRAClassifier() = default;
62+
63+
/**
64+
* @brief Initialize LoRA classifier with base model and adapters
65+
* @param base_model_path Path to base BERT/ModernBERT model (.xml file)
66+
* @param lora_adapters_path Path to directory containing LoRA adapter models
67+
* @param task_configs Map of task types to number of classes
68+
* @param device Device name ("CPU", "GPU", etc.)
69+
* @param model_type "bert" or "modernbert"
70+
* @return true if successful
71+
*/
72+
bool initialize(
73+
const std::string& base_model_path,
74+
const std::string& lora_adapters_path,
75+
const std::unordered_map<TaskType, int>& task_configs,
76+
const std::string& device = "CPU",
77+
const std::string& model_type = "bert"
78+
);
79+
80+
/**
81+
* @brief Classify text for a specific task (sequence classification)
82+
* @param text Input text
83+
* @param task Task type
84+
* @return Classification result
85+
*/
86+
core::ClassificationResult classifyTask(const std::string& text, TaskType task);
87+
88+
/**
89+
* @brief Classify tokens for token-level classification (e.g., NER, PII detection)
90+
* @param text Input text
91+
* @param task Task type (should be PII or similar token classification task)
92+
* @return Token classification result with per-token predictions and detected entities
93+
*/
94+
TokenClassificationResult classifyTokens(const std::string& text, TaskType task);
95+
96+
/**
97+
* @brief Check if initialized
98+
*/
99+
bool isInitialized() const {
100+
return base_model_ && base_model_->compiled_model != nullptr;
101+
}
102+
103+
/**
104+
* @brief Get supported tasks
105+
*/
106+
std::vector<TaskType> getSupportedTasks() const;
107+
108+
private:
109+
/**
110+
* @brief Get pooled output from base model
111+
*/
112+
ov::Tensor getPooledOutput(const std::string& text);
113+
114+
/**
115+
* @brief Apply task-specific LoRA adapter and classification head
116+
*/
117+
core::ClassificationResult applyLoRAAndClassify(
118+
const ov::Tensor& pooled_output,
119+
TaskType task
120+
);
121+
122+
/**
123+
* @brief Load task-specific LoRA adapter and classification head
124+
*/
125+
bool loadTaskAdapter(
126+
const std::string& lora_adapters_path,
127+
TaskType task,
128+
int num_classes,
129+
const std::string& device
130+
);
131+
132+
/**
133+
* @brief Get task name as string
134+
*/
135+
std::string getTaskName(TaskType task) const;
136+
137+
/**
138+
* @brief Get maximum sequence length for the model type
139+
* @return Max sequence length (8192 for ModernBERT, 512 for BERT)
140+
*/
141+
int getMaxSequenceLength() const;
142+
143+
/**
144+
* @brief Aggregate BIO tags into detected entities
145+
* @param original_text The original input text
146+
* @param tokens Vector of token strings
147+
* @param predictions Vector of token predictions
148+
* @param labels Map of class IDs to label names
149+
* @return Vector of detected entities
150+
*/
151+
std::vector<DetectedEntity> aggregateBIOTags(
152+
const std::string& original_text,
153+
const std::vector<std::string>& tokens,
154+
const std::vector<TokenPrediction>& predictions,
155+
const std::unordered_map<int, std::string>& labels
156+
) const;
157+
158+
/**
159+
* @brief Load label mapping from JSON file
160+
* @param adapters_path Path to adapters directory containing label_mapping.json
161+
* @return Map of class IDs to label names
162+
*/
163+
std::unordered_map<int, std::string> loadLabelMapping(const std::string& adapters_path) const;
164+
165+
std::shared_ptr<core::ModelInstance> base_model_; // Frozen base model
166+
std::unordered_map<TaskType, LoRAAdapter> lora_adapters_; // Task-specific LoRA adapters
167+
std::unordered_map<TaskType, std::shared_ptr<ov::CompiledModel>> task_heads_; // Classification heads
168+
std::unordered_map<TaskType, int> task_num_classes_; // Number of classes per task
169+
std::string adapters_path_; // Path to adapters directory
170+
core::OVNativeTokenizer tokenizer_;
171+
std::mutex mutex_;
172+
std::string model_type_; // "bert" or "modernbert"
173+
};
174+
175+
} // namespace classifiers
176+
} // namespace openvino_sr
177+

openvino-binding/cpp/include/openvino_semantic_router.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,91 @@ OVTokenClassificationResult ov_classify_modernbert_tokens(const char* text, cons
356356
*/
357357
OVEmbeddingResult ov_get_modernbert_embedding(const char* text, int max_length);
358358

359+
// ================================================================================================
360+
// LORA ADAPTER SUPPORT (BERT AND MODERNBERT)
361+
// ================================================================================================
362+
363+
/**
364+
* @brief Task type enumeration for LoRA multi-task classification
365+
*/
366+
typedef enum {
367+
OV_TASK_INTENT = 0,
368+
OV_TASK_PII = 1,
369+
OV_TASK_SECURITY = 2,
370+
OV_TASK_CLASSIFICATION = 3
371+
} OVTaskType;
372+
373+
374+
/**
375+
* @brief Initialize BERT LoRA classifier
376+
* @param base_model_path Path to base BERT model (.xml file)
377+
* @param lora_adapters_path Path to directory containing LoRA adapter models
378+
* @param device Device name ("CPU", "GPU", etc.)
379+
* @return true if initialization succeeded, false otherwise
380+
*/
381+
bool ov_init_bert_lora_classifier(
382+
const char* base_model_path,
383+
const char* lora_adapters_path,
384+
const char* device
385+
);
386+
387+
/**
388+
* @brief Check if BERT LoRA classifier is initialized
389+
* @return true if initialized, false otherwise
390+
*/
391+
bool ov_is_bert_lora_classifier_initialized();
392+
393+
/**
394+
* @brief Initialize ModernBERT LoRA classifier
395+
* @param base_model_path Path to base ModernBERT model (.xml file)
396+
* @param lora_adapters_path Path to directory containing LoRA adapter models
397+
* @param device Device name ("CPU", "GPU", etc.)
398+
* @return true if initialization succeeded, false otherwise
399+
*/
400+
bool ov_init_modernbert_lora_classifier(
401+
const char* base_model_path,
402+
const char* lora_adapters_path,
403+
const char* device
404+
);
405+
406+
/**
407+
* @brief Check if ModernBERT LoRA classifier is initialized
408+
* @return true if initialized, false otherwise
409+
*/
410+
bool ov_is_modernbert_lora_classifier_initialized();
411+
412+
/**
413+
* @brief Classify text using BERT LoRA adapter for a specific task
414+
* @param text Input text
415+
* @param task Task type
416+
* @return Classification result
417+
*/
418+
OVClassificationResult ov_classify_bert_lora_task(const char* text, OVTaskType task);
419+
420+
/**
421+
* @brief Classify text using ModernBERT LoRA adapter for a specific task
422+
* @param text Input text
423+
* @param task Task type
424+
* @return Classification result
425+
*/
426+
OVClassificationResult ov_classify_modernbert_lora_task(const char* text, OVTaskType task);
427+
428+
/**
429+
* @brief Token classification using BERT LoRA (for PII detection, NER, etc.)
430+
* @param text Input text
431+
* @param task Task type (should be PII or similar token classification task)
432+
* @return Token classification result (caller must free using ov_free_token_classification_result)
433+
*/
434+
OVTokenClassificationResult ov_classify_bert_lora_tokens(const char* text, OVTaskType task);
435+
436+
/**
437+
* @brief Token classification using ModernBERT LoRA (for PII detection, NER, etc.)
438+
* @param text Input text
439+
* @param task Task type (should be PII or similar token classification task)
440+
* @return Token classification result (caller must free using ov_free_token_classification_result)
441+
*/
442+
OVTokenClassificationResult ov_classify_modernbert_lora_tokens(const char* text, OVTaskType task);
443+
359444
// ================================================================================================
360445
// UTILITY FUNCTIONS
361446
// ================================================================================================

0 commit comments

Comments
 (0)