Skip to content

Commit 1f687fa

Browse files
committed
bump cudarc
1 parent 74916fe commit 1f687fa

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ clap = { version = "~3.1", features = [ "cargo" ] }
138138
colorous = "1.0.5"
139139
core_affinity = "0.8.0"
140140
criterion = "0.6"
141-
cudarc = { version = "0.16.4", features = ["dynamic-loading", "cuda-12060", "f16"] }
141+
cudarc = { version = "0.17", features = ["dynamic-loading", "cuda-version-from-build-system", "f16"] }
142142
derive-new = "0.5.9"
143143
dinghy-test = "0.6"
144144
downcast-rs = "1.2.0"

cuda/src/kernels/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ mod unary;
99
mod utils;
1010

1111
use crate::ops::GgmlQuantQ81Fact;
12+
use std::mem::transmute;
13+
1214
use crate::tensor::{CudaBuffer, CudaTensor};
1315
use anyhow::{bail, ensure};
1416
pub use binary::BinOps;
@@ -159,7 +161,9 @@ pub fn get_sliced_cuda_view_mut(
159161
len: usize,
160162
) -> TractResult<CudaViewMut<'_, u8>> {
161163
ensure!(offset + len <= t.len() * t.datum_type().size_of());
162-
let mut buffer = t.device_buffer().downcast_ref::<CudaBuffer>().unwrap();
164+
let buffer: &CudaBuffer = t.device_buffer().downcast_ref::<CudaBuffer>().unwrap();
163165
let offset = t.buffer_offset::<usize>() + offset;
164-
Ok(buffer.as_view_mut().slice_mut(offset..(offset + len)))
166+
let ptr: *const CudaBuffer = buffer;
167+
let mut_buffer: &mut CudaBuffer = unsafe { (ptr as *mut CudaBuffer).as_mut().unwrap() };
168+
Ok(mut_buffer.as_view_mut().slice_mut(offset..(offset + len)))
165169
}

cuda/src/tensor.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::ops::Deref;
1+
use std::ops::{Deref, DerefMut};
22

33
use cudarc::driver::{CudaSlice, DevicePtr};
44
use tract_core::internal::tract_smallvec::ToSmallVec;
@@ -30,6 +30,12 @@ impl Deref for CudaBuffer {
3030
}
3131
}
3232

33+
impl DerefMut for CudaBuffer {
34+
fn deref_mut(&mut self) -> &mut Self::Target {
35+
&mut self.inner
36+
}
37+
}
38+
3339
#[derive(Clone)]
3440
pub struct CudaTensor {
3541
buffer: Arc<CudaBuffer>,

0 commit comments

Comments
 (0)