Skip to content

Commit f451b1c

Browse files
Feat/persistent memory (#947)
1 parent 16696aa commit f451b1c

File tree

27 files changed

+720
-234
lines changed

27 files changed

+720
-234
lines changed

crates/cubecl-cpu/src/compute/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ pub struct CpuServer {
2929
}
3030

3131
impl CpuServer {
32-
pub fn new(ctx: CpuContext) -> Self {
32+
pub fn new(ctx: CpuContext, logger: Arc<ServerLogger>) -> Self {
3333
Self {
34-
logger: Arc::new(ServerLogger::default()),
34+
logger,
3535
scheduler: Scheduler::default(),
3636
ctx,
3737
}

crates/cubecl-cpu/src/runtime.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ use cubecl_core::{
77
};
88
use cubecl_runtime::{
99
ComputeRuntime, DeviceProperties,
10-
memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
10+
logging::ServerLogger,
11+
memory_management::{
12+
HardwareProperties, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions,
13+
},
1114
storage::BytesStorage,
1215
};
1316
use cubecl_std::tensor::is_contiguous;
@@ -45,6 +48,7 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
4548
.cgroup_limits()
4649
.map(|g| g.total_memory)
4750
.unwrap_or(system.total_memory()) as usize;
51+
let logger = cubecl_common::stub::Arc::new(ServerLogger::default());
4852

4953
let topology = HardwareProperties {
5054
plane_size_min: 1,
@@ -66,8 +70,13 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
6670
alignment: ALIGNMENT,
6771
};
6872

69-
let memory_management =
70-
MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
73+
let memory_management = MemoryManagement::from_configuration(
74+
storage,
75+
&mem_properties,
76+
options.memory_config,
77+
logger.clone(),
78+
MemoryManagementOptions::new("test"),
79+
);
7180
let mut device_props = DeviceProperties::new(
7281
Default::default(),
7382
mem_properties,
@@ -77,7 +86,7 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
7786
register_supported_types(&mut device_props);
7887

7988
let ctx = CpuContext::new(memory_management);
80-
let server = CpuServer::new(ctx);
89+
let server = CpuServer::new(ctx, logger);
8190
ComputeClient::new(Channel::new(server), device_props, ())
8291
}
8392

crates/cubecl-cuda/src/compute/server.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,14 @@ impl CudaServer {
464464
log::info!("Peer data transfer not available for device {device_id}");
465465
}
466466

467+
let logger = Arc::new(ServerLogger::default());
467468
Self {
468469
mem_alignment,
469470
ctx,
470471
peer_activated,
471472
streams: MultiStream::new(
472-
Arc::new(ServerLogger::default()),
473-
CudaStreamBackend::new(mem_props, mem_config, mem_alignment),
473+
logger.clone(),
474+
CudaStreamBackend::new(mem_props, mem_config, mem_alignment, logger),
474475
max_streams,
475476
),
476477
}

crates/cubecl-cuda/src/compute/stream.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::compute::{
24
storage::{
35
cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage},
@@ -7,7 +9,10 @@ use crate::compute::{
79
};
810
use cubecl_core::MemoryConfiguration;
911
use cubecl_runtime::{
10-
memory_management::{MemoryDeviceProperties, MemoryManagement},
12+
logging::ServerLogger,
13+
memory_management::{
14+
MemoryAllocationMode, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions,
15+
},
1116
stream::EventStreamBackend,
1217
};
1318

@@ -23,6 +28,7 @@ pub struct CudaStreamBackend {
2328
mem_props: MemoryDeviceProperties,
2429
mem_config: MemoryConfiguration,
2530
mem_alignment: usize,
31+
logger: Arc<ServerLogger>,
2632
}
2733

2834
impl EventStreamBackend for CudaStreamBackend {
@@ -36,8 +42,13 @@ impl EventStreamBackend for CudaStreamBackend {
3642
.expect("Can create a new stream.");
3743

3844
let storage = GpuStorage::new(self.mem_alignment, stream);
39-
let memory_management_gpu =
40-
MemoryManagement::from_configuration(storage, &self.mem_props, self.mem_config.clone());
45+
let memory_management_gpu = MemoryManagement::from_configuration(
46+
storage,
47+
&self.mem_props,
48+
self.mem_config.clone(),
49+
self.logger.clone(),
50+
MemoryManagementOptions::new("Main GPU Memory"),
51+
);
4152
// We use the same page size and memory pools configuration for CPU pinned memory, since we
4253
// expect the CPU to have at least the same amount of RAM as GPU memory.
4354
let memory_management_cpu = MemoryManagement::from_configuration(
@@ -47,6 +58,8 @@ impl EventStreamBackend for CudaStreamBackend {
4758
alignment: PINNED_MEMORY_ALIGNMENT as u64,
4859
},
4960
self.mem_config.clone(),
61+
self.logger.clone(),
62+
MemoryManagementOptions::new("Pinned CPU Memory").mode(MemoryAllocationMode::Auto),
5063
);
5164

5265
Stream {

crates/cubecl-hip/src/compute/server.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,13 @@ impl HipServer {
252252
let config = GlobalConfig::get();
253253
let max_streams = config.streaming.max_streams;
254254

255+
let logger = Arc::new(ServerLogger::default());
255256
Self {
256257
ctx,
257258
mem_alignment,
258259
streams: MultiStream::new(
259-
Arc::new(ServerLogger::default()),
260-
HipStreamBackend::new(mem_props, mem_config, mem_alignment),
260+
logger.clone(),
261+
HipStreamBackend::new(mem_props, mem_config, mem_alignment, logger),
261262
max_streams,
262263
),
263264
}

crates/cubecl-hip/src/compute/stream.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
use std::sync::Arc;
2+
13
use cubecl_core::MemoryConfiguration;
24
use cubecl_hip_sys::HIP_SUCCESS;
35
use cubecl_runtime::{
4-
memory_management::{MemoryDeviceProperties, MemoryManagement},
6+
logging::ServerLogger,
7+
memory_management::{
8+
MemoryAllocationMode, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions,
9+
},
510
stream::EventStreamBackend,
611
};
712

@@ -23,6 +28,7 @@ pub struct HipStreamBackend {
2328
mem_props: MemoryDeviceProperties,
2429
mem_config: MemoryConfiguration,
2530
mem_alignment: usize,
31+
logger: Arc<ServerLogger>,
2632
}
2733

2834
impl EventStreamBackend for HipStreamBackend {
@@ -37,8 +43,13 @@ impl EventStreamBackend for HipStreamBackend {
3743
stream
3844
};
3945
let storage = GpuStorage::new(self.mem_alignment);
40-
let memory_management_gpu =
41-
MemoryManagement::from_configuration(storage, &self.mem_props, self.mem_config.clone());
46+
let memory_management_gpu = MemoryManagement::from_configuration(
47+
storage,
48+
&self.mem_props,
49+
self.mem_config.clone(),
50+
self.logger.clone(),
51+
MemoryManagementOptions::new("Main GPU Memory"),
52+
);
4253
// We use the same page size and memory pools configuration for CPU pinned memory, since we
4354
// expect the CPU to have at least the same amount of RAM as GPU memory.
4455
let memory_management_cpu = MemoryManagement::from_configuration(
@@ -48,6 +59,8 @@ impl EventStreamBackend for HipStreamBackend {
4859
alignment: PINNED_MEMORY_ALIGNMENT as u64,
4960
},
5061
self.mem_config.clone(),
62+
self.logger.clone(),
63+
MemoryManagementOptions::new("Pinned CPU Memory").mode(MemoryAllocationMode::Auto),
5164
);
5265

5366
Stream {

crates/cubecl-runtime/benches/dynamic.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use std::collections::LinkedList;
1+
use std::{collections::LinkedList, sync::Arc};
22

33
use cubecl_runtime::{
4-
memory_management::{MemoryConfiguration, MemoryDeviceProperties, MemoryManagement},
4+
logging::ServerLogger,
5+
memory_management::{
6+
MemoryConfiguration, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions,
7+
},
58
storage::BytesStorage,
69
};
710

@@ -15,7 +18,14 @@ fn main() {
1518
max_page_size: 2048 * MB,
1619
alignment: 32,
1720
};
18-
let mut mm = MemoryManagement::from_configuration(storage, &mem_props, config);
21+
let logger = Arc::new(ServerLogger::default());
22+
let mut mm = MemoryManagement::from_configuration(
23+
storage,
24+
&mem_props,
25+
config,
26+
logger,
27+
MemoryManagementOptions::new("test"),
28+
);
1929
let mut handles = LinkedList::new();
2030
for _ in 0..100 * 2048 {
2131
if handles.len() >= 4000 {

crates/cubecl-runtime/src/client.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,27 +508,26 @@ where
508508
self.channel.allocation_mode(mode, self.stream_id())
509509
}
510510

511-
/// Use a static memory strategy to execute the provided function.
511+
/// Use a persistent memory strategy to execute the provided function.
512512
///
513513
/// # Notes
514514
///
515-
/// Using that memory strategy is beneficial for weights loading and similar workflows.
516-
/// However make sure to call [Self::memory_cleanup()] if you want to free the allocated
517-
/// memory.
518-
pub fn memory_static_allocation<Input, Output, Func: Fn(Input) -> Output>(
515+
/// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
516+
/// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
517+
pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
519518
&self,
520519
input: Input,
521520
func: Func,
522521
) -> Output {
523522
// We use the same profiling lock to make sure no other task is currently using the current
524-
// device. Meaning that the current static memory strategy will only be used for the
523+
// device. Meaning that the current persistent memory strategy will only be used for the
525524
// provided function.
526525

527526
#[cfg(multi_threading)]
528527
let stream_id = self.profile_acquire();
529528

530529
self.channel
531-
.allocation_mode(MemoryAllocationMode::Static, self.stream_id());
530+
.allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
532531
let output = func(input);
533532
self.channel
534533
.allocation_mode(MemoryAllocationMode::Auto, self.stream_id());

crates/cubecl-runtime/src/config/base.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::config::memory::MemoryConfig;
12
use crate::config::streaming::StreamingConfig;
23

34
use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
@@ -26,6 +27,10 @@ pub struct GlobalConfig {
2627
/// Configuration for streaming settings.
2728
#[serde(default)]
2829
pub streaming: StreamingConfig,
30+
31+
/// Configuration for memory settings.
32+
#[serde(default)]
33+
pub memory: MemoryConfig,
2934
}
3035

3136
impl GlobalConfig {

crates/cubecl-runtime/src/config/logger.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::GlobalConfig;
22
use crate::config::{
3-
autotune::AutotuneLogLevel, compilation::CompilationLogLevel, profiling::ProfilingLogLevel,
4-
streaming::StreamingLogLevel,
3+
autotune::AutotuneLogLevel, compilation::CompilationLogLevel, memory::MemoryLogLevel,
4+
profiling::ProfilingLogLevel, streaming::StreamingLogLevel,
55
};
66
use alloc::{string::ToString, sync::Arc, vec::Vec};
77
use core::fmt::Display;
@@ -118,6 +118,9 @@ pub struct Logger {
118118
/// Indices of loggers used for streaming logging.
119119
streaming_index: Vec<usize>,
120120

121+
/// Indices of loggers used for memory logging.
122+
memory_index: Vec<usize>,
123+
121124
/// Global configuration for logging settings.
122125
pub config: Arc<GlobalConfig>,
123126
}
@@ -142,6 +145,7 @@ impl Logger {
142145
let mut profiling_index = Vec::new();
143146
let mut autotune_index = Vec::new();
144147
let mut streaming_index = Vec::new();
148+
let mut memory_index = Vec::new();
145149

146150
#[derive(Hash, PartialEq, Eq)]
147151
enum LoggerId {
@@ -281,12 +285,25 @@ impl Logger {
281285
)
282286
}
283287

288+
if let MemoryLogLevel::Disabled = config.memory.logger.level {
289+
} else {
290+
register_logger(
291+
&config.memory.logger,
292+
config.memory.logger.append,
293+
config.memory.logger.log,
294+
&mut memory_index,
295+
&mut loggers,
296+
&mut logger2index,
297+
)
298+
}
299+
284300
Self {
285301
loggers,
286302
compilation_index,
287303
profiling_index,
288304
autotune_index,
289305
streaming_index,
306+
memory_index,
290307
config,
291308
}
292309
}
@@ -305,6 +322,20 @@ impl Logger {
305322
}
306323
}
307324

325+
/// Logs a message for memory, directing it to all configured streaming loggers.
326+
pub fn log_memory<S: Display>(&mut self, msg: &S) {
327+
let length = self.memory_index.len();
328+
if length > 1 {
329+
let msg = msg.to_string();
330+
for i in 0..length {
331+
let index = self.memory_index[i];
332+
self.log(&msg, index)
333+
}
334+
} else if let Some(index) = self.memory_index.first() {
335+
self.log(&msg, *index)
336+
}
337+
}
338+
308339
/// Logs a message for compilation, directing it to all configured compilation loggers.
309340
pub fn log_compilation<S: Display>(&mut self, msg: &S) {
310341
let length = self.compilation_index.len();

0 commit comments

Comments
 (0)