Skip to content

Commit 0d50a18

Browse files
committed
edit part1
1 parent c134360 commit 0d50a18

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/content/blog/building-pytorch-in-rust-part1.mdx

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
title: "Building PyTorch in Rust to Force Myself to Learn Rust"
33
description: "Part 1: Tensors, Strides, NonNull, and aliases"
44
pubDate: "February 23, 2026"
5+
useOutline: false
56
---
67

7-
[PyTorch](https://pytorch.org/) is pretty cool. In the midst of the current AI craze, it's easy to lose sight of the really cool lower-level components that make modern ML possible. I also think it's important to learn how these frameworks operate under the hood because it might enable you to discover performance optimizations, understand ML frameworks' idiosyncrasies and differences, or just to satisfy your own curiosity. For this reason, I decided to reimplement a basic version of PyTorch, with support for strided tensors on CUDA and Cpu, and autograd, the automatic differentiation core that enables users to easily run backpropagation. Time permitting, I'll even look at providing a basic python interface.
8+
## Why?
89

9-
Additionally, I've been looking for a good opportunity to learn more Rust and make something cool with it. Rust's excellent tooling, macro system, and speed make it an excellent candidate to replace the C++ backend in this project. Its thorns (\*cough\* borrow checker \*cough\*) provide an interesting challenge. This tutorial will be both an intermediate level explanation of how PyTorch works and an intermediate exploration of the Rust programming language. I'm certainly no expert in either, but I hope this tutorial provides some interesting breadcrumbs to someone following a similar path.
10+
[PyTorch](https://pytorch.org/) is a phenomenal piece of software. In the midst of the current AI craze, it's easy to lose sight of the really cool lower-level components that make modern ML possible. I decided to reimplement a basic version of PyTorch with support for strided tensors on CUDA and CPU, and autograd, the automatic differentiation core that enables users to easily run backpropagation.
1011

11-
One last note: while I used GenAI tools to help with making visuals, writing docs, and debugging, I refrained from using any agentic coding tool (Claude Code) for this project.
12+
I've also been looking for a good opportunity to learn more Rust and make something cool with it. Rust's excellent tooling, macro system, and speed make it an excellent candidate to replace the C++ backend in this project. Its thorns (\*cough\* borrow checker \*cough\*) provide an interesting challenge. This tutorial will be both an intermediate level explanation of how PyTorch works and an intermediate exploration of the Rust programming language.
1213

1314
---
1415

@@ -27,7 +28,7 @@ The stride is what makes this, to me, quite magical. Let's use an example of a t
2728

2829
Then, we just use the stride array to properly index through this flat `storage` array. This makes sense if you've ever worked with 2d-arrays in almost any programming language.
2930

30-
What's really cool is that you can manipulate the stride array and create completely different views on the same flat array. Consider `tensor.t()`, which returns the transpose of the tensor. We _could_ create a new backing array and copy them over to create this matrix. Or, we could just copy a pointer to the same array and modify the stride array as illustrated. Creating new views from indexes requires creating new offsets and shape vectors rather than fully allocating a new Tensor.
31+
We can manipulate the stride array and create different views on the same flat array. Consider `tensor.t()`, which returns the transpose of the tensor. We _could_ create a new backing array and copy them over to create this matrix. Or, we could just copy a pointer to the same array and modify the stride array as illustrated. Creating new views from indexes requires creating new offsets and shape vectors rather than fully allocating a new Tensor.
3132

3233
![Graphic showing strided indexing on a transposed tensor](/images/crabtorch_1_transpose.png)
3334

@@ -36,11 +37,11 @@ Consider some operations that are actually quite simple like [`.unsqueeze()`](ht
3637
### `NonNull<u8>` and `Rc`
3738
To accomplish this, we will decouple the Tensor from the underlying memory. We'll need to build two structs: one for maintaining the underlying raw storage, and the actual tensor holding the stride, shape, and type information. We want to enable many tensors to point to the same raw storage container. Rust's alias rules disallow multiple `&mut` references to the same struct. More formally, the [Rust aliasing rules](https://doc.rust-lang.org/nomicon/aliasing.html) allow a single mutable reference or multiple immutable references, exclusively. This type of shared, multiple ownership isn't allowed by aliasing alone.
3839

39-
The solution is to use a reference counter or [`Rc<T>`](https://doc.rust-lang.org/book/ch15-04-rc.html), which brings back reference counting for a specified object, and automatically drops the inner struct once the reference count hits zero.
40+
The solution is to use a reference counter or [`Rc<T>`](https://doc.rust-lang.org/book/ch15-04-rc.html), which brings back reference counting for a specified object, and automatically drops the inner struct once the reference count hits zero.
4041

4142
![Image showing how multiple Tensors point to one Untyped Storage struct](/images/crabtorch_2_structs.png)
4243

43-
Additionally, we'll use a [`NonNull<u8>`](https://doc.rust-lang.org/beta/std/ptr/struct.NonNull.html) instead of a `*mut T`, which enables some nice compiler optimizations at the expense of some type shenanigans. Because we will treat the underlying storage container as untyped storage with raw bytes, we will have to do some `unsafe` casting.
44+
Additionally, we'll use a [`NonNull<u8>`](https://doc.rust-lang.org/beta/std/ptr/struct.NonNull.html) instead of a `*mut u8`, which enables some nice compiler optimizations at the expense of some type shenanigans. Because we will treat the underlying storage container as untyped storage with raw bytes, we will have to do some `unsafe` casting.
4445

4546

4647
```rust
@@ -54,12 +55,12 @@ pub struct UntypedStorage {
5455

5556
#[derive(Default, Debug)]
5657
pub struct Tensor {
57-
shape: Vec<usize>,
58-
stride: Vec<usize>,
59-
offset: usize,
60-
dtype: DataType,
61-
_storage: Rc<UntypedStorage>,
62-
_dispatch_keys: DispatchKeySet,
58+
shape: Vec<usize>,
59+
stride: Vec<usize>,
60+
offset: usize,
61+
dtype: DataType,
62+
_storage: Rc<UntypedStorage>,
63+
//...
6364
}
6465
```
6566

@@ -95,7 +96,7 @@ impl UntypedStorage {
9596
}
9697
```
9798

98-
We also need to implement deallocation behavior that triggers when the storage is dropped by its `Rc`. In rust, implementing destruction behavior comes from implementing the `Drop` trait. Side note: If you're building a storage container and the inner type has properties that require recursive dropping, you'll have to call that here. I'm making the active assumption that we will not need that:
99+
We also need to implement deallocation behavior that triggers when the storage is dropped by its `Rc`. In rust, implementing destruction behavior comes from implementing the `Drop` trait. Side note: If you're building a storage container and the inner type has properties that require recursive dropping, you'll have to call that here. I'm making the assumption that we will not need that:
99100

100101
```rust
101102
impl Drop for UntypedStorage {
@@ -111,6 +112,8 @@ impl Drop for UntypedStorage {
111112
}
112113
```
113114

115+
For now we will use `Rc<T>` over its thread-safe but slower sibling [`Arc<T>`](https://doc.rust-lang.org/std/sync/struct.Arc.html), but a thread-safe implementation will need to use `Arc<T>`, lest the compiler prevents you from `Send`ing your tensors across threads.
116+
114117
Side note: this is similar to how [`Vec` works under the hood](https://doc.rust-lang.org/src/alloc/raw_vec/mod.rs.html#85). It uses a `Unique` wrapper, which has slightly stronger guarantees than a `NonNull`.
115118

116119
As we see, using the allocation apis requires `unsafe`, which allows us to momentarily break the compile-time memory safety guarantees that regular rust provides. As we continue to build functionality, we'll have to be careful to maintain the memory-safety features and contracts through our implementation.

0 commit comments

Comments
 (0)