Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions candle-binding/src/core/config_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,32 @@ impl Default for RouterConfig {
}
}

/// Reinforcement Learning configuration for classifier training
#[derive(Debug, Clone)]
pub struct RLConfig {
pub enabled: bool,
pub algorithm: String, // e.g., "ppo", "a2c", "dqn"
pub learning_rate: f32,
pub gamma: f32,
pub batch_size: usize,
pub update_epochs: usize,
pub reward_metric: String, // e.g., "accuracy", "f1", "custom"
}

impl Default for RLConfig {
fn default() -> Self {
Self {
enabled: false,
algorithm: "ppo".to_string(),
learning_rate: 1e-5,
gamma: 0.99,
batch_size: 16,
update_epochs: 4,
reward_metric: "accuracy".to_string(),
}
}
}

impl GlobalConfigLoader {
/// Load router configuration from config/config.yaml
pub fn load_router_config() -> Result<RouterConfig, UnifiedError> {
Expand Down Expand Up @@ -663,6 +689,60 @@ impl GlobalConfigLoader {
Ok(router_config)
}

/// Load RL configuration for classifier training from config/config.yaml
pub fn load_classifier_rl_config() -> Result<RLConfig, UnifiedError> {
let config_path = "config/config.yaml";
let config_str = std::fs::read_to_string(config_path)
.map_err(|_| config_errors::file_not_found(config_path))?;

let mut rl_config = RLConfig::default();

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "enabled"]) {
if let Ok(b) = value.parse::<bool>() {
rl_config.enabled = b;
}
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "algorithm"]) {
rl_config.algorithm = value;
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "learning_rate"]) {
if let Ok(lr) = value.parse::<f32>() {
rl_config.learning_rate = lr;
}
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "gamma"]) {
if let Ok(g) = value.parse::<f32>() {
rl_config.gamma = g;
}
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "batch_size"]) {
if let Ok(bs) = value.parse::<usize>() {
rl_config.batch_size = bs;
}
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "update_epochs"]) {
if let Ok(ep) = value.parse::<usize>() {
rl_config.update_epochs = ep;
}
}

if let Some(value) = Self::extract_yaml_value(&config_str, &["classifier", "rl_training", "reward_metric"]) {
rl_config.reward_metric = value;
}

Ok(rl_config)
}

/// Safe loader for RL config
pub fn load_classifier_rl_config_safe() -> RLConfig {
Self::load_classifier_rl_config().unwrap_or_default()
}

/// Load router configuration with fallback to defaults
pub fn load_router_config_safe() -> RouterConfig {
Self::load_router_config().unwrap_or_default()
Expand Down
10 changes: 10 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ classifier:
use_cpu: true
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

# Optional: Reinforcement Learning options for classifier model training
rl_training:
enabled: false # Enable RL fine-tuning for classifiers
algorithm: "ppo" # Algorithm: ppo | a2c | dqn
learning_rate: 1e-05
gamma: 0.99
batch_size: 16
update_epochs: 4
reward_metric: "accuracy" # Metric used to compute reward (accuracy|f1|custom)

# Categories with new use_reasoning field structure
categories:
- name: business
Expand Down
Loading
Loading