Skip to content

Commit fd8d669

Browse files
Merge pull request #10 from soran-ghaderi/gaussian-sketch-kernel
[PR] gaussian-sketching-kernel Key Changes Vectorized cuRAND Generation: The kernel's unit of work has been changed from 1 thread processing 1 element to 1 thread processing 4 elements. Switched from curand_normal() to the highly optimized curand_normal4(), which generates four floats simultaneously. Optimized cuRAND State: Replaced the curandState (XORWOW) with curandStatePhilox4_32_10_t. The Philox generator is a counter-based RNG specifically designed for high-throughput, parallel workloads and is required by curand_normal4. Reduced Initialization Overhead: By having each thread process four elements, the number of expensive curand_init() calls is reduced by 75%. For a 20-million-element matrix, this reduces initializations from 20 million to 5 million. Updated Launch Configuration: The kernel launch parameters in the main test application have been updated to reflect the new 1-thread-to-4-elements strategy, ensuring the correct number of threads are launched. Performance Profiling on a 4000x5000 matrix (20 million elements) demonstrates a dramatic improvement: Metric Before (Old Kernel) After (New Kernel) Speedup Kernel Execution Time 2.53 seconds 2.9 milliseconds ~872x
2 parents 3dd9159 + bb418d5 commit fd8d669

File tree

26 files changed

+1442
-935
lines changed

26 files changed

+1442
-935
lines changed

.github/workflows/ci.yml

Lines changed: 796 additions & 695 deletions
Large diffs are not rendered by default.

CMakeLists.txt

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ if(CMAKE_CUDA_COMPILER)
1111
enable_language(CUDA)
1212
set(CUDA_FOUND TRUE)
1313
message(STATUS "CUDA found and enabled")
14-
14+
1515
# Set CUDA architectures - this fixes the CMAKE_CUDA_ARCHITECTURES error
1616
# Supporting common GPU architectures: Pascal, Volta, Turing, Ampere, Ada Lovelace
1717
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80;86;89;90" CACHE STRING "CUDA architectures")
18-
18+
1919
# Define CUDA standard
2020
set(CMAKE_CUDA_STANDARD 17)
2121
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
22-
22+
2323
# CUDA specific settings
2424
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
2525
set(CMAKE_CUDA_RESOLVE_DEVICE_SYMBOLS ON)
@@ -50,14 +50,14 @@ if(WIN32 AND MSVC)
5050
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MD")
5151
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /MD")
5252
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /MD")
53-
53+
5454
# Ensure consistent iterator debug levels
5555
# Use level 2 for Debug builds and level 0 for Release builds
5656
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=2")
5757
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0")
5858
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0")
5959
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0")
60-
60+
6161
message(STATUS "Windows MSVC runtime library flags configured for consistent linking")
6262
endif()
6363

@@ -69,6 +69,7 @@ option(BUILD_DOCS "Enable building of documentation" ON)
6969
if(CUDA_FOUND)
7070
find_package(CUDAToolkit REQUIRED)
7171
endif()
72+
find_package(Threads REQUIRED)
7273

7374
# compile the library
7475
add_subdirectory(src)
@@ -81,27 +82,27 @@ add_subdirectory(app)
8182
include(CTest)
8283
if(BUILD_TESTING)
8384
find_package(Catch2 3 REQUIRED)
84-
85+
8586
# Try to include Catch2 discovery functions if available
8687
if(TARGET Catch2::Catch2WithMain)
8788
# Look for the Catch2 CMake module in common locations
88-
find_file(CATCH2_CMAKE_MODULE
89-
NAMES Catch2.cmake
90-
PATHS
89+
find_file(CATCH2_CMAKE_MODULE
90+
NAMES Catch2.cmake
91+
PATHS
9192
${Catch2_DIR}
9293
${Catch2_DIR}/../../../lib/cmake/Catch2
9394
/usr/local/lib/cmake/Catch2
9495
/usr/lib/cmake/Catch2
9596
NO_DEFAULT_PATH
9697
)
97-
98+
9899
if(CATCH2_CMAKE_MODULE)
99100
include(${CATCH2_CMAKE_MODULE})
100101
else()
101102
message(STATUS "Catch2 discovery module not found, using basic test registration")
102103
endif()
103104
endif()
104-
105+
105106
add_subdirectory(tests)
106107
endif()
107108

@@ -113,7 +114,7 @@ if(BUILD_PYTHON)
113114
# Add Python bindings
114115
find_package(pybind11 REQUIRED)
115116
# Compile the Pybind11 module
116-
pybind11_add_module(_cuRBLAS python/cuRBLAS/_cuRBLAS.cpp)
117+
pybind11_add_module(_cuRBLAS python/curblas/_curblas.cpp)
117118
target_link_libraries(_cuRBLAS PUBLIC cuRBLAS)
118119

119120
# Install the Python module shared library
@@ -157,6 +158,17 @@ install(
157158
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
158159
)
159160

161+
#include_directories(${CUDA_INCLUDE_DIRS})
162+
#target_include_directories(curblas
163+
# PUBLIC
164+
# ${CMAKE_CURRENT_SOURCE_DIR}/include
165+
# ${CUDA_INCLUDE_DIRS}
166+
#)
167+
#target_link_libraries(curblas
168+
# PUBLIC
169+
# ${CUDA_LIBRARIES}
170+
#)
171+
#include_directories(/usr/local/cuda/include)
160172
# This prints a summary of found dependencies
161173
include(FeatureSummary)
162174
feature_summary(WHAT ALL)

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Building cuRBLAS requires:
8787
```bash
8888
# Clone the repo
8989
git clone https://github.com/soran-ghaderi/cuRBLAS.git
90-
cd cuRBLAS
90+
cd curblas
9191

9292
# Create build directory
9393
mkdir build && cd build
@@ -124,19 +124,19 @@ pip install .
124124
Or install directly from PyPI (when available):
125125

126126
```bash
127-
pip install cuRBLAS
127+
pip install curblas
128128
```
129129

130130
## Usage Example
131131

132132
```c
133133
#include <cuRBLAS/curblas.h>
134134

135-
// Create cuRBLAS context
135+
// Create curblas context
136136
curblasHandle_t handle;
137137
curblasStatus_t status = curblasCreate(&handle);
138138
if (status != CURBLAS_STATUS_SUCCESS) {
139-
printf("Failed to create cuRBLAS handle: %s\n",
139+
printf("Failed to create curblas handle: %s\n",
140140
curblasGetStatusString(status));
141141
return -1;
142142
}
@@ -153,7 +153,7 @@ curblasSetSketchType(handle, CURBLAS_SKETCH_GAUSSIAN);
153153
// Get version information
154154
int version;
155155
curblasGetVersion(handle, &version);
156-
printf("cuRBLAS Version: %d\n", version);
156+
printf("curblas Version: %d\n", version);
157157

158158
// Note: Matrix operations like curblasRgemm are declared
159159
// in headers but not yet implemented
@@ -237,7 +237,7 @@ We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.
237237

238238
```bash
239239
git clone https://github.com/soran-ghaderi/cuRBLAS.git
240-
cd cuRBLAS
240+
cd curblas
241241
pip install -r requirements-dev.txt
242242
```
243243

app/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
add_executable(cuRBLAS_app cuRBLAS_app.cpp)
1+
add_executable(cuRBLAS_app ReduceSum.cu)
22
target_link_libraries(cuRBLAS_app PRIVATE cuRBLAS)
3+
4+
add_executable(generateGaussianSketch gaussian_sketch.cu)
5+
target_link_libraries(generateGaussianSketch PRIVATE cuRBLAS)

app/ReduceSum.cu

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include "curblas/curblas.cuh"
2+
#include "cuda_runtime.h"
3+
#include <iostream>
4+
#include <vector>
5+
#include <cooperative_groups.h>
6+
7+
8+
int main() {
9+
// Example usage of the reduceSum kernel
10+
int N = 1024;
11+
float h_input[N];
12+
13+
for (int i = 0; i < N; ++i) {
14+
h_input[i] = 1.0f;
15+
}
16+
17+
float *d_input, *d_output;
18+
float h_output;
19+
20+
21+
// Allocate device memory
22+
cudaMalloc((void**)&d_input, N * sizeof(float));
23+
// cudaMalloc((void**)&d_output, sizeof(float));
24+
25+
26+
// Launch the kernel
27+
int blockSize = 256;
28+
int numBlocks = (N + blockSize - 1) / blockSize;
29+
cudaMalloc((void**)&d_output, blockSize * sizeof(float));
30+
cudaMemcpy(d_input, h_input, N * sizeof(float), cudaMemcpyHostToDevice);
31+
cudaMemset(d_output, 0, blockSize * sizeof(float));
32+
33+
// curblas::reduceSum<<<numBlocks, blockSize, blockSize * sizeof(float)>>>(d_input, d_output, N);
34+
void *args[] = {&d_input, &d_output, &N};
35+
cudaLaunchCooperativeKernel((void*)curblas::reduceSum, dim3(numBlocks), dim3(blockSize), args, blockSize * sizeof(float));
36+
37+
38+
cudaError_t err = cudaGetLastError();
39+
if (err != cudaSuccess) {
40+
std::cerr << "Kernel launch failed: " << cudaGetErrorString(err) << std::endl;
41+
}
42+
43+
// Copy the result back to host
44+
cudaMemcpy(&h_output, d_output, sizeof(float), cudaMemcpyDeviceToHost);
45+
46+
47+
cudaMemcpy(&h_output, d_output, sizeof(float), cudaMemcpyDeviceToHost);
48+
49+
std::cout << "Final Sum (from GPU reduction with Cooperative Groups): " << h_output << std::endl;
50+
51+
52+
std::vector<float> h_partialSums(numBlocks);
53+
cudaMemcpy(h_partialSums.data(), d_output, numBlocks * sizeof(float), cudaMemcpyDeviceToHost);
54+
55+
float finalSum = 0.0f;
56+
for (int i = 0; i < numBlocks; ++i) {
57+
finalSum += h_partialSums[i];
58+
std::cout << "Partial sum from block " << i << ": " << h_partialSums[i] << std::endl;
59+
}
60+
61+
62+
std::cout << "Sum: " << finalSum << std::endl;
63+
// Clean up
64+
cudaFree(d_input);
65+
cudaFree(d_output);
66+
67+
return 0;
68+
}

app/cuRBLAS_app.cpp

Lines changed: 0 additions & 7 deletions
This file was deleted.

app/gaussian_sketch.cu

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include "curblas/curblas.cuh"
2+
//#include "curblas/curblas.h"
3+
//#include "curblas/curblas_types.h"
4+
#include <cuda_runtime.h>
5+
#include <cublas_v2.h>
6+
#include <iostream>
7+
#include <iomanip>
8+
#include <vector>
9+
10+
11+
12+
void printMatrix(const std::vector<float>& vec, int rows, int cols) {
13+
for (int i = 0; i < rows; ++i) {
14+
for (int j = 0; j < cols; ++j) {
15+
std::cout << std::fixed << std::setprecision(4) << std::setw(10) << vec[i * cols + j] << " ";
16+
}
17+
std::cout << std::endl;
18+
}
19+
}
20+
21+
int main() {
22+
int rows = 4000;
23+
int cols = 5000;
24+
25+
int totalElements = rows * cols;
26+
long long seed = 112345L;
27+
float scale = 1.0f;
28+
29+
std::cout << "Generating a" << rows << " x " << cols << " gaussian sketch matrix." << std::endl;
30+
31+
std::vector<float> h_sketch(totalElements);
32+
33+
float* d_sketch;
34+
cudaMalloc((void**)&d_sketch, totalElements * sizeof(float));
35+
36+
int blockSize = 256;
37+
38+
// int totalElements = rows * cols;
39+
int elementsPerThread = 4; //test
40+
int totalThreads = (totalElements + elementsPerThread - 1) / elementsPerThread;
41+
int numBlocks = (totalThreads + blockSize - 1) / blockSize;
42+
43+
// int numBlocks = (totalElements + blockSize - 1) / blockSize;
44+
45+
curblas::generateGaussianSketch<<<numBlocks, blockSize>>>(d_sketch, rows, cols, seed, scale);
46+
47+
cudaError err = cudaGetLastError();
48+
if (err != cudaSuccess) {
49+
std::cerr << "kernel launch failed: " << cudaGetErrorString(err) << std::endl;
50+
cudaFree(d_sketch);
51+
return -1;
52+
}
53+
54+
55+
cudaDeviceSynchronize();
56+
cudaStream_t stream;
57+
cudaStreamCreate(&stream);
58+
59+
// bring the data back:
60+
cudaMemcpyAsync(h_sketch.data(), d_sketch, totalElements * sizeof(float), cudaMemcpyDeviceToHost, stream);
61+
cudaStreamSynchronize(stream);
62+
cudaStreamDestroy(stream);
63+
64+
std::cout << "result:" << rows << 'x' << cols << std::endl;
65+
66+
// printMatrix(h_sketch, rows, cols);
67+
68+
cudaFree(d_sketch);
69+
70+
71+
}

docs/getting-started/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Currently cuRBLAS is in development. To install and use cuRBLAS:
1414

1515
```bash
1616
git clone https://github.com/cuRBLAS/cuRBLAS.git
17-
cd cuRBLAS
17+
cd curblas
1818
mkdir build && cd build
1919
cmake .. -DCMAKE_BUILD_TYPE=Release
2020
make -j$(nproc)

include/cuRBLAS/cuRBLAS.hpp

Lines changed: 0 additions & 16 deletions
This file was deleted.

include/curblas/curblas.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
#include "cuda_runtime.h"
3+
4+
namespace curblas {
5+
6+
__global__ void reduceSum(const float *input, float *output, int n);
7+
8+
__global__ void generateGaussianSketch(float *sketch, int rows, int cols, long long seed, float scale);
9+
10+
11+
} // namespace curblas

0 commit comments

Comments
 (0)