Skip to content

Commit 9593293

Browse files
authored
Extract stable stwo-types crate containing field and struct types used downstream (#1331)
1 parent 9920aeb commit 9593293

File tree

18 files changed

+1189
-1066
lines changed

18 files changed

+1189
-1066
lines changed

Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"crates/constraint-framework",
77
"crates/examples",
88
"crates/std-shims",
9+
"crates/stwo-types",
910
]
1011
exclude = ["ensure-verifier-no_std"]
1112
resolver = "2"
@@ -32,6 +33,7 @@ rayon = { version = "1.10.0", optional = false }
3233
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
3334
serde = { version = "1.0", default-features = false, features = ["derive"] }
3435
hashbrown = { version = ">=0.15.2", features = ["serde"] }
36+
stwo-types = { path = "crates/stwo-types", version = "2.0.1", default-features = false }
3537

3638
[profile.bench]
3739
codegen-units = 1

crates/stwo-types/Cargo.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[package]
2+
name = "stwo-types"
3+
version = "2.0.1"
4+
edition.workspace = true
5+
license.workspace = true
6+
description = "Core field types shared across stwo crates"
7+
8+
[features]
9+
default = ["std"]
10+
std = ["serde/std", "num-traits/std", "rand/std", "std-shims/std"]
11+
parallel = ["dep:rayon"]
12+
13+
[dependencies]
14+
bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] }
15+
num-traits.workspace = true
16+
rand.workspace = true
17+
serde.workspace = true
18+
rayon = { workspace = true, optional = true }
19+
std-shims = { path = "../std-shims", version = "1.0.0", default-features = false }
20+
21+
[lints.rust]
22+
future-incompatible = "deny"
23+
nonstandard-style = "deny"
24+
rust-2018-idioms = "deny"
25+
26+
[lints.clippy]
27+
missing_const_for_fn = "warn"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
use super::m31::M31;
4+
5+
#[derive(Copy, Clone, Debug, Serialize, Deserialize, Default, Eq, PartialEq, Hash)]
6+
pub struct CasmState {
7+
pub pc: M31,
8+
pub ap: M31,
9+
pub fp: M31,
10+
}
11+
12+
impl CasmState {
13+
pub const fn values(&self) -> [M31; 3] {
14+
[self.pc, self.ap, self.fp]
15+
}
16+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use core::fmt::{Debug, Display};
2+
use core::ops::{
3+
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
4+
};
5+
6+
use serde::{Deserialize, Serialize};
7+
8+
use super::m31::M31;
9+
use super::{ComplexConjugate, FieldExpOps};
10+
use crate::{impl_extension_field, impl_field};
11+
pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2
12+
13+
/// Complex extension field of M31.
14+
/// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial.
15+
/// Represented as (a, b) of a + bi.
16+
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
17+
pub struct CM31(pub M31, pub M31);
18+
19+
impl_field!(CM31, P2);
20+
impl_extension_field!(CM31, M31);
21+
22+
impl CM31 {
23+
pub const fn from_u32_unchecked(a: u32, b: u32) -> CM31 {
24+
Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b))
25+
}
26+
27+
pub const fn from_m31(a: M31, b: M31) -> CM31 {
28+
Self(a, b)
29+
}
30+
}
31+
32+
impl Display for CM31 {
33+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34+
write!(f, "{} + {}i", self.0, self.1)
35+
}
36+
}
37+
38+
impl Debug for CM31 {
39+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40+
write!(f, "{} + {}i", self.0, self.1)
41+
}
42+
}
43+
44+
impl Mul for CM31 {
45+
type Output = Self;
46+
47+
fn mul(self, rhs: Self) -> Self::Output {
48+
// (a + bi) * (c + di) = (ac - bd) + (ad + bc)i.
49+
Self(
50+
self.0 * rhs.0 - self.1 * rhs.1,
51+
self.0 * rhs.1 + self.1 * rhs.0,
52+
)
53+
}
54+
}
55+
56+
impl TryInto<M31> for CM31 {
57+
type Error = ();
58+
59+
fn try_into(self) -> Result<M31, Self::Error> {
60+
if self.1 != M31::zero() {
61+
return Err(());
62+
}
63+
Ok(self.0)
64+
}
65+
}
66+
67+
impl FieldExpOps for CM31 {
68+
fn inverse(&self) -> Self {
69+
assert!(!self.is_zero(), "0 has no inverse");
70+
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
71+
Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse()
72+
}
73+
}
74+
75+
#[cfg(test)]
76+
#[macro_export]
77+
macro_rules! cm31 {
78+
($m0:expr, $m1:expr) => {
79+
CM31::from_u32_unchecked($m0, $m1)
80+
};
81+
}
82+
83+
#[cfg(test)]
84+
mod tests {
85+
use super::CM31;
86+
use crate::fields::m31::P;
87+
use crate::fields::FieldExpOps;
88+
use crate::m31;
89+
90+
#[test]
91+
fn test_inverse() {
92+
let cm = cm31!(1, 2);
93+
let cm_inv = cm.inverse();
94+
assert_eq!(cm * cm_inv, cm31!(1, 0));
95+
}
96+
97+
#[test]
98+
fn test_ops() {
99+
let cm0 = cm31!(1, 2);
100+
let cm1 = cm31!(4, 5);
101+
let m = m31!(8);
102+
let cm = CM31::from(m);
103+
let cm0_x_cm1 = cm31!(P - 6, 13);
104+
105+
assert_eq!(cm0 + cm1, cm31!(5, 7));
106+
assert_eq!(cm1 + m, cm1 + cm);
107+
assert_eq!(cm0 * cm1, cm0_x_cm1);
108+
assert_eq!(cm1 * m, cm1 * cm);
109+
assert_eq!(-cm0, cm31!(P - 1, P - 2));
110+
assert_eq!(cm0 - cm1, cm31!(P - 3, P - 3));
111+
assert_eq!(cm1 - m, cm1 - cm);
112+
assert_eq!(cm0_x_cm1 / cm1, cm31!(1, 2));
113+
assert_eq!(cm1 / m, cm1 / cm);
114+
}
115+
}

0 commit comments

Comments
 (0)