Skip to content

Commit 3ec40fe

Browse files
authored
make CUDA and Flash Attention 2 optional features (#511)
Signed-off-by: OneZero-Y <[email protected]>
1 parent 1a82152 commit 3ec40fe

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

candle-binding/Cargo.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@ crate-type = ["staticlib", "cdylib"]
1111

1212
[features]
1313
default = []
14+
# CUDA support (enables GPU acceleration)
15+
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
1416
# Flash Attention 2 support (requires CUDA and compatible GPU)
1517
# Enable with: cargo build --features flash-attn
1618
# Note: Requires CUDA Compute Capability >= 8.0 (Ampere or newer)
17-
flash-attn = ["candle-flash-attn"]
19+
flash-attn = ["cuda", "candle-flash-attn"]
1820

1921
[dependencies]
2022
anyhow = { version = "1", features = ["backtrace"] }
21-
candle-core = { version = "0.8.4", features = ["cuda"] }
22-
candle-nn = { version = "0.8.4", features = ["cuda"] }
23-
candle-transformers = { version = "0.8.4", features = ["cuda"] }
23+
candle-core = "0.8.4"
24+
candle-nn = "0.8.4"
25+
candle-transformers = "0.8.4"
2426
# Flash Attention 2 (optional, requires CUDA)
2527
# Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn
2628
candle-flash-attn = { version = "0.8.4", optional = true }

tools/make/rust.mk

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ test-rust: rust
1313
@echo "Running Rust unit tests (release mode, sequential on GPU $(TEST_GPU_DEVICE))"
1414
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --lib -- --test-threads=1 --nocapture
1515

16+
# Test Flash Attention(requires GPU and CUDA environment configured in system)
17+
# Note: Ensure CUDA paths are set in your shell environment (e.g., ~/.bashrc)
18+
# - PATH should include nvcc (e.g., /usr/local/cuda/bin)
19+
# - LD_LIBRARY_PATH should include CUDA libs (e.g., /usr/local/cuda/lib64, /usr/lib/wsl/lib for WSL)
20+
# - CUDA_HOME, CUDA_PATH should point to CUDA installation
21+
# Note: Uses TEST_GPU_DEVICE env var (default: 2) to avoid GPU 0/1 which may be busy
22+
test-rust-flash-attn: rust-flash-attn
23+
@$(LOG_TARGET)
24+
@echo "Running Rust unit tests with Flash Attention 2 (GPU $(TEST_GPU_DEVICE))"
25+
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn --lib -- --test-threads=1 --nocapture
26+
1627
# Test specific Rust module (with release optimization for performance)
1728
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test
1829
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test::test_pii_lora_pii_lora_classifier_new
@@ -26,6 +37,19 @@ test-rust-module: rust
2637
@echo "Running Rust tests for module: $(MODULE) (release mode, GPU $(TEST_GPU_DEVICE))"
2738
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release $(MODULE) --lib -- --test-threads=1 --nocapture
2839

40+
# Test specific Flash Attention module (requires GPU and CUDA environment)
41+
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test
42+
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test::test_qwen3_embedding_forward
43+
test-rust-flash-attn-module: rust-flash-attn
44+
@$(LOG_TARGET)
45+
@if [ -z "$(MODULE)" ]; then \
46+
echo "Usage: make test-rust-flash-attn-module MODULE=<module_name>"; \
47+
echo "Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test"; \
48+
exit 1; \
49+
fi
50+
@echo "Running Rust Flash Attention tests for module: $(MODULE) (GPU $(TEST_GPU_DEVICE))"
51+
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn $(MODULE) --lib -- --nocapture
52+
2953
# Test the Rust library (Go binding tests)
3054
test-binding: rust
3155
@$(LOG_TARGET)
@@ -66,3 +90,14 @@ rust: ## Ensure Rust is installed and build the Rust library
6690
fi && \
6791
echo "Building Rust library..." && \
6892
cd candle-binding && cargo build --release'
93+
94+
rust-flash-attn: ## Build Rust library with Flash Attention 2 (requires CUDA environment)
95+
@$(LOG_TARGET)
96+
@echo "Building Rust library with Flash Attention 2 (requires CUDA)..."
97+
@if command -v nvcc >/dev/null 2>&1; then \
98+
echo "✅ nvcc found: $$(nvcc --version | grep release)"; \
99+
else \
100+
echo "❌ nvcc not found in PATH. Please configure CUDA environment."; \
101+
exit 1; \
102+
fi
103+
@cd candle-binding && cargo build --release --features flash-attn

0 commit comments

Comments
 (0)