diff --git a/candle-binding/Cargo.toml b/candle-binding/Cargo.toml index e638ecde..122388ac 100644 --- a/candle-binding/Cargo.toml +++ b/candle-binding/Cargo.toml @@ -11,16 +11,18 @@ crate-type = ["staticlib", "cdylib"] [features] default = [] +# CUDA support (enables GPU acceleration) +cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] # Flash Attention 2 support (requires CUDA and compatible GPU) # Enable with: cargo build --features flash-attn # Note: Requires CUDA Compute Capability >= 8.0 (Ampere or newer) -flash-attn = ["candle-flash-attn"] +flash-attn = ["cuda", "candle-flash-attn"] [dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-core = { version = "0.8.4", features = ["cuda"] } -candle-nn = { version = "0.8.4", features = ["cuda"] } -candle-transformers = { version = "0.8.4", features = ["cuda"] } +candle-core = "0.8.4" +candle-nn = "0.8.4" +candle-transformers = "0.8.4" # Flash Attention 2 (optional, requires CUDA) # Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn candle-flash-attn = { version = "0.8.4", optional = true } diff --git a/tools/make/rust.mk b/tools/make/rust.mk index aa7f4e57..0ad7031f 100644 --- a/tools/make/rust.mk +++ b/tools/make/rust.mk @@ -13,6 +13,17 @@ test-rust: rust @echo "Running Rust unit tests (release mode, sequential on GPU $(TEST_GPU_DEVICE))" @cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --lib -- --test-threads=1 --nocapture +# Test Flash Attention(requires GPU and CUDA environment configured in system) +# Note: Ensure CUDA paths are set in your shell environment (e.g., ~/.bashrc) +# - PATH should include nvcc (e.g., /usr/local/cuda/bin) +# - LD_LIBRARY_PATH should include CUDA libs (e.g., /usr/local/cuda/lib64, /usr/lib/wsl/lib for WSL) +# - CUDA_HOME, CUDA_PATH should point to CUDA installation +# Note: Uses TEST_GPU_DEVICE env var (default: 2) to avoid GPU 0/1 which may be busy +test-rust-flash-attn: rust-flash-attn + @$(LOG_TARGET) + @echo "Running Rust unit tests with Flash Attention 2 (GPU $(TEST_GPU_DEVICE))" + @cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn --lib -- --test-threads=1 --nocapture + # Test specific Rust module (with release optimization for performance) # Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test # Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test::test_pii_lora_pii_lora_classifier_new @@ -26,6 +37,19 @@ test-rust-module: rust @echo "Running Rust tests for module: $(MODULE) (release mode, GPU $(TEST_GPU_DEVICE))" @cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release $(MODULE) --lib -- --test-threads=1 --nocapture +# Test specific Flash Attention module (requires GPU and CUDA environment) +# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test +# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test::test_qwen3_embedding_forward +test-rust-flash-attn-module: rust-flash-attn + @$(LOG_TARGET) + @if [ -z "$(MODULE)" ]; then \ + echo "Usage: make test-rust-flash-attn-module MODULE="; \ + echo "Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test"; \ + exit 1; \ + fi + @echo "Running Rust Flash Attention tests for module: $(MODULE) (GPU $(TEST_GPU_DEVICE))" + @cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn $(MODULE) --lib -- --nocapture + # Test the Rust library (Go binding tests) test-binding: rust @$(LOG_TARGET) @@ -66,3 +90,14 @@ rust: ## Ensure Rust is installed and build the Rust library fi && \ echo "Building Rust library..." && \ cd candle-binding && cargo build --release' + +rust-flash-attn: ## Build Rust library with Flash Attention 2 (requires CUDA environment) + @$(LOG_TARGET) + @echo "Building Rust library with Flash Attention 2 (requires CUDA)..." + @if command -v nvcc >/dev/null 2>&1; then \ + echo "✅ nvcc found: $$(nvcc --version | grep release)"; \ + else \ + echo "❌ nvcc not found in PATH. Please configure CUDA environment."; \ + exit 1; \ + fi + @cd candle-binding && cargo build --release --features flash-attn