Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions candle-binding/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ crate-type = ["staticlib", "cdylib"]

[features]
default = []
# CUDA support (enables GPU acceleration)
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
# Flash Attention 2 support (requires CUDA and compatible GPU)
# Enable with: cargo build --features flash-attn
# Note: Requires CUDA Compute Capability >= 8.0 (Ampere or newer)
flash-attn = ["candle-flash-attn"]
flash-attn = ["cuda", "candle-flash-attn"]

[dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-core = { version = "0.8.4", features = ["cuda"] }
candle-nn = { version = "0.8.4", features = ["cuda"] }
candle-transformers = { version = "0.8.4", features = ["cuda"] }
candle-core = "0.8.4"
candle-nn = "0.8.4"
candle-transformers = "0.8.4"
# Flash Attention 2 (optional, requires CUDA)
# Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn
candle-flash-attn = { version = "0.8.4", optional = true }
Expand Down
35 changes: 35 additions & 0 deletions tools/make/rust.mk
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ test-rust: rust
@echo "Running Rust unit tests (release mode, sequential on GPU $(TEST_GPU_DEVICE))"
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --lib -- --test-threads=1 --nocapture

# Test Flash Attention(requires GPU and CUDA environment configured in system)
# Note: Ensure CUDA paths are set in your shell environment (e.g., ~/.bashrc)
# - PATH should include nvcc (e.g., /usr/local/cuda/bin)
# - LD_LIBRARY_PATH should include CUDA libs (e.g., /usr/local/cuda/lib64, /usr/lib/wsl/lib for WSL)
# - CUDA_HOME, CUDA_PATH should point to CUDA installation
# Note: Uses TEST_GPU_DEVICE env var (default: 2) to avoid GPU 0/1 which may be busy
test-rust-flash-attn: rust-flash-attn
@$(LOG_TARGET)
@echo "Running Rust unit tests with Flash Attention 2 (GPU $(TEST_GPU_DEVICE))"
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn --lib -- --test-threads=1 --nocapture

# Test specific Rust module (with release optimization for performance)
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test::test_pii_lora_pii_lora_classifier_new
Expand All @@ -26,6 +37,19 @@ test-rust-module: rust
@echo "Running Rust tests for module: $(MODULE) (release mode, GPU $(TEST_GPU_DEVICE))"
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release $(MODULE) --lib -- --test-threads=1 --nocapture

# Test specific Flash Attention module (requires GPU and CUDA environment)
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test::test_qwen3_embedding_forward
test-rust-flash-attn-module: rust-flash-attn
@$(LOG_TARGET)
@if [ -z "$(MODULE)" ]; then \
echo "Usage: make test-rust-flash-attn-module MODULE=<module_name>"; \
echo "Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test"; \
exit 1; \
fi
@echo "Running Rust Flash Attention tests for module: $(MODULE) (GPU $(TEST_GPU_DEVICE))"
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn $(MODULE) --lib -- --nocapture

# Test the Rust library (Go binding tests)
test-binding: rust
@$(LOG_TARGET)
Expand Down Expand Up @@ -66,3 +90,14 @@ rust: ## Ensure Rust is installed and build the Rust library
fi && \
echo "Building Rust library..." && \
cd candle-binding && cargo build --release'

rust-flash-attn: ## Build Rust library with Flash Attention 2 (requires CUDA environment)
@$(LOG_TARGET)
@echo "Building Rust library with Flash Attention 2 (requires CUDA)..."
@if command -v nvcc >/dev/null 2>&1; then \
echo "✅ nvcc found: $$(nvcc --version | grep release)"; \
else \
echo "❌ nvcc not found in PATH. Please configure CUDA environment."; \
exit 1; \
fi
@cd candle-binding && cargo build --release --features flash-attn
Loading