diff --git a/candle-binding/src/classifiers/lora/parallel_engine_test.rs b/candle-binding/src/classifiers/lora/parallel_engine_test.rs index b7ba8c3d..1f11e9ee 100644 --- a/candle-binding/src/classifiers/lora/parallel_engine_test.rs +++ b/candle-binding/src/classifiers/lora/parallel_engine_test.rs @@ -1,6 +1,7 @@ //! Tests for Parallel LoRA Engine with performance benchmarks use crate::test_fixtures::fixtures::*; +use rayon::prelude::*; use rstest::*; use serial_test::serial; use std::sync::Arc; @@ -207,19 +208,14 @@ fn test_performance_concurrent_requests( println!("\n🔢 Testing with {} concurrent requests", num_threads); let start = Instant::now(); - let handles: Vec<_> = (0..*num_threads) - .map(|_| { - let classifier = Arc::clone(classifier); - std::thread::spawn(move || classifier.classify_intent(test_text)) - }) + + // Use rayon for parallel execution - simpler and more efficient + let results: Vec<_> = (0..*num_threads) + .into_par_iter() + .map(|_| classifier.classify_intent(test_text)) .collect(); - let mut success_count = 0; - for handle in handles { - if handle.join().is_ok() { - success_count += 1; - } - } + let success_count = results.iter().filter(|r| r.is_ok()).count(); let duration = start.elapsed(); println!( diff --git a/candle-binding/src/core/mod.rs b/candle-binding/src/core/mod.rs index 013d1263..5f24560b 100644 --- a/candle-binding/src/core/mod.rs +++ b/candle-binding/src/core/mod.rs @@ -34,4 +34,8 @@ pub use tokenization::{ #[cfg(test)] pub mod config_loader_test; #[cfg(test)] +pub mod similarity_test; +#[cfg(test)] +pub mod tokenization_test; +#[cfg(test)] pub mod unified_error_test; diff --git a/candle-binding/src/core/similarity_test.rs b/candle-binding/src/core/similarity_test.rs new file mode 100644 index 00000000..a0a0480b --- /dev/null +++ b/candle-binding/src/core/similarity_test.rs @@ -0,0 +1,452 @@ +//! Tests for core similarity module + +use super::similarity::*; +use candle_core::{Device, Tensor}; +use rayon::prelude::*; +use rstest::*; +use std::path::PathBuf; + +// Test model paths +const TEST_MODEL_BASE: &str = "../models"; +const BERT_MODEL: &str = "lora_intent_classifier_bert-base-uncased_model"; + +/// Fixture to create a BertSimilarity instance +#[fixture] +fn bert_similarity() -> BertSimilarity { + let model_path = PathBuf::from(TEST_MODEL_BASE).join(BERT_MODEL); + + if model_path.exists() { + BertSimilarity::new(model_path.to_str().unwrap(), true) + .expect("Failed to create BertSimilarity") + } else { + // Skip test if model not available + panic!("Test model not found at {:?}", model_path); + } +} + +// ============================================================================ +// Initialization Tests +// ============================================================================ + +#[rstest] +fn test_bert_similarity_new(bert_similarity: BertSimilarity) { + assert!(bert_similarity.device().is_cpu(), "Should use CPU device"); +} + +#[rstest] +fn test_bert_similarity_tokenizer(bert_similarity: BertSimilarity) { + let tokenizer = bert_similarity.tokenizer(); + assert!( + tokenizer.get_vocab_size(true) > 0, + "Tokenizer should have vocabulary" + ); +} + +#[rstest] +fn test_bert_similarity_is_gpu(bert_similarity: BertSimilarity) { + assert!(!bert_similarity.is_gpu(), "Should be using CPU"); +} + +// ============================================================================ +// Tokenization Tests +// ============================================================================ + +#[rstest] +fn test_tokenize_text_basic(bert_similarity: BertSimilarity) { + let text = "Hello, world!"; + let result = bert_similarity.tokenize_text(text, None); + + assert!(result.is_ok(), "Should tokenize simple text"); + + let (token_ids, tokens) = result.unwrap(); + assert!(!token_ids.is_empty(), "Token IDs should not be empty"); + assert!(!tokens.is_empty(), "Tokens should not be empty"); +} + +#[rstest] +fn test_tokenize_text_empty(bert_similarity: BertSimilarity) { + let text = ""; + let result = bert_similarity.tokenize_text(text, None); + + assert!(result.is_ok(), "Should handle empty text"); +} + +#[rstest] +#[case("Simple text", None)] +#[case( + "A longer text that might need truncation when the max length is set", + Some(20) +)] +#[case("Short", Some(512))] +fn test_tokenize_text_with_max_length( + bert_similarity: BertSimilarity, + #[case] text: &str, + #[case] max_length: Option, +) { + let result = bert_similarity.tokenize_text(text, max_length); + + assert!( + result.is_ok(), + "Should tokenize text with max_length {:?}", + max_length + ); + + let (token_ids, _tokens) = result.unwrap(); + + if let Some(max_len) = max_length { + assert!( + token_ids.len() <= max_len, + "Token IDs length should be <= max_length" + ); + } +} + +// ============================================================================ +// Embedding Generation Tests +// ============================================================================ + +#[rstest] +fn test_get_embedding(bert_similarity: BertSimilarity) { + let text = "This is a test sentence for embedding."; + let result = bert_similarity.get_embedding(text, None); + + assert!(result.is_ok(), "Should generate embedding"); + + let embedding = result.unwrap(); + let dims = embedding.dims(); + // get_embedding returns [batch_size, hidden_dim] = [1, 768] + assert_eq!( + dims.len(), + 2, + "Embedding should be 2D tensor (batch format)" + ); + assert_eq!(dims[0], 1, "Batch size should be 1"); + assert!(dims[1] > 0, "Hidden dimension should be positive"); +} + +#[rstest] +fn test_get_embedding_consistency(bert_similarity: BertSimilarity) { + let text = "Consistency test sentence."; + + // Generate embedding twice + let embedding1 = bert_similarity + .get_embedding(text, None) + .expect("First embedding"); + let embedding2 = bert_similarity + .get_embedding(text, None) + .expect("Second embedding"); + + // Should produce identical embeddings for same input + assert_eq!( + embedding1.dims(), + embedding2.dims(), + "Embeddings should have same dimensions" + ); + + // Convert to Vec for comparison (squeeze batch dimension) + let vec1: Vec = embedding1 + .squeeze(0) + .expect("Squeeze") + .to_vec1() + .expect("Convert to vec1"); + let vec2: Vec = embedding2 + .squeeze(0) + .expect("Squeeze") + .to_vec1() + .expect("Convert to vec2"); + + for (i, (v1, v2)) in vec1.iter().zip(vec2.iter()).enumerate() { + assert!( + (v1 - v2).abs() < 1e-6, + "Embeddings should be identical at position {}: {} vs {}", + i, + v1, + v2 + ); + } +} + +#[rstest] +fn test_get_embedding_different_texts(bert_similarity: BertSimilarity) { + let text1 = "The cat sits on the mat."; + let text2 = "A dog runs in the park."; + + let embedding1 = bert_similarity + .get_embedding(text1, None) + .expect("First embedding"); + let embedding2 = bert_similarity + .get_embedding(text2, None) + .expect("Second embedding"); + + // Embeddings should be different for different texts (squeeze batch dimension) + let vec1: Vec = embedding1 + .squeeze(0) + .expect("Squeeze") + .to_vec1() + .expect("Convert to vec1"); + let vec2: Vec = embedding2 + .squeeze(0) + .expect("Squeeze") + .to_vec1() + .expect("Convert to vec2"); + + let mut differences = 0; + for (v1, v2) in vec1.iter().zip(vec2.iter()) { + if (v1 - v2).abs() > 1e-6 { + differences += 1; + } + } + + assert!( + differences > vec1.len() / 10, + "Embeddings should be substantially different (found {} differences out of {})", + differences, + vec1.len() + ); +} + +#[rstest] +fn test_get_embedding_with_max_length(bert_similarity: BertSimilarity) { + let long_text = "This is a very long text that will be truncated. ".repeat(20); + let result = bert_similarity.get_embedding(&long_text, Some(128)); + + assert!(result.is_ok(), "Should generate embedding with max_length"); +} + +// ============================================================================ +// Similarity Calculation Tests +// ============================================================================ + +#[rstest] +fn test_calculate_similarity_identical(bert_similarity: BertSimilarity) { + let text = "Identical text"; + + let similarity = bert_similarity + .calculate_similarity(text, text, None) + .expect("Calculate similarity"); + + assert!( + (similarity - 1.0).abs() < 0.01, + "Identical text should have similarity ~1.0, got {}", + similarity + ); +} + +#[rstest] +fn test_calculate_similarity_similar_texts(bert_similarity: BertSimilarity) { + let text1 = "Machine learning is fascinating."; + let text2 = "AI and machine learning are interesting."; + + let similarity = bert_similarity + .calculate_similarity(text1, text2, None) + .expect("Calculate similarity"); + + assert!( + similarity > 0.3, + "Similar texts should have reasonable similarity, got {}", + similarity + ); +} + +#[rstest] +fn test_calculate_similarity_dissimilar_texts(bert_similarity: BertSimilarity) { + let text1 = "The weather is sunny today."; + let text2 = "Quantum physics is complex."; + + let similarity = bert_similarity + .calculate_similarity(text1, text2, None) + .expect("Calculate similarity"); + + assert!( + similarity < 0.9 && similarity > -1.0, + "Dissimilar texts should have lower similarity, got {}", + similarity + ); +} + +#[rstest] +#[case("Hello", "Hi", 0.0)] // Should be somewhat similar +#[case("Cat", "Dog", 0.0)] // Should be somewhat similar (both animals) +#[case("Apple", "Computer", -1.0)] // Can vary greatly +fn test_calculate_similarity_various_pairs( + bert_similarity: BertSimilarity, + #[case] text1: &str, + #[case] text2: &str, + #[case] min_similarity: f32, +) { + let similarity = bert_similarity + .calculate_similarity(text1, text2, None) + .expect("Calculate similarity"); + + assert!( + similarity >= min_similarity && similarity <= 1.0, + "Similarity should be between {} and 1.0, got {}", + min_similarity, + similarity + ); +} + +// ============================================================================ +// Most Similar Finding Tests +// ============================================================================ + +#[rstest] +fn test_find_most_similar(bert_similarity: BertSimilarity) { + let query = "Machine learning algorithms"; + let candidates = vec![ + "AI and deep learning", + "Cooking recipes", + "Neural networks", + "Weather forecast", + ]; + + let result = bert_similarity.find_most_similar(query, &candidates, None); + + assert!(result.is_ok(), "Should find most similar"); + + let (most_similar_idx, similarity) = result.unwrap(); + + // Should find either "AI and deep learning" (0) or "Neural networks" (2) + assert!( + most_similar_idx == 0 || most_similar_idx == 2, + "Should find AI-related text, got index {}", + most_similar_idx + ); + + assert!( + similarity > 0.3, + "Similarity should be reasonably high, got {}", + similarity + ); +} + +#[rstest] +fn test_find_most_similar_single_candidate(bert_similarity: BertSimilarity) { + let query = "Test query"; + let candidates = vec!["Single candidate"]; + + let result = bert_similarity.find_most_similar(query, &candidates, None); + + assert!(result.is_ok(), "Should handle single candidate"); + + let (most_similar_idx, _) = result.unwrap(); + assert_eq!(most_similar_idx, 0, "Should return the only candidate"); +} + +#[rstest] +fn test_find_most_similar_with_max_length(bert_similarity: BertSimilarity) { + let query = "Short query"; + let long_text = "This is a very long candidate text that will be truncated. ".repeat(10); + let candidates_data = vec![long_text.as_str(), "Short match"]; + + let result = bert_similarity.find_most_similar(query, &candidates_data, Some(64)); + + assert!(result.is_ok(), "Should handle max_length parameter"); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[test] +fn test_new_with_invalid_path() { + let result = BertSimilarity::new("/nonexistent/path", true); + assert!(result.is_err(), "Should fail with invalid path"); +} + +#[rstest] +fn test_find_most_similar_empty_candidates(bert_similarity: BertSimilarity) { + let query = "Test query"; + let candidates: Vec<&str> = vec![]; + + let result = bert_similarity.find_most_similar(query, &candidates, None); + + // Depending on implementation, this might error or return None + // Adjust assertion based on actual behavior + assert!( + result.is_err() || result.unwrap().1 == 0.0, + "Should handle empty candidates" + ); +} + +// ============================================================================ +// L2 Normalization Tests +// ============================================================================ + +#[test] +fn test_normalize_l2() { + let device = Device::Cpu; + let data = vec![3.0_f32, 4.0_f32]; // L2 norm = 5.0 + // normalize_l2 expects 2D tensor (batch format: [batch_size, dim]) + let tensor = Tensor::from_slice(&data, (1, 2), &device).expect("Create tensor"); + + let normalized = normalize_l2(&tensor).expect("Normalize"); + let vec: Vec = normalized + .squeeze(0) + .expect("Squeeze") + .to_vec1() + .expect("To vec"); + + // After normalization: [3/5, 4/5] = [0.6, 0.8] + assert!( + (vec[0] - 0.6).abs() < 0.01, + "First component should be ~0.6" + ); + assert!( + (vec[1] - 0.8).abs() < 0.01, + "Second component should be ~0.8" + ); + + // Check L2 norm is 1.0 + let l2_norm: f32 = vec.iter().map(|x| x * x).sum::().sqrt(); + assert!((l2_norm - 1.0).abs() < 0.01, "L2 norm should be ~1.0"); +} + +#[test] +fn test_normalize_l2_zero_vector() { + let device = Device::Cpu; + let data = vec![0.0_f32, 0.0_f32]; + let tensor = Tensor::from_slice(&data, 2, &device).expect("Create tensor"); + + let result = normalize_l2(&tensor); + + // Should handle zero vector gracefully (either error or return zeros) + match result { + Ok(normalized) => { + let vec: Vec = normalized.to_vec1().expect("To vec"); + assert!( + vec.iter().all(|x| x.is_nan() || *x == 0.0), + "Should handle zero vector" + ); + } + Err(_) => { + // Also acceptable to return an error + } + } +} + +// ============================================================================ +// Concurrency Tests +// ============================================================================ + +#[rstest] +fn test_bert_similarity_thread_safety(bert_similarity: BertSimilarity) { + use std::sync::Arc; + + let similarity = Arc::new(bert_similarity); + + // Use rayon for parallel execution - simpler and more efficient + let embeddings: Vec<_> = (0..4) + .into_par_iter() + .map(|i| { + let text = format!("Thread {} test text", i); + similarity + .get_embedding(&text, None) + .expect("Generate embedding in thread") + }) + .collect(); + + for embedding in embeddings { + assert!(embedding.dims()[0] > 0, "Should generate valid embedding"); + } +} diff --git a/candle-binding/src/core/tokenization_test.rs b/candle-binding/src/core/tokenization_test.rs new file mode 100644 index 00000000..1bf8f0b3 --- /dev/null +++ b/candle-binding/src/core/tokenization_test.rs @@ -0,0 +1,401 @@ +//! Tests for core tokenization module + +use super::tokenization::*; +use candle_core::Device; +use rayon::prelude::*; +use rstest::*; +use std::path::PathBuf; +use tokenizers::{TruncationDirection, TruncationStrategy}; + +// Test model paths +const TEST_MODEL_BASE: &str = "../models"; +const BERT_MODEL: &str = "lora_intent_classifier_bert-base-uncased_model"; + +/// Fixture to create a UnifiedTokenizer instance +#[fixture] +fn unified_tokenizer() -> UnifiedTokenizer { + let model_path = PathBuf::from(TEST_MODEL_BASE).join(BERT_MODEL); + let tokenizer_path = model_path.join("tokenizer.json"); + + if tokenizer_path.exists() { + UnifiedTokenizer::from_file( + tokenizer_path.to_str().unwrap(), + TokenizationStrategy::BERT, + Device::Cpu, + ) + .expect("Failed to create UnifiedTokenizer") + } else { + // Skip test if tokenizer not available + panic!("Test tokenizer not found at {:?}", tokenizer_path); + } +} + +// ============================================================================ +// Configuration Tests +// ============================================================================ + +#[rstest] +fn test_tokenization_config_default() { + let config = TokenizationConfig::default(); + + assert_eq!(config.max_length, 512); + assert!(config.add_special_tokens); + assert_eq!(config.truncation_strategy, TruncationStrategy::LongestFirst); + assert_eq!(config.pad_token_id, 0); + assert_eq!(config.tokenization_strategy, TokenizationStrategy::BERT); + assert_eq!(config.token_data_type, TokenDataType::I32); +} + +#[rstest] +fn test_tokenization_config_custom() { + let config = TokenizationConfig { + max_length: 256, + add_special_tokens: false, + truncation_strategy: TruncationStrategy::OnlyFirst, + truncation_direction: TruncationDirection::Left, + pad_token_id: 1, + pad_token: "".to_string(), + tokenization_strategy: TokenizationStrategy::ModernBERT, + token_data_type: TokenDataType::U32, + }; + + assert_eq!(config.max_length, 256); + assert!(!config.add_special_tokens); + assert_eq!( + config.tokenization_strategy, + TokenizationStrategy::ModernBERT + ); + assert_eq!(config.token_data_type, TokenDataType::U32); +} + +// ============================================================================ +// UnifiedTokenizer Tests +// ============================================================================ + +#[rstest] +fn test_unified_tokenizer_new(unified_tokenizer: UnifiedTokenizer) { + // UnifiedTokenizer should be created successfully + // We can't access config directly (private field), but we can test functionality + let result = unified_tokenizer.tokenize("test"); + assert!(result.is_ok(), "Tokenizer should work"); +} + +#[rstest] +fn test_tokenize_basic(unified_tokenizer: UnifiedTokenizer) { + let text = "Hello, world!"; + let result = unified_tokenizer.tokenize(text); + + assert!(result.is_ok(), "Should tokenize simple text"); + + let tokenization_result = result.unwrap(); + assert!( + !tokenization_result.token_ids.is_empty(), + "Should have token IDs" + ); + assert_eq!( + tokenization_result.token_ids.len(), + tokenization_result.attention_mask.len(), + "Token IDs and attention mask should have same length" + ); +} + +#[rstest] +fn test_tokenize_empty(unified_tokenizer: UnifiedTokenizer) { + let text = ""; + let result = unified_tokenizer.tokenize(text); + + assert!(result.is_ok(), "Should handle empty text"); +} + +#[rstest] +#[case("Simple text")] +#[case("A longer text that needs to be tokenized properly")] +#[case("Short")] +fn test_tokenize_various_texts(unified_tokenizer: UnifiedTokenizer, #[case] text: &str) { + let result = unified_tokenizer.tokenize(text); + + assert!(result.is_ok(), "Should tokenize: {}", text); + + let tokenization_result = result.unwrap(); + assert!(!tokenization_result.tokens.is_empty(), "Should have tokens"); +} + +// ============================================================================ +// Batch Tokenization Tests +// ============================================================================ + +#[rstest] +fn test_tokenize_batch_basic(unified_tokenizer: UnifiedTokenizer) { + let texts = vec!["First text", "Second text", "Third text"]; + let result = unified_tokenizer.tokenize_batch(&texts); + + assert!(result.is_ok(), "Should tokenize batch"); + + let batch_result = result.unwrap(); + assert_eq!(batch_result.batch_size, 3, "Should have 3 texts"); + assert_eq!( + batch_result.token_ids.len(), + 3, + "Should have 3 tokenizations" + ); + assert!(batch_result.max_length > 0, "Max length should be positive"); +} + +#[rstest] +fn test_tokenize_batch_empty(unified_tokenizer: UnifiedTokenizer) { + let texts: Vec<&str> = vec![]; + let result = unified_tokenizer.tokenize_batch(&texts); + + // Should either handle gracefully or return error + match result { + Ok(batch_result) => { + assert_eq!(batch_result.batch_size, 0, "Should have 0 texts"); + } + Err(_) => { + // Also acceptable to return error + } + } +} + +#[rstest] +fn test_tokenize_batch_varying_lengths(unified_tokenizer: UnifiedTokenizer) { + let texts = vec![ + "Short", + "A medium length text here", + "This is a much longer text that will have more tokens after tokenization", + ]; + let result = unified_tokenizer.tokenize_batch(&texts); + + assert!(result.is_ok(), "Should tokenize varying length texts"); + + let batch_result = result.unwrap(); + assert_eq!(batch_result.batch_size, 3); + + // All tokenizations should be padded to max_length + for token_ids in &batch_result.token_ids { + assert_eq!(token_ids.len(), batch_result.max_length); + } +} + +// ============================================================================ +// Traditional Tokenization Tests +// ============================================================================ + +#[rstest] +fn test_tokenize_for_traditional(unified_tokenizer: UnifiedTokenizer) { + let text = "Traditional tokenization test"; + let result = unified_tokenizer.tokenize_for_traditional(text); + + assert!(result.is_ok(), "Should tokenize for traditional path"); + + let tokenization_result = result.unwrap(); + assert!(!tokenization_result.token_ids.is_empty()); +} + +// ============================================================================ +// LoRA Tokenization Tests +// ============================================================================ + +#[rstest] +fn test_tokenize_for_lora(unified_tokenizer: UnifiedTokenizer) { + let text = "LoRA tokenization test"; + let result = unified_tokenizer.tokenize_for_lora(text); + + assert!(result.is_ok(), "Should tokenize for LoRA path"); + + let tokenization_result = result.unwrap(); + assert!(!tokenization_result.token_ids.is_empty()); +} + +// ============================================================================ +// Tensor Creation Tests +// ============================================================================ + +#[rstest] +fn test_create_tensors(unified_tokenizer: UnifiedTokenizer) { + let text = "Test for tensor creation"; + let tokenization_result = unified_tokenizer.tokenize(text).expect("Tokenize text"); + + let result = unified_tokenizer.create_tensors(&tokenization_result); + + assert!(result.is_ok(), "Should create tensors"); + + let (token_ids_tensor, attention_mask_tensor) = result.unwrap(); + assert_eq!(token_ids_tensor.dims().len(), 2, "Token IDs should be 2D"); + assert_eq!( + attention_mask_tensor.dims().len(), + 2, + "Attention mask should be 2D" + ); + assert_eq!( + token_ids_tensor.dims()[1], + attention_mask_tensor.dims()[1], + "Tensors should have same sequence length" + ); +} + +#[rstest] +fn test_create_batch_tensors(unified_tokenizer: UnifiedTokenizer) { + let texts = vec!["First", "Second", "Third"]; + let batch_result = unified_tokenizer + .tokenize_batch(&texts) + .expect("Tokenize batch"); + + let result = unified_tokenizer.create_batch_tensors(&batch_result); + + assert!(result.is_ok(), "Should create batch tensors"); + + let (token_ids_tensor, attention_mask_tensor) = result.unwrap(); + let dims = token_ids_tensor.dims(); + + assert_eq!(dims.len(), 2, "Should be 2D tensor"); + assert_eq!(dims[0], 3, "Batch size should be 3"); + assert_eq!( + token_ids_tensor.dims(), + attention_mask_tensor.dims(), + "Tensors should have same dimensions" + ); +} + +// ============================================================================ +// Smart Batch Tokenization Tests +// ============================================================================ + +#[rstest] +#[case(true, "Should prefer LoRA")] +#[case(false, "Should not prefer LoRA")] +fn test_tokenize_batch_smart( + unified_tokenizer: UnifiedTokenizer, + #[case] prefer_lora: bool, + #[case] description: &str, +) { + let texts = vec!["Text one", "Text two"]; + let result = unified_tokenizer.tokenize_batch_smart(&texts, prefer_lora); + + assert!(result.is_ok(), "{}", description); + + let batch_result = result.unwrap(); + assert_eq!(batch_result.batch_size, 2); +} + +// ============================================================================ +// Helper Function Tests +// ============================================================================ + +#[test] +fn test_create_tokenizer() { + let model_path = PathBuf::from(TEST_MODEL_BASE).join(BERT_MODEL); + let tokenizer_path = model_path.join("tokenizer.json"); + + if !tokenizer_path.exists() { + return; // Skip test if model not available + } + + let result = create_tokenizer( + tokenizer_path.to_str().unwrap(), + TokenizationStrategy::BERT, + Device::Cpu, + ); + assert!(result.is_ok(), "Should create tokenizer from path"); +} + +#[test] +fn test_detect_tokenization_strategy() { + let model_path = PathBuf::from(TEST_MODEL_BASE).join(BERT_MODEL); + let tokenizer_path = model_path.join("tokenizer.json"); + + if !tokenizer_path.exists() { + return; // Skip test if model not available + } + + let result = detect_tokenization_strategy(tokenizer_path.to_str().unwrap()); + assert!(result.is_ok(), "Should detect tokenization strategy"); +} + +// ============================================================================ +// Compatibility Tokenizer Tests +// ============================================================================ + +#[test] +fn test_create_bert_compatibility_tokenizer() { + use tokenizers::Tokenizer; + + let model_path = PathBuf::from(TEST_MODEL_BASE).join(BERT_MODEL); + let tokenizer_path = model_path.join("tokenizer.json"); + + if !tokenizer_path.exists() { + return; + } + + let tokenizer = Tokenizer::from_file(tokenizer_path).expect("Load tokenizer"); + + let result = create_bert_compatibility_tokenizer(tokenizer, Device::Cpu); + + assert!(result.is_ok(), "Should create BERT compatibility tokenizer"); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[test] +fn test_create_tokenizer_invalid_path() { + let result = create_tokenizer( + "/nonexistent/tokenizer.json", + TokenizationStrategy::BERT, + Device::Cpu, + ); + assert!(result.is_err(), "Should fail with invalid path"); +} + +#[test] +fn test_detect_strategy_invalid_path() { + let result = detect_tokenization_strategy("/nonexistent/tokenizer.json"); + assert!(result.is_err(), "Should fail with invalid path"); +} + +// ============================================================================ +// Tokenization Strategy Tests +// ============================================================================ + +#[rstest] +#[case(TokenizationStrategy::BERT, TokenDataType::I32)] +#[case(TokenizationStrategy::ModernBERT, TokenDataType::U32)] +#[case(TokenizationStrategy::LoRA, TokenDataType::I32)] +fn test_tokenization_strategy_data_types( + #[case] strategy: TokenizationStrategy, + #[case] expected_dtype: TokenDataType, +) { + let config = TokenizationConfig { + tokenization_strategy: strategy, + token_data_type: expected_dtype.clone(), + ..Default::default() + }; + + assert_eq!(config.tokenization_strategy, strategy); + assert_eq!(config.token_data_type, expected_dtype); +} + +// ============================================================================ +// Concurrency Tests +// ============================================================================ + +#[rstest] +fn test_unified_tokenizer_thread_safety(unified_tokenizer: UnifiedTokenizer) { + use std::sync::Arc; + + let tokenizer = Arc::new(unified_tokenizer); + + // Use rayon for parallel execution - simpler and more efficient + let results: Vec<_> = (0..4) + .into_par_iter() + .map(|i| { + let text = format!("Thread {} test text", i); + tokenizer.tokenize(&text).expect("Tokenize in thread") + }) + .collect(); + + for result in results { + assert!(!result.token_ids.is_empty(), "Should tokenize successfully"); + } +} diff --git a/candle-binding/src/ffi/init_test.rs b/candle-binding/src/ffi/init_test.rs new file mode 100644 index 00000000..a088ce66 --- /dev/null +++ b/candle-binding/src/ffi/init_test.rs @@ -0,0 +1,353 @@ +//! Tests for FFI initialization module + +use super::init::*; +use super::state_manager::GlobalStateManager; +use rayon::prelude::*; +use rstest::*; +use std::ffi::CString; +use std::os::raw::c_char; + +// Note: Testing FFI functions is challenging because they use C ABI and global state. +// These tests focus on verifying basic functionality without requiring actual models. + +// ============================================================================ +// Global State Tests +// ============================================================================ + +#[rstest] +fn test_global_state_variables_exist() { + // Verify that the global static variables can be accessed + // We can't directly test lazy_static! variables, but we can test that + // the state manager works, which uses them internally + + let manager = GlobalStateManager::instance(); + let _state = manager.get_system_state(); + + // If we get here without panicking, the globals exist +} + +// ============================================================================ +// Helper Function Tests +// ============================================================================ + +#[rstest] +fn test_cstring_creation() { + // Test that we can create CStrings for FFI calls + let test_string = "test_model_path"; + let c_string = CString::new(test_string).expect("CString creation failed"); + let c_ptr: *const c_char = c_string.as_ptr(); + + assert!(!c_ptr.is_null(), "CString pointer should not be null"); +} + +#[rstest] +#[case("")] +#[case("model_path")] +#[case("/path/to/model")] +fn test_cstring_from_various_inputs(#[case] input: &str) { + let result = CString::new(input); + assert!(result.is_ok(), "Should create CString from: {}", input); +} + +// ============================================================================ +// Initialization Function Signatures Tests +// ============================================================================ + +#[test] +fn test_init_similarity_model_signature() { + // Verify function signature compiles and can be called with invalid path + // Note: This will likely fail/return false, but we're testing the interface + let test_path = CString::new("/nonexistent/model/path").unwrap(); + let result = init_similarity_model(test_path.as_ptr(), true); + + // With invalid path, should return false + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_classifier_signature() { + // Test with invalid path - should fail gracefully + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_classifier(test_path.as_ptr(), 0, true); + + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_pii_classifier_signature() { + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_pii_classifier(test_path.as_ptr(), 0, true); + + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_jailbreak_classifier_signature() { + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_jailbreak_classifier(test_path.as_ptr(), 0, true); + + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_modernbert_classifier_signature() { + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_modernbert_classifier(test_path.as_ptr(), true); + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_modernbert_pii_classifier_signature() { + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_modernbert_pii_classifier(test_path.as_ptr(), true); + assert!(!result, "Should return false with invalid path"); +} + +#[test] +fn test_init_unified_classifier_c_signature() { + let test_path = CString::new("/nonexistent/model").unwrap(); + + // Create valid (but empty) arrays for labels + // slice::from_raw_parts requires non-null, aligned pointers even if length is 0 + let empty_labels: Vec<*const c_char> = Vec::new(); + let labels_ptr = if empty_labels.is_empty() { + // Use a valid non-null pointer for empty slice + std::ptr::NonNull::<*const c_char>::dangling().as_ptr() + } else { + empty_labels.as_ptr() + }; + + let result = init_unified_classifier_c( + test_path.as_ptr(), + test_path.as_ptr(), + test_path.as_ptr(), + test_path.as_ptr(), + labels_ptr, + 0, + labels_ptr, + 0, + labels_ptr, + 0, + true, + ); + + assert!(!result, "Should return false with invalid paths"); +} + +// ============================================================================ +// State Manager Integration Tests +// ============================================================================ + +#[rstest] +fn test_state_manager_after_failed_init() { + let manager = GlobalStateManager::instance(); + + // Attempt init with invalid path (will fail) + let test_path = CString::new("/nonexistent/model").unwrap(); + let _result = init_similarity_model(test_path.as_ptr(), true); + + // State manager should still be accessible + let state = manager.get_system_state(); + + // State should be one of the valid states + assert!( + matches!( + state, + super::state_manager::SystemState::Uninitialized + | super::state_manager::SystemState::Ready + | super::state_manager::SystemState::Error(_) + | super::state_manager::SystemState::Initializing + ), + "Should have valid system state" + ); +} + +// ============================================================================ +// Thread Safety Tests for Initialization +// ============================================================================ + +#[rstest] +fn test_concurrent_init_attempts() { + // Try to initialize from multiple threads simultaneously + // This tests that the initialization locks work correctly + // Use rayon for parallel execution - simpler and more efficient + (0..4).into_par_iter().for_each(|_| { + // Attempt init with invalid path (will fail, but tests locking) + let test_path = CString::new("/nonexistent/model").unwrap(); + let _ = init_similarity_model(test_path.as_ptr(), true); + }); + + // If we get here, no deadlock occurred +} + +// ============================================================================ +// CString Safety Tests +// ============================================================================ + +#[rstest] +#[case("valid_path")] +#[case("/another/valid/path")] +#[case("model_id_123")] +fn test_cstring_for_model_paths(#[case] path: &str) { + let c_string = CString::new(path).expect("Create CString"); + let c_ptr = c_string.as_ptr(); + + // Verify pointer is not null + assert!(!c_ptr.is_null()); + + // Convert back to verify correctness + let back_to_str = unsafe { + std::ffi::CStr::from_ptr(c_ptr) + .to_str() + .expect("Convert back to str") + }; + + assert_eq!( + back_to_str, path, + "Round-trip conversion should preserve string" + ); +} + +#[test] +fn test_cstring_with_null_byte_fails() { + let invalid_string = "path\0with\0nulls"; + let result = CString::new(invalid_string); + + assert!( + result.is_err(), + "CString creation should fail with interior null bytes" + ); +} + +// ============================================================================ +// Boolean Return Value Tests +// ============================================================================ + +#[rstest] +fn test_init_functions_return_boolean() { + // All init functions should return bool + // Test that false is returned for invalid inputs + let test_path = CString::new("/nonexistent/model").unwrap(); + + assert!(!init_similarity_model(test_path.as_ptr(), true)); + assert!(!init_classifier(test_path.as_ptr(), 0, true)); + assert!(!init_pii_classifier(test_path.as_ptr(), 0, true)); + assert!(!init_jailbreak_classifier(test_path.as_ptr(), 0, true)); + assert!(!init_modernbert_classifier(test_path.as_ptr(), true)); +} + +// ============================================================================ +// Parameter Validation Tests +// ============================================================================ + +#[rstest] +#[case(true)] +#[case(false)] +fn test_use_cpu_parameter(#[case] use_cpu: bool) { + // Test that use_cpu parameter is accepted + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_similarity_model(test_path.as_ptr(), use_cpu); + + // Should fail due to invalid path, but parameter should be processed + assert!(!result); +} + +#[rstest] +#[case(0)] +#[case(2)] +#[case(5)] +fn test_num_labels_parameter(#[case] num_labels: i32) { + // Test that num_labels parameter is accepted + let test_path = CString::new("/nonexistent/model").unwrap(); + let result = init_classifier(test_path.as_ptr(), num_labels, true); + + assert!(!result, "Should fail with invalid path"); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[rstest] +fn test_invalid_path_handling() { + // All functions should handle invalid paths gracefully without crashing + let test_path = CString::new("/nonexistent/model").unwrap(); + + let _ = init_similarity_model(test_path.as_ptr(), true); + let _ = init_classifier(test_path.as_ptr(), 0, true); + let _ = init_pii_classifier(test_path.as_ptr(), 0, true); + let _ = init_jailbreak_classifier(test_path.as_ptr(), 0, true); + let _ = init_modernbert_classifier(test_path.as_ptr(), true); + let _ = init_modernbert_pii_classifier(test_path.as_ptr(), true); + + // If we reach here, no crashes occurred +} + +// ============================================================================ +// Integration with State Manager Tests +// ============================================================================ + +#[rstest] +fn test_state_manager_stats_after_init_attempts() { + let manager = GlobalStateManager::instance(); + + // Try various init functions + let test_path = CString::new("/nonexistent/model").unwrap(); + let _ = init_similarity_model(test_path.as_ptr(), true); + let _ = init_modernbert_classifier(test_path.as_ptr(), true); + + // Get stats - should work regardless of init success/failure + let stats = manager.get_stats(); + + // Stats should be retrievable + assert!( + stats.unified_classifier_initialized || !stats.unified_classifier_initialized, + "Should have stats" + ); +} + +// ============================================================================ +// Const Correctness Tests +// ============================================================================ + +#[test] +fn test_const_char_pointer_usage() { + // Test that const char* parameters work correctly + let test_str = CString::new("test").unwrap(); + let ptr: *const c_char = test_str.as_ptr(); + + // Verify the pointer can be used in FFI context + assert!(!ptr.is_null()); + + // Pass to a function (will fail but tests the interface) + let _result = init_similarity_model(ptr, true); +} + +// ============================================================================ +// Memory Safety Tests +// ============================================================================ + +#[rstest] +fn test_cstring_lifetime() { + // Test that CString lives long enough for FFI call + let _result = { + let model_id = CString::new("model").unwrap(); + let ptr = model_id.as_ptr(); + init_similarity_model(ptr, true) + // model_id is dropped here, but call already completed + }; + + // Should complete without memory issues +} + +#[rstest] +fn test_multiple_cstrings() { + // Test creating multiple CStrings for different parameters + let model_id = CString::new("model_id").unwrap(); + let _tokenizer_path = CString::new("tokenizer_path").unwrap(); + let _lora_path = CString::new("lora_path").unwrap(); + + let _result = init_classifier(model_id.as_ptr(), 2, true); + + // All CStrings should remain valid during the call +} diff --git a/candle-binding/src/ffi/mod.rs b/candle-binding/src/ffi/mod.rs index 83d2e079..5f94aeab 100644 --- a/candle-binding/src/ffi/mod.rs +++ b/candle-binding/src/ffi/mod.rs @@ -34,4 +34,10 @@ pub mod classify_test; #[cfg(test)] pub mod embedding_test; #[cfg(test)] +pub mod init_test; +#[cfg(test)] pub mod memory_safety_test; +#[cfg(test)] +pub mod state_manager_test; +#[cfg(test)] +pub mod validation_test; diff --git a/candle-binding/src/ffi/state_manager_test.rs b/candle-binding/src/ffi/state_manager_test.rs new file mode 100644 index 00000000..b820c5f4 --- /dev/null +++ b/candle-binding/src/ffi/state_manager_test.rs @@ -0,0 +1,383 @@ +//! Tests for global state manager + +use super::state_manager::*; +use rayon::prelude::*; +use rstest::*; + +// Note: These tests use the actual singleton instance, so they may affect each other +// In a real scenario, you might want to use a separate test instance or mock + +// ============================================================================ +// Singleton Tests +// ============================================================================ + +#[rstest] +fn test_global_state_manager_instance() { + let instance1 = GlobalStateManager::instance(); + let instance2 = GlobalStateManager::instance(); + + // Should return the same instance (singleton pattern) + assert_eq!( + instance1 as *const GlobalStateManager, instance2 as *const GlobalStateManager, + "Should return the same singleton instance" + ); +} + +// ============================================================================ +// System State Tests +// ============================================================================ + +#[rstest] +fn test_system_state_initial() { + let manager = GlobalStateManager::instance(); + let state = manager.get_system_state(); + + // System should either be Uninitialized or Ready (depending on test order) + assert!( + matches!( + state, + SystemState::Uninitialized | SystemState::Ready | SystemState::Initializing + ), + "Initial state should be Uninitialized, Initializing, or Ready" + ); +} + +#[rstest] +fn test_system_state_enum() { + // Test SystemState enum variants + let states = vec![ + SystemState::Uninitialized, + SystemState::Initializing, + SystemState::Ready, + SystemState::ShuttingDown, + SystemState::Error("Test error".to_string()), + ]; + + for state in states { + assert!( + matches!( + state, + SystemState::Uninitialized + | SystemState::Initializing + | SystemState::Ready + | SystemState::ShuttingDown + | SystemState::Error(_) + ), + "Should be valid SystemState variant" + ); + } +} + +// ============================================================================ +// Initialization Status Tests +// ============================================================================ + +#[rstest] +fn test_is_any_initialized() { + let manager = GlobalStateManager::instance(); + + // This will be true or false depending on what's initialized + let any_init = manager.is_any_initialized(); + + // Just verify it returns a boolean + assert!(any_init || !any_init, "Should return boolean"); +} + +#[rstest] +fn test_is_ready() { + let manager = GlobalStateManager::instance(); + + // Just verify the method works + let ready = manager.is_ready(); + assert!(ready || !ready, "Should return boolean"); +} + +// ============================================================================ +// Classifier Initialization Tests +// ============================================================================ + +#[rstest] +fn test_is_unified_classifier_initialized() { + let manager = GlobalStateManager::instance(); + + let is_init = manager.is_unified_classifier_initialized(); + + // Should return a boolean + assert!(is_init || !is_init, "Should return boolean"); + + // If initialized, should be able to get it + if is_init { + let classifier = manager.get_unified_classifier(); + assert!( + classifier.is_some(), + "Should return classifier when initialized" + ); + } +} + +#[rstest] +fn test_get_unified_classifier_when_not_initialized() { + let manager = GlobalStateManager::instance(); + + // Attempt to get classifier (may or may not be initialized) + let classifier = manager.get_unified_classifier(); + + // Should return Option + match classifier { + Some(_) => { + // If Some, is_initialized should be true + assert!(manager.is_unified_classifier_initialized()); + } + None => { + // If None, might not be initialized (or just wasn't set yet) + } + } +} + +// ============================================================================ +// LoRA Engine Tests +// ============================================================================ + +#[rstest] +fn test_get_parallel_lora_engine() { + let manager = GlobalStateManager::instance(); + + // Attempt to get LoRA engine + let engine = manager.get_parallel_lora_engine(); + + // Should return Option (may be None if not initialized) + match engine { + Some(_) => { + // Successfully got engine + } + None => { + // Engine not initialized yet + } + } +} + +// ============================================================================ +// Token Classifier Tests +// ============================================================================ + +#[rstest] +fn test_get_lora_token_classifier() { + let manager = GlobalStateManager::instance(); + + // Attempt to get token classifier + let classifier = manager.get_lora_token_classifier(); + + // Should return Option + match classifier { + Some(_) => { + // Successfully got classifier + } + None => { + // Classifier not initialized + } + } +} + +// ============================================================================ +// BERT Similarity Tests +// ============================================================================ + +#[rstest] +fn test_get_bert_similarity() { + let manager = GlobalStateManager::instance(); + + // Attempt to get BERT similarity + let similarity = manager.get_bert_similarity(); + + // Should return Option + match similarity { + Some(_) => { + // Successfully got similarity + } + None => { + // Similarity not initialized + } + } +} + +// ============================================================================ +// Legacy Classifier Tests +// ============================================================================ + +#[rstest] +#[case("legacy_bert")] +#[case("legacy_pii")] +#[case("legacy_jailbreak")] +#[case("nonexistent")] +fn test_get_legacy_classifier(#[case] name: &str) { + let manager = GlobalStateManager::instance(); + + // Attempt to get legacy classifier by name + let classifier = manager.get_legacy_classifier(name); + + // Should return Option (likely None for most names) + match classifier { + Some(_) => { + // Found a classifier with this name + } + None => { + // Classifier not found or not initialized + } + } +} + +// ============================================================================ +// Statistics Tests +// ============================================================================ + +#[rstest] +fn test_get_stats() { + let manager = GlobalStateManager::instance(); + + // Get statistics + let stats = manager.get_stats(); + + // Verify structure (based on actual implementation) + // Note: You may need to adjust these assertions based on actual struct fields + assert!( + stats.unified_classifier_initialized || !stats.unified_classifier_initialized, + "Should have unified_classifier_initialized field" + ); + assert!( + stats.parallel_lora_engine_initialized || !stats.parallel_lora_engine_initialized, + "Should have parallel_lora_engine_initialized field" + ); + assert!( + stats.lora_token_classifier_initialized || !stats.lora_token_classifier_initialized, + "Should have lora_token_classifier_initialized field" + ); + assert!( + stats.bert_similarity_initialized || !stats.bert_similarity_initialized, + "Should have bert_similarity_initialized field" + ); +} + +// ============================================================================ +// Cleanup Tests +// ============================================================================ + +#[rstest] +fn test_cleanup_method_exists() { + let manager = GlobalStateManager::instance(); + + // Just verify cleanup method can be called + // Note: We don't actually call it in tests as it would affect other tests + // manager.cleanup(); + + // Instead, just verify the method exists through compilation + let _ = manager; // Use the manager to avoid unused variable warning +} + +// ============================================================================ +// Thread Safety Tests +// ============================================================================ + +#[rstest] +fn test_global_state_manager_thread_safety() { + // Use rayon for parallel execution - simpler and more efficient + (0..4).into_par_iter().for_each(|_| { + let manager = GlobalStateManager::instance(); + let _ = manager.get_system_state(); + let _ = manager.is_any_initialized(); + let _ = manager.get_stats(); + }); +} + +#[rstest] +fn test_concurrent_state_access() { + // Use rayon for parallel execution - simpler and more efficient + let results: Vec<_> = (0..8) + .into_par_iter() + .map(|i| { + let manager = GlobalStateManager::instance(); + + // Perform various read operations + let _ = manager.get_system_state(); + let _ = manager.is_ready(); + let _ = manager.is_any_initialized(); + let _ = manager.get_unified_classifier(); + let _ = manager.get_parallel_lora_engine(); + let _ = manager.get_lora_token_classifier(); + let _ = manager.get_bert_similarity(); + let _ = manager.get_legacy_classifier(&format!("classifier_{}", i)); + let _ = manager.get_stats(); + + i // Return thread number + }) + .collect(); + + for (idx, result) in results.into_iter().enumerate() { + assert_eq!(result, idx, "Thread should return correct index"); + } +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[rstest] +fn test_system_state_error_variant() { + let error_state = SystemState::Error("Test error message".to_string()); + + match error_state { + SystemState::Error(msg) => { + assert_eq!(msg, "Test error message"); + } + _ => panic!("Should be Error variant"), + } +} + +// ============================================================================ +// Integration Tests +// ============================================================================ + +#[rstest] +fn test_state_consistency() { + let manager = GlobalStateManager::instance(); + + // Get initialization status + let unified_init = manager.is_unified_classifier_initialized(); + let any_init = manager.is_any_initialized(); + + // If unified classifier is initialized, any_init should be true + if unified_init { + assert!( + any_init, + "If unified classifier is initialized, any_init should be true" + ); + } + + // Get stats and verify consistency + let stats = manager.get_stats(); + assert_eq!( + stats.unified_classifier_initialized, unified_init, + "Stats should match is_initialized status" + ); +} + +#[rstest] +fn test_get_operations_consistency() { + let manager = GlobalStateManager::instance(); + + // Call get twice, should return consistent results + let classifier1 = manager.get_unified_classifier(); + let classifier2 = manager.get_unified_classifier(); + + match (classifier1, classifier2) { + (Some(_), Some(_)) => { + // Both Some - consistent + } + (None, None) => { + // Both None - consistent + } + _ => { + // This should not happen unless there's a race condition + // In practice, once initialized, it should stay initialized + } + } +} diff --git a/candle-binding/src/ffi/validation_test.rs b/candle-binding/src/ffi/validation_test.rs new file mode 100644 index 00000000..5b1e6bdd --- /dev/null +++ b/candle-binding/src/ffi/validation_test.rs @@ -0,0 +1,382 @@ +//! Tests for FFI validation functions + +use super::validation::*; +use rayon::prelude::*; +use rstest::*; +use std::ffi::CString; +use std::os::raw::c_char; + +// ============================================================================ +// Text Input Validation Tests +// ============================================================================ + +#[rstest] +fn test_validate_text_input_null_pointer() { + let result = validate_text_input(std::ptr::null(), 0); + assert!(!result.is_valid, "Should reject null pointer"); +} + +#[rstest] +#[case("Valid text for testing", 0, true)] +#[case("Another valid text", 1, true)] +#[case("Short but valid", 0, true)] +fn test_validate_text_input_valid( + #[case] text: &str, + #[case] path_type: i32, + #[case] should_be_valid: bool, +) { + let c_text = CString::new(text).unwrap(); + let result = validate_text_input(c_text.as_ptr(), path_type); + + assert_eq!( + result.is_valid, should_be_valid, + "Text validation result mismatch for: {}", + text + ); + + // Clean up + free_validation_result(result); +} + +#[rstest] +fn test_validate_text_input_empty() { + let c_text = CString::new("").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + // Empty text should likely be invalid (too short) + assert!(!result.is_valid, "Empty text should be invalid"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_text_input_very_long() { + // Create a very long text + let long_text = "a".repeat(100000); + let c_text = CString::new(long_text).unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + // May or may not be valid depending on MAX_TEXT_LENGTH + // Just verify it doesn't crash + let _ = result.is_valid; + + free_validation_result(result); +} + +#[rstest] +#[case(0)] +#[case(1)] +fn test_validate_text_input_path_types(#[case] path_type: i32) { + let c_text = CString::new("Test text").unwrap(); + let result = validate_text_input(c_text.as_ptr(), path_type); + + // Should handle both path types + let _ = result.is_valid; + + free_validation_result(result); +} + +#[rstest] +fn test_validate_text_input_invalid_path_type() { + let c_text = CString::new("Test text").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 99); + + // Invalid path type should result in error + assert!(!result.is_valid, "Invalid path type should fail"); + + free_validation_result(result); +} + +// ============================================================================ +// Batch Input Validation Tests +// ============================================================================ + +#[rstest] +fn test_validate_batch_input_null_pointer() { + let result = validate_batch_input(std::ptr::null(), 0, 0); + assert!(!result.is_valid, "Should reject null pointer"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_batch_input_zero_count() { + // Even with valid pointer, zero count should fail + let texts = vec![CString::new("test").unwrap()]; + let ptrs: Vec<*const c_char> = texts.iter().map(|s| s.as_ptr()).collect(); + + let result = validate_batch_input(ptrs.as_ptr(), 0, 0); + assert!(!result.is_valid, "Zero count should be invalid"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_batch_input_negative_count() { + let texts = vec![CString::new("test").unwrap()]; + let ptrs: Vec<*const c_char> = texts.iter().map(|s| s.as_ptr()).collect(); + + let result = validate_batch_input(ptrs.as_ptr(), -1, 0); + assert!(!result.is_valid, "Negative count should be invalid"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_batch_input_valid_small_batch() { + let texts = vec![ + CString::new("First text").unwrap(), + CString::new("Second text").unwrap(), + CString::new("Third text").unwrap(), + ]; + let ptrs: Vec<*const c_char> = texts.iter().map(|s| s.as_ptr()).collect(); + + let result = validate_batch_input(ptrs.as_ptr(), 3, 0); + + // Should be valid for small batch + assert!( + result.is_valid || !result.is_valid, + "Should complete validation" + ); + + free_validation_result(result); +} + +#[rstest] +#[case(0)] +#[case(1)] +fn test_validate_batch_input_path_types(#[case] path_type: i32) { + let texts = vec![ + CString::new("Test one").unwrap(), + CString::new("Test two").unwrap(), + ]; + let ptrs: Vec<*const c_char> = texts.iter().map(|s| s.as_ptr()).collect(); + + let result = validate_batch_input(ptrs.as_ptr(), 2, path_type); + + let _ = result.is_valid; + + free_validation_result(result); +} + +// ============================================================================ +// Model Path Validation Tests +// ============================================================================ + +#[rstest] +fn test_validate_model_path_null() { + let result = validate_model_path(std::ptr::null(), 0); + assert!(!result.is_valid, "Null path should be invalid"); + + free_validation_result(result); +} + +#[rstest] +#[case("/path/to/model", 0)] +#[case("/another/path", 1)] +fn test_validate_model_path_various_paths(#[case] path: &str, #[case] path_type: i32) { + let c_path = CString::new(path).unwrap(); + let result = validate_model_path(c_path.as_ptr(), path_type); + + // Path validation depends on actual file existence + let _ = result.is_valid; + + free_validation_result(result); +} + +// ============================================================================ +// Confidence Threshold Validation Tests +// ============================================================================ + +#[rstest] +#[case(0.5, 0, true)] // Valid for traditional +#[case(0.9, 1, true)] // Valid for LoRA +#[case(0.0, 0, false)] // Too low for traditional +#[case(1.0, 0, true)] // Maximum valid +fn test_validate_confidence_threshold_various_values( + #[case] confidence: f32, + #[case] path_type: i32, + #[case] _expected_valid: bool, +) { + let result = validate_confidence_threshold(confidence, path_type); + + // Just verify it runs without crashing + let _ = result.is_valid; + + free_validation_result(result); +} + +#[rstest] +fn test_validate_confidence_threshold_out_of_range_low() { + let result = validate_confidence_threshold(-0.1, 0); + assert!(!result.is_valid, "Negative confidence should be invalid"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_confidence_threshold_out_of_range_high() { + let result = validate_confidence_threshold(1.1, 0); + assert!(!result.is_valid, "Confidence > 1.0 should be invalid"); + + free_validation_result(result); +} + +#[rstest] +fn test_validate_confidence_threshold_boundary_values() { + let result_zero = validate_confidence_threshold(0.0, 1); + let _ = result_zero.is_valid; + free_validation_result(result_zero); + + let result_one = validate_confidence_threshold(1.0, 1); + let _ = result_one.is_valid; + free_validation_result(result_one); +} + +// ============================================================================ +// Memory Parameters Validation Tests +// ============================================================================ + +#[rstest] +#[case(1024, 16, true)] +#[case(4096, 32, true)] +#[case(0, 16, false)] // Zero size should be invalid +fn test_validate_memory_parameters( + #[case] size: usize, + #[case] alignment: usize, + #[case] _expected_valid: bool, +) { + let result = validate_memory_parameters(size, alignment); + + let _ = result.is_valid; + + free_validation_result(result); +} + +// ============================================================================ +// ValidationResult Structure Tests +// ============================================================================ + +#[rstest] +fn test_validation_result_structure() { + let c_text = CString::new("Test").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + // Verify structure fields exist + let _ = result.is_valid; + let _ = result.error_code; + let _ = result.error_message; + let _ = result.suggestions; + + free_validation_result(result); +} + +// ============================================================================ +// Free Function Tests +// ============================================================================ + +#[rstest] +fn test_free_validation_result() { + let c_text = CString::new("Test").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + // Should not crash when freeing + free_validation_result(result); +} + +#[rstest] +fn test_multiple_free_calls() { + let c_text = CString::new("Test").unwrap(); + + for _ in 0..10 { + let result = validate_text_input(c_text.as_ptr(), 0); + free_validation_result(result); + } + + // Should not leak memory +} + +// ============================================================================ +// Thread Safety Tests +// ============================================================================ + +#[rstest] +fn test_validation_thread_safety() { + // Use rayon for parallel execution - simpler and more efficient + (0..4).into_par_iter().for_each(|i| { + let text = format!("Thread {} test", i); + let c_text = CString::new(text).unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + let is_valid = result.is_valid; + free_validation_result(result); + assert!(is_valid, "Thread {} should validate successfully", i); + }); +} + +// ============================================================================ +// UTF-8 Validation Tests +// ============================================================================ + +#[rstest] +fn test_validate_text_input_ascii() { + let c_text = CString::new("ASCII text only").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + let _ = result.is_valid; + + free_validation_result(result); +} + +#[rstest] +fn test_validate_text_input_unicode() { + let c_text = CString::new("Unicode: 你好世界 🌍").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + // Should handle valid UTF-8 + let _ = result.is_valid; + + free_validation_result(result); +} + +// ============================================================================ +// Error Code Tests +// ============================================================================ + +#[rstest] +fn test_validation_error_codes() { + // Test that error codes are set correctly + let result_null = validate_text_input(std::ptr::null(), 0); + assert_eq!(result_null.error_code, ERROR_NULL_POINTER); + free_validation_result(result_null); + + let result_invalid_confidence = validate_confidence_threshold(-1.0, 0); + assert_eq!( + result_invalid_confidence.error_code, + ERROR_INVALID_CONFIDENCE + ); + free_validation_result(result_invalid_confidence); +} + +// ============================================================================ +// Success Case Tests +// ============================================================================ + +#[rstest] +fn test_validation_success_case() { + let c_text = CString::new("This is a valid test text for validation").unwrap(); + let result = validate_text_input(c_text.as_ptr(), 0); + + if result.is_valid { + // On success, error_message and suggestion should be null or empty + assert!( + result.error_message.is_null() + || unsafe { + std::ffi::CStr::from_ptr(result.error_message) + .to_bytes() + .is_empty() + } + ); + } + + free_validation_result(result); +} diff --git a/candle-binding/src/model_architectures/lora/lora_adapter_test.rs b/candle-binding/src/model_architectures/lora/lora_adapter_test.rs new file mode 100644 index 00000000..f77b9cf7 --- /dev/null +++ b/candle-binding/src/model_architectures/lora/lora_adapter_test.rs @@ -0,0 +1,518 @@ +//! Tests for LoRA adapter module + +use super::lora_adapter::*; +use candle_core::{DType, Device}; +use candle_nn::VarBuilder; +use rstest::*; + +// ============================================================================ +// Configuration Tests +// ============================================================================ + +#[rstest] +fn test_lora_config_default() { + let config = LoRAConfig::default(); + + assert_eq!(config.rank, 16); + assert_eq!(config.alpha, 32.0); + assert_eq!(config.dropout, 0.1); + assert_eq!(config.target_modules.len(), 4); + assert!(!config.use_bias); + assert!(matches!(config.init_method, LoRAInitMethod::Kaiming)); +} + +#[rstest] +fn test_lora_config_custom() { + let config = LoRAConfig { + rank: 32, + alpha: 64.0, + dropout: 0.2, + target_modules: vec!["query".to_string(), "value".to_string()], + use_bias: true, + init_method: LoRAInitMethod::Xavier, + }; + + assert_eq!(config.rank, 32); + assert_eq!(config.alpha, 64.0); + assert_eq!(config.dropout, 0.2); + assert_eq!(config.target_modules.len(), 2); + assert!(config.use_bias); + assert!(matches!(config.init_method, LoRAInitMethod::Xavier)); +} + +#[rstest] +fn test_lora_config_clone() { + let config1 = LoRAConfig::default(); + let config2 = config1.clone(); + + assert_eq!(config1.rank, config2.rank); + assert_eq!(config1.alpha, config2.alpha); + assert_eq!(config1.dropout, config2.dropout); +} + +#[rstest] +#[case(4)] +#[case(8)] +#[case(16)] +#[case(32)] +#[case(64)] +fn test_lora_config_various_ranks(#[case] rank: usize) { + let config = LoRAConfig { + rank, + ..Default::default() + }; + + assert_eq!(config.rank, rank); + + // Scaling factor should be alpha / rank + let expected_scaling = config.alpha / rank as f64; + assert!((expected_scaling - (config.alpha / config.rank as f64)).abs() < 1e-9); +} + +// ============================================================================ +// LoRA Init Method Tests +// ============================================================================ + +#[rstest] +fn test_lora_init_method_variants() { + let methods = vec![ + LoRAInitMethod::Kaiming, + LoRAInitMethod::Xavier, + LoRAInitMethod::Normal { + mean: 0.0, + std: 0.02, + }, + LoRAInitMethod::Zero, + ]; + + // Each variant should be distinct + for (i, method1) in methods.iter().enumerate() { + for (j, method2) in methods.iter().enumerate() { + if i != j { + match (method1, method2) { + (LoRAInitMethod::Kaiming, LoRAInitMethod::Kaiming) => unreachable!(), + (LoRAInitMethod::Xavier, LoRAInitMethod::Xavier) => unreachable!(), + (LoRAInitMethod::Zero, LoRAInitMethod::Zero) => unreachable!(), + _ => { + // Different variants + } + } + } + } + } +} + +#[rstest] +fn test_lora_init_method_normal_with_custom_params() { + let method = LoRAInitMethod::Normal { + mean: 0.5, + std: 0.1, + }; + + match method { + LoRAInitMethod::Normal { mean, std } => { + assert_eq!(mean, 0.5); + assert_eq!(std, 0.1); + } + _ => panic!("Expected Normal variant"), + } +} + +// ============================================================================ +// LoRA Adapter Creation Tests +// ============================================================================ + +#[rstest] +fn test_lora_adapter_new_basic() { + let device = Device::Cpu; + let input_dim = 768; + let output_dim = 768; + let config = LoRAConfig::default(); + + // Create a simple VarMap for testing + let var_map = candle_nn::VarMap::new(); + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device); + + let result = LoRAAdapter::new(input_dim, output_dim, &config, vb, &device); + + assert!(result.is_ok(), "Should create LoRA adapter"); +} + +#[rstest] +#[case(512, 512)] +#[case(768, 768)] +#[case(1024, 1024)] +fn test_lora_adapter_various_dimensions(#[case] input_dim: usize, #[case] output_dim: usize) { + let device = Device::Cpu; + let config = LoRAConfig::default(); + + let var_map = candle_nn::VarMap::new(); + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device); + + let result = LoRAAdapter::new(input_dim, output_dim, &config, vb, &device); + + assert!( + result.is_ok(), + "Should create adapter with dims {}x{}", + input_dim, + output_dim + ); +} + +#[rstest] +fn test_lora_adapter_with_different_init_methods() { + let device = Device::Cpu; + let input_dim = 768; + let output_dim = 768; + + let init_methods = vec![ + LoRAInitMethod::Kaiming, + LoRAInitMethod::Xavier, + LoRAInitMethod::Normal { + mean: 0.0, + std: 0.02, + }, + LoRAInitMethod::Zero, + ]; + + for init_method in init_methods { + let config = LoRAConfig { + init_method, + ..Default::default() + }; + + let var_map = candle_nn::VarMap::new(); + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device); + + let result = LoRAAdapter::new(input_dim, output_dim, &config, vb, &device); + + assert!( + result.is_ok(), + "Should create adapter with init method {:?}", + config.init_method + ); + } +} + +// ============================================================================ +// LoRA Scaling Tests +// ============================================================================ + +#[rstest] +#[case(16, 32.0, 2.0)] +#[case(8, 16.0, 2.0)] +#[case(32, 64.0, 2.0)] +#[case(4, 8.0, 2.0)] +fn test_lora_scaling_calculation( + #[case] rank: usize, + #[case] alpha: f64, + #[case] expected_scaling: f64, +) { + let config = LoRAConfig { + rank, + alpha, + ..Default::default() + }; + + let scaling = config.alpha / config.rank as f64; + + assert!( + (scaling - expected_scaling).abs() < 1e-9, + "Scaling should be alpha/rank = {}, got {}", + expected_scaling, + scaling + ); +} + +// ============================================================================ +// Target Modules Tests +// ============================================================================ + +#[rstest] +fn test_lora_config_default_target_modules() { + let config = LoRAConfig::default(); + + let expected_modules = vec!["query", "value", "key", "output"]; + + assert_eq!(config.target_modules.len(), expected_modules.len()); + + for expected in expected_modules { + assert!( + config.target_modules.contains(&expected.to_string()), + "Should contain target module: {}", + expected + ); + } +} + +#[rstest] +fn test_lora_config_custom_target_modules() { + let custom_modules = vec!["query".to_string(), "key".to_string(), "dense".to_string()]; + + let config = LoRAConfig { + target_modules: custom_modules.clone(), + ..Default::default() + }; + + assert_eq!(config.target_modules.len(), 3); + assert_eq!(config.target_modules, custom_modules); +} + +// ============================================================================ +// Edge Case Tests +// ============================================================================ + +#[rstest] +fn test_lora_config_with_zero_dropout() { + let config = LoRAConfig { + dropout: 0.0, + ..Default::default() + }; + + assert_eq!(config.dropout, 0.0); +} + +#[rstest] +fn test_lora_config_with_high_dropout() { + let config = LoRAConfig { + dropout: 0.9, + ..Default::default() + }; + + assert_eq!(config.dropout, 0.9); +} + +#[rstest] +fn test_lora_config_with_small_rank() { + let config = LoRAConfig { + rank: 2, + ..Default::default() + }; + + assert_eq!(config.rank, 2); +} + +#[rstest] +fn test_lora_config_with_large_rank() { + let config = LoRAConfig { + rank: 128, + ..Default::default() + }; + + assert_eq!(config.rank, 128); +} + +// ============================================================================ +// Serialization Tests (if needed) +// ============================================================================ + +#[rstest] +fn test_lora_config_serialization() { + let config = LoRAConfig::default(); + + // Test JSON serialization + let json_result = serde_json::to_string(&config); + assert!(json_result.is_ok(), "Should serialize to JSON"); + + let json_str = json_result.unwrap(); + assert!(!json_str.is_empty(), "JSON string should not be empty"); +} + +#[rstest] +fn test_lora_config_deserialization() { + let json_str = r#"{ + "rank": 16, + "alpha": 32.0, + "dropout": 0.1, + "target_modules": ["query", "value", "key", "output"], + "use_bias": false, + "init_method": "Kaiming" + }"#; + + let result: Result = serde_json::from_str(json_str); + + assert!(result.is_ok(), "Should deserialize from JSON"); + + let config = result.unwrap(); + assert_eq!(config.rank, 16); + assert_eq!(config.alpha, 32.0); +} + +#[rstest] +fn test_lora_init_method_serialization() { + let methods = vec![ + LoRAInitMethod::Kaiming, + LoRAInitMethod::Xavier, + LoRAInitMethod::Normal { + mean: 0.0, + std: 0.02, + }, + LoRAInitMethod::Zero, + ]; + + for method in methods { + let json_result = serde_json::to_string(&method); + assert!( + json_result.is_ok(), + "Should serialize init method {:?}", + method + ); + } +} + +// ============================================================================ +// Parameter Count Tests +// ============================================================================ + +#[rstest] +#[case(768, 768, 16)] +#[case(1024, 1024, 32)] +#[case(512, 512, 8)] +fn test_lora_parameter_count( + #[case] input_dim: usize, + #[case] output_dim: usize, + #[case] rank: usize, +) { + // LoRA parameters: A (rank x input_dim) + B (output_dim x rank) + let expected_params = (rank * input_dim) + (output_dim * rank); + + // For reference: full fine-tuning would be input_dim x output_dim + let full_params = input_dim * output_dim; + + let reduction_ratio = full_params as f64 / expected_params as f64; + + assert!( + reduction_ratio > 1.0, + "LoRA should reduce parameter count (reduction: {}x)", + reduction_ratio + ); +} + +// ============================================================================ +// Configuration Validation Tests +// ============================================================================ + +#[rstest] +fn test_lora_config_alpha_positive() { + let config = LoRAConfig { + alpha: 32.0, + ..Default::default() + }; + + assert!(config.alpha > 0.0, "Alpha should be positive"); +} + +#[rstest] +fn test_lora_config_rank_positive() { + let config = LoRAConfig { + rank: 16, + ..Default::default() + }; + + assert!(config.rank > 0, "Rank should be positive"); +} + +#[rstest] +fn test_lora_config_dropout_valid_range() { + let config = LoRAConfig { + dropout: 0.1, + ..Default::default() + }; + + assert!( + config.dropout >= 0.0 && config.dropout <= 1.0, + "Dropout should be in [0, 1] range" + ); +} + +// ============================================================================ +// Bias Configuration Tests +// ============================================================================ + +#[rstest] +fn test_lora_config_with_bias() { + let config = LoRAConfig { + use_bias: true, + ..Default::default() + }; + + assert!(config.use_bias); +} + +#[rstest] +fn test_lora_config_without_bias() { + let config = LoRAConfig { + use_bias: false, + ..Default::default() + }; + + assert!(!config.use_bias); +} + +// ============================================================================ +// Memory Estimation Tests +// ============================================================================ + +#[rstest] +#[case(768, 768, 16, 4)] // F32 = 4 bytes +fn test_lora_memory_estimation( + #[case] input_dim: usize, + #[case] output_dim: usize, + #[case] rank: usize, + #[case] bytes_per_param: usize, +) { + // Memory for LoRA: (rank * input_dim + output_dim * rank) * bytes_per_param + let lora_params = (rank * input_dim) + (output_dim * rank); + let lora_memory = lora_params * bytes_per_param; + + // Memory for full fine-tuning: input_dim * output_dim * bytes_per_param + let full_params = input_dim * output_dim; + let full_memory = full_params * bytes_per_param; + + let memory_saving_ratio = full_memory as f64 / lora_memory as f64; + + assert!( + memory_saving_ratio > 1.0, + "LoRA should save memory ({}x reduction)", + memory_saving_ratio + ); +} + +// ============================================================================ +// Target Module Pattern Tests +// ============================================================================ + +#[rstest] +fn test_lora_target_modules_empty() { + let config = LoRAConfig { + target_modules: vec![], + ..Default::default() + }; + + assert_eq!(config.target_modules.len(), 0); +} + +#[rstest] +fn test_lora_target_modules_single() { + let config = LoRAConfig { + target_modules: vec!["query".to_string()], + ..Default::default() + }; + + assert_eq!(config.target_modules.len(), 1); + assert_eq!(config.target_modules[0], "query"); +} + +#[rstest] +fn test_lora_target_modules_all_attention() { + let attention_modules = vec!["query".to_string(), "key".to_string(), "value".to_string()]; + + let config = LoRAConfig { + target_modules: attention_modules.clone(), + ..Default::default() + }; + + for module in attention_modules { + assert!(config.target_modules.contains(&module)); + } +} diff --git a/candle-binding/src/model_architectures/lora/mod.rs b/candle-binding/src/model_architectures/lora/mod.rs index 7cc58f22..c7347a88 100644 --- a/candle-binding/src/model_architectures/lora/mod.rs +++ b/candle-binding/src/model_architectures/lora/mod.rs @@ -18,3 +18,5 @@ pub use lora_adapter::*; // Test modules (only compiled in test builds) #[cfg(test)] pub mod bert_lora_test; +#[cfg(test)] +pub mod lora_adapter_test;