Skip to content

Commit f20e378

Browse files
: v1: value_mesh serialization (meta-pytorch#1471)
Summary: this change enables serde serialization and deserialization for `ValueMesh<T>` and its internal representations. it introduces a stable wire format for run-length encoded (RLE) meshes by defining a `Run` struct (`u64` bounds, `u32` id) to avoid platform-dependent `usize` encoding. both dense and compressed representations serialize deterministically and retain their form on round trips Reviewed By: dulinriley Differential Revision: D84197759
1 parent 5268cb1 commit f20e378

File tree

1 file changed

+131
-13
lines changed

1 file changed

+131
-13
lines changed

hyperactor_mesh/src/v1/value_mesh.rs

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use futures::Future;
1717
use ndslice::view;
1818
use ndslice::view::Ranked;
1919
use ndslice::view::Region;
20+
use serde::Deserialize;
21+
use serde::Serialize;
2022

2123
/// A mesh of values, one per rank in `region`.
2224
///
@@ -27,7 +29,7 @@ use ndslice::view::Region;
2729
/// # Invariants
2830
/// - Complete: every rank in `region` has exactly one value.
2931
/// - Order: iteration and indexing follow the region's linearization.
30-
#[derive(Clone, Debug, PartialEq, Eq, Hash)] // only if T implements
32+
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] // only if T implements
3133
pub struct ValueMesh<T> {
3234
/// The logical multidimensional domain of the mesh.
3335
///
@@ -48,6 +50,58 @@ pub struct ValueMesh<T> {
4850
rep: Rep<T>,
4951
}
5052

53+
/// A single run-length–encoded (RLE) segment within a [`ValueMesh`].
54+
///
55+
/// Each `Run` represents a contiguous range of ranks `[start, end)`
56+
/// that all share the same value, referenced indirectly via a table
57+
/// index `id`. This allows compact storage of large regions with
58+
/// repeated values.
59+
///
60+
/// Runs are serialized in a stable, portable format using `u64` for
61+
/// range bounds (`start`, `end`) to avoid platform‐dependent `usize`
62+
/// encoding differences.
63+
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
64+
struct Run {
65+
/// Inclusive start of the contiguous range of ranks (0-based).
66+
start: u64,
67+
/// Exclusive end of the contiguous range of ranks (0-based).
68+
end: u64,
69+
/// Index into the value table for this run's shared value.
70+
id: u32,
71+
}
72+
73+
impl Run {
74+
/// Creates a new `Run` covering ranks `[start, end)` that all
75+
/// share the same table entry `id`.
76+
///
77+
/// Converts `usize` bounds to `u64` for stable serialization.
78+
fn new(start: usize, end: usize, id: u32) -> Self {
79+
Self {
80+
start: start as u64,
81+
end: end as u64,
82+
id,
83+
}
84+
}
85+
}
86+
87+
impl TryFrom<Run> for (Range<usize>, u32) {
88+
type Error = &'static str;
89+
90+
/// Converts a serialized [`Run`] back into its in-memory form
91+
/// `(Range<usize>, u32)`.
92+
///
93+
/// Performs checked conversion of the 64-bit wire fields back
94+
/// into `usize` indices, returning an error if either bound
95+
/// exceeds the platform’s addressable range. This ensures safe
96+
/// round-tripping between the serialized wire format and native
97+
/// representation.
98+
fn try_from(r: Run) -> Result<Self, Self::Error> {
99+
let start = usize::try_from(r.start).map_err(|_| "run.start too large")?;
100+
let end = usize::try_from(r.end).map_err(|_| "run.end too large")?;
101+
Ok((start..end, r.id))
102+
}
103+
}
104+
51105
/// Internal storage representation for a [`ValueMesh`].
52106
///
53107
/// This enum abstracts how the per-rank values are stored.
@@ -61,7 +115,8 @@ pub struct ValueMesh<T> {
61115
/// Users of [`ValueMesh`] normally never interact with `Rep`
62116
/// directly; all iteration and slicing APIs present a dense logical
63117
/// view.
64-
#[derive(Clone, Debug, PartialEq, Eq, Hash)] // only if T implements
118+
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] // only if T implements
119+
#[serde(tag = "rep", rename_all = "snake_case")]
65120
enum Rep<T> {
66121
/// Fully expanded representation: one element per rank.
67122
///
@@ -92,7 +147,7 @@ enum Rep<T> {
92147

93148
/// List of `(range, table_id)` pairs describing contiguous
94149
/// runs of identical values in region order.
95-
runs: Vec<(Range<usize>, u32)>,
150+
runs: Vec<Run>,
96151
},
97152
}
98153

@@ -202,13 +257,15 @@ impl<T: 'static> view::Ranked for ValueMesh<T> {
202257
Rep::Dense { values } => values.get(rank),
203258

204259
Rep::Compressed { table, runs } => {
260+
let rank = rank as u64;
261+
205262
// Binary search over runs: find the one whose range
206263
// contains `rank`.
207264
let idx = runs
208-
.binary_search_by(|(r, _)| {
209-
if r.end <= rank {
265+
.binary_search_by(|run| {
266+
if run.end <= rank {
210267
Ordering::Less
211-
} else if r.start > rank {
268+
} else if run.start > rank {
212269
Ordering::Greater
213270
} else {
214271
Ordering::Equal
@@ -217,7 +274,7 @@ impl<T: 'static> view::Ranked for ValueMesh<T> {
217274
.ok()?;
218275

219276
// Map the run's table ID to its actual value.
220-
let id = runs[idx].1 as usize;
277+
let id = runs[idx].id as usize;
221278
table.get(id)
222279
}
223280
}
@@ -581,10 +638,7 @@ impl<T: Clone> ValueMesh<T> {
581638
/// # Returns
582639
/// A tuple `(table, runs)` that together form the compressed
583640
/// representation. Expanding the runs reproduces the original data.
584-
fn compress_adjacent_with<T: Clone, F>(
585-
values: Vec<T>,
586-
mut same: F,
587-
) -> (Vec<T>, Vec<(Range<usize>, u32)>)
641+
fn compress_adjacent_with<T: Clone, F>(values: Vec<T>, mut same: F) -> (Vec<T>, Vec<Run>)
588642
where
589643
F: FnMut(&T, &T) -> bool,
590644
{
@@ -605,7 +659,7 @@ where
605659
for (i, _value) in values.iter().enumerate().skip(1) {
606660
if !same(&values[i], &table[cur_id as usize]) {
607661
// Close current run [start, i)
608-
runs.push((start..i, cur_id));
662+
runs.push(Run::new(start, i, cur_id));
609663

610664
// Start a new run
611665
start = i;
@@ -615,7 +669,7 @@ where
615669
}
616670

617671
// Close the final run
618-
runs.push((start..values.len(), cur_id));
672+
runs.push(Run::new(start, values.len(), cur_id));
619673

620674
(table, runs)
621675
}
@@ -644,6 +698,7 @@ mod tests {
644698
use ndslice::view::ViewExt;
645699
use proptest::prelude::*;
646700
use proptest::strategy::ValueTree;
701+
use serde_json;
647702

648703
use super::*;
649704

@@ -1296,4 +1351,67 @@ mod tests {
12961351
assert_eq!(vm.get(0), Some(&123));
12971352
assert_eq!(vm.get(1), None);
12981353
}
1354+
1355+
#[test]
1356+
fn test_dense_round_trip() {
1357+
// Build a simple dense mesh of 5 integers.
1358+
let region: Region = extent!(x = 5).into();
1359+
let dense = ValueMesh::new(region.clone(), vec![1, 2, 3, 4, 5]).unwrap();
1360+
1361+
let json = serde_json::to_string_pretty(&dense).unwrap();
1362+
let restored: ValueMesh<i32> = serde_json::from_str(&json).unwrap();
1363+
1364+
assert_eq!(dense, restored);
1365+
1366+
// Dense meshes should stay dense on the wire: check the
1367+
// tagged variant.
1368+
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
1369+
// enum tag is nested: {"rep": {"rep":"dense", ...}}
1370+
let tag = v
1371+
.get("rep")
1372+
.and_then(|o| o.get("rep"))
1373+
.and_then(|s| s.as_str());
1374+
assert_eq!(tag, Some("dense"));
1375+
}
1376+
1377+
#[test]
1378+
fn test_compressed_round_trip() {
1379+
// Build a dense mesh, compress it, and verify it stays
1380+
// compressed on the wire.
1381+
let region: Region = extent!(x = 10).into();
1382+
let mut mesh = ValueMesh::new(region.clone(), vec![1, 1, 1, 2, 2, 3, 3, 3, 3, 3]).unwrap();
1383+
mesh.compress_adjacent_in_place();
1384+
1385+
let json = serde_json::to_string_pretty(&mesh).unwrap();
1386+
let restored: ValueMesh<i32> = serde_json::from_str(&json).unwrap();
1387+
1388+
// Logical equality preserved.
1389+
assert_eq!(mesh, restored);
1390+
1391+
// Compressed meshes should stay compressed on the wire.
1392+
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
1393+
// enum tag is nested: {"rep": {"rep":"compressed", ...}}
1394+
let tag = v
1395+
.get("rep")
1396+
.and_then(|o| o.get("rep"))
1397+
.and_then(|s| s.as_str());
1398+
assert_eq!(tag, Some("compressed"));
1399+
}
1400+
1401+
#[test]
1402+
fn test_stable_run_encoding() {
1403+
let run = Run::new(0, 10, 42);
1404+
let json = serde_json::to_string(&run).unwrap();
1405+
let decoded: Run = serde_json::from_str(&json).unwrap();
1406+
1407+
assert_eq!(run, decoded);
1408+
assert_eq!(run.start, 0);
1409+
assert_eq!(run.end, 10);
1410+
assert_eq!(run.id, 42);
1411+
1412+
// Ensure conversion back to Range<usize> works.
1413+
let (range, id): (Range<usize>, u32) = run.try_into().unwrap();
1414+
assert_eq!(range, 0..10);
1415+
assert_eq!(id, 42);
1416+
}
12991417
}

0 commit comments

Comments
 (0)