Skip to content

Commit a81c29c

Browse files
OneZero-Yrootfs
authored andcommitted
make CUDA and Flash Attention 2 optional features (#511)
Signed-off-by: OneZero-Y <[email protected]>
1 parent c83d109 commit a81c29c

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

candle-binding/Cargo.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ name = "candle_semantic_router"
1010
crate-type = ["staticlib", "cdylib"]
1111

1212
[features]
13-
default = ["cuda"]
13+
default = []
14+
# CUDA support (enables GPU acceleration)
1415
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
1516
# Flash Attention 2 support (requires CUDA and compatible GPU)
1617
# Enable with: cargo build --features flash-attn
1718
# Note: Requires CUDA Compute Capability >= 8.0 (Ampere or newer)
18-
flash-attn = ["candle-flash-attn"]
19+
flash-attn = ["cuda", "candle-flash-attn"]
1920

2021
[dependencies]
2122
anyhow = { version = "1", features = ["backtrace"] }
22-
candle-core = { version = "0.8.4", features = ["cuda"] }
23-
candle-nn = { version = "0.8.4", features = ["cuda"] }
24-
candle-transformers = { version = "0.8.4", features = ["cuda"] }
23+
candle-core = "0.8.4"
24+
candle-nn = "0.8.4"
25+
candle-transformers = "0.8.4"
2526
# Flash Attention 2 (optional, requires CUDA)
2627
# Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn
2728
candle-flash-attn = { version = "0.8.4", optional = true }

tools/make/rust.mk

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ test-rust: rust
1313
@echo "Running Rust unit tests (release mode, sequential on GPU $(TEST_GPU_DEVICE))"
1414
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --lib -- --test-threads=1 --nocapture
1515

16+
# Test Flash Attention(requires GPU and CUDA environment configured in system)
17+
# Note: Ensure CUDA paths are set in your shell environment (e.g., ~/.bashrc)
18+
# - PATH should include nvcc (e.g., /usr/local/cuda/bin)
19+
# - LD_LIBRARY_PATH should include CUDA libs (e.g., /usr/local/cuda/lib64, /usr/lib/wsl/lib for WSL)
20+
# - CUDA_HOME, CUDA_PATH should point to CUDA installation
21+
# Note: Uses TEST_GPU_DEVICE env var (default: 2) to avoid GPU 0/1 which may be busy
22+
test-rust-flash-attn: rust-flash-attn
23+
@$(LOG_TARGET)
24+
@echo "Running Rust unit tests with Flash Attention 2 (GPU $(TEST_GPU_DEVICE))"
25+
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn --lib -- --test-threads=1 --nocapture
26+
1627
# Test specific Rust module (with release optimization for performance)
1728
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test
1829
# Example: make test-rust-module MODULE=classifiers::lora::pii_lora_test::test_pii_lora_pii_lora_classifier_new
@@ -26,6 +37,19 @@ test-rust-module: rust
2637
@echo "Running Rust tests for module: $(MODULE) (release mode, GPU $(TEST_GPU_DEVICE))"
2738
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release $(MODULE) --lib -- --test-threads=1 --nocapture
2839

40+
# Test specific Flash Attention module (requires GPU and CUDA environment)
41+
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test
42+
# Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test::test_qwen3_embedding_forward
43+
test-rust-flash-attn-module: rust-flash-attn
44+
@$(LOG_TARGET)
45+
@if [ -z "$(MODULE)" ]; then \
46+
echo "Usage: make test-rust-flash-attn-module MODULE=<module_name>"; \
47+
echo "Example: make test-rust-flash-attn-module MODULE=model_architectures::embedding::qwen3_embedding_test"; \
48+
exit 1; \
49+
fi
50+
@echo "Running Rust Flash Attention tests for module: $(MODULE) (GPU $(TEST_GPU_DEVICE))"
51+
@cd candle-binding && CUDA_VISIBLE_DEVICES=$(TEST_GPU_DEVICE) cargo test --release --features flash-attn $(MODULE) --lib -- --nocapture
52+
2953
# Test the Rust library (conditionally use rust-ci in CI environments)
3054
test-binding: $(if $(CI),rust-ci,rust) ## Run Go tests with the Rust static library
3155
@$(LOG_TARGET)
@@ -83,3 +107,14 @@ rust-ci: ## Build the Rust library without CUDA support (for GitHub Actions/CI)
83107
fi && \
84108
echo "Building Rust library without CUDA (CPU-only)..." && \
85109
cd candle-binding && cargo build --release --no-default-features'
110+
111+
rust-flash-attn: ## Build Rust library with Flash Attention 2 (requires CUDA environment)
112+
@$(LOG_TARGET)
113+
@echo "Building Rust library with Flash Attention 2 (requires CUDA)..."
114+
@if command -v nvcc >/dev/null 2>&1; then \
115+
echo "✅ nvcc found: $$(nvcc --version | grep release)"; \
116+
else \
117+
echo "❌ nvcc not found in PATH. Please configure CUDA environment."; \
118+
exit 1; \
119+
fi
120+
@cd candle-binding && cargo build --release --features flash-attn

0 commit comments

Comments
 (0)