You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/content/blog/building-pytorch-in-rust-part1.mdx
+16-13Lines changed: 16 additions & 13 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,13 +2,14 @@
2
2
title: "Building PyTorch in Rust to Force Myself to Learn Rust"
3
3
description: "Part 1: Tensors, Strides, NonNull, and aliases"
4
4
pubDate: "February 23, 2026"
5
+
useOutline: false
5
6
---
6
7
7
-
[PyTorch](https://pytorch.org/) is pretty cool. In the midst of the current AI craze, it's easy to lose sight of the really cool lower-level components that make modern ML possible. I also think it's important to learn how these frameworks operate under the hood because it might enable you to discover performance optimizations, understand ML frameworks' idiosyncrasies and differences, or just to satisfy your own curiosity. For this reason, I decided to reimplement a basic version of PyTorch, with support for strided tensors on CUDA and Cpu, and autograd, the automatic differentiation core that enables users to easily run backpropagation. Time permitting, I'll even look at providing a basic python interface.
8
+
## Why?
8
9
9
-
Additionally, I've been looking for a good opportunity to learn more Rust and make something cool with it. Rust's excellent tooling, macro system, and speed make it an excellent candidate to replace the C++ backend in this project. Its thorns (\*cough\* borrow checker \*cough\*) provide an interesting challenge. This tutorial will be both an intermediate level explanation of how PyTorch works and an intermediate exploration of the Rust programming language. I'm certainly no expert in either, but I hope this tutorial provides some interesting breadcrumbs to someone following a similar path.
10
+
[PyTorch](https://pytorch.org/) is a phenomenal piece of software. In the midst of the current AI craze, it's easy to lose sight of the really cool lower-level components that make modern ML possible. I decided to reimplement a basic version of PyTorch with support for strided tensors on CUDA and CPU, and autograd, the automatic differentiation core that enables users to easily run backpropagation.
10
11
11
-
One last note: while I used GenAI tools to help with making visuals, writing docs, and debugging, I refrained from using any agentic coding tool (Claude Code) for this project.
12
+
I've also been looking for a good opportunity to learn more Rust and make something cool with it. Rust's excellent tooling, macro system, and speed make it an excellent candidate to replace the C++ backend in this project. Its thorns (\*cough\* borrow checker \*cough\*) provide an interesting challenge. This tutorial will be both an intermediate level explanation of how PyTorch works and an intermediate exploration of the Rust programming language.
12
13
13
14
---
14
15
@@ -27,7 +28,7 @@ The stride is what makes this, to me, quite magical. Let's use an example of a t
27
28
28
29
Then, we just use the stride array to properly index through this flat `storage` array. This makes sense if you've ever worked with 2d-arrays in almost any programming language.
29
30
30
-
What's really cool is that you can manipulate the stride array and create completely different views on the same flat array. Consider `tensor.t()`, which returns the transpose of the tensor. We _could_ create a new backing array and copy them over to create this matrix. Or, we could just copy a pointer to the same array and modify the stride array as illustrated. Creating new views from indexes requires creating new offsets and shape vectors rather than fully allocating a new Tensor.
31
+
We can manipulate the stride array and create different views on the same flat array. Consider `tensor.t()`, which returns the transpose of the tensor. We _could_ create a new backing array and copy them over to create this matrix. Or, we could just copy a pointer to the same array and modify the stride array as illustrated. Creating new views from indexes requires creating new offsets and shape vectors rather than fully allocating a new Tensor.
31
32
32
33

33
34
@@ -36,11 +37,11 @@ Consider some operations that are actually quite simple like [`.unsqueeze()`](ht
36
37
### `NonNull<u8>` and `Rc`
37
38
To accomplish this, we will decouple the Tensor from the underlying memory. We'll need to build two structs: one for maintaining the underlying raw storage, and the actual tensor holding the stride, shape, and type information. We want to enable many tensors to point to the same raw storage container. Rust's alias rules disallow multiple `&mut` references to the same struct. More formally, the [Rust aliasing rules](https://doc.rust-lang.org/nomicon/aliasing.html) allow a single mutable reference or multiple immutable references, exclusively. This type of shared, multiple ownership isn't allowed by aliasing alone.
38
39
39
-
The solution is to use a reference counter or [`Rc<T>`](https://doc.rust-lang.org/book/ch15-04-rc.html), which brings back reference counting for a specified object, and automatically drops the inner struct once the reference count hits zero.
40
+
The solution is to use a reference counter or [`Rc<T>`](https://doc.rust-lang.org/book/ch15-04-rc.html), which brings back reference counting for a specified object, and automatically drops the inner struct once the reference count hits zero.
40
41
41
42

42
43
43
-
Additionally, we'll use a [`NonNull<u8>`](https://doc.rust-lang.org/beta/std/ptr/struct.NonNull.html) instead of a `*mut T`, which enables some nice compiler optimizations at the expense of some type shenanigans. Because we will treat the underlying storage container as untyped storage with raw bytes, we will have to do some `unsafe` casting.
44
+
Additionally, we'll use a [`NonNull<u8>`](https://doc.rust-lang.org/beta/std/ptr/struct.NonNull.html) instead of a `*mut u8`, which enables some nice compiler optimizations at the expense of some type shenanigans. Because we will treat the underlying storage container as untyped storage with raw bytes, we will have to do some `unsafe` casting.
44
45
45
46
46
47
```rust
@@ -54,12 +55,12 @@ pub struct UntypedStorage {
54
55
55
56
#[derive(Default, Debug)]
56
57
pubstructTensor {
57
-
shape:Vec<usize>,
58
-
stride:Vec<usize>,
59
-
offset:usize,
60
-
dtype:DataType,
61
-
_storage:Rc<UntypedStorage>,
62
-
_dispatch_keys:DispatchKeySet,
58
+
shape:Vec<usize>,
59
+
stride:Vec<usize>,
60
+
offset:usize,
61
+
dtype:DataType,
62
+
_storage:Rc<UntypedStorage>,
63
+
//...
63
64
}
64
65
```
65
66
@@ -95,7 +96,7 @@ impl UntypedStorage {
95
96
}
96
97
```
97
98
98
-
We also need to implement deallocation behavior that triggers when the storage is dropped by its `Rc`. In rust, implementing destruction behavior comes from implementing the `Drop` trait. Side note: If you're building a storage container and the inner type has properties that require recursive dropping, you'll have to call that here. I'm making the active assumption that we will not need that:
99
+
We also need to implement deallocation behavior that triggers when the storage is dropped by its `Rc`. In rust, implementing destruction behavior comes from implementing the `Drop` trait. Side note: If you're building a storage container and the inner type has properties that require recursive dropping, you'll have to call that here. I'm making the assumption that we will not need that:
99
100
100
101
```rust
101
102
implDropforUntypedStorage {
@@ -111,6 +112,8 @@ impl Drop for UntypedStorage {
111
112
}
112
113
```
113
114
115
+
For now we will use `Rc<T>` over its thread-safe but slower sibling [`Arc<T>`](https://doc.rust-lang.org/std/sync/struct.Arc.html), but a thread-safe implementation will need to use `Arc<T>`, lest the compiler prevents you from `Send`ing your tensors across threads.
116
+
114
117
Side note: this is similar to how [`Vec` works under the hood](https://doc.rust-lang.org/src/alloc/raw_vec/mod.rs.html#85). It uses a `Unique` wrapper, which has slightly stronger guarantees than a `NonNull`.
115
118
116
119
As we see, using the allocation apis requires `unsafe`, which allows us to momentarily break the compile-time memory safety guarantees that regular rust provides. As we continue to build functionality, we'll have to be careful to maintain the memory-safety features and contracts through our implementation.
0 commit comments