Skip to content

Commit 2ede7b7

Browse files
Add QM31 Lambdaworks wrapper (#152)
1 parent 2e96357 commit 2ede7b7

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed

crates/starknet-types-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod curve;
66
pub mod hash;
77

88
pub mod felt;
9+
pub mod qm31;
910

1011
#[cfg(any(feature = "std", feature = "alloc"))]
1112
pub mod short_string;
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
//! A value in the Degree-4 (quadruple) extension of the Mersenne 31 field.
2+
//!
3+
//! The Marsenne 31 field is used by the Stwo prover.
4+
5+
use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub};
6+
7+
use lambdaworks_math::field::{
8+
element::FieldElement,
9+
errors::FieldError,
10+
fields::mersenne31::{
11+
extensions::Degree4ExtensionField,
12+
field::{Mersenne31Field, MERSENNE_31_PRIME_FIELD_ORDER},
13+
},
14+
traits::IsField,
15+
};
16+
17+
#[cfg(feature = "num-traits")]
18+
mod num_traits_impl;
19+
20+
use crate::felt::Felt;
21+
22+
/// A value in the Degree-4 (quadruple) extension of the Mersenne 31 (M31) field.
23+
///
24+
/// Each QM31 value is represented by two values in the Degree-2 (complex)
25+
/// extension, and each of these is represented by two values in the base
26+
/// field. Thus, a QM31 is represented by four M31 coordinates.
27+
///
28+
/// An M31 coordinate fits in 31 bits, as it has a maximum value of: `(1 << 31) - 1`.
29+
#[derive(Debug, Clone, PartialEq, Eq, Default)]
30+
pub struct QM31(pub FieldElement<Degree4ExtensionField>);
31+
32+
#[derive(Debug, Clone, Copy)]
33+
pub struct InvalidQM31Packing(pub Felt);
34+
35+
#[cfg(feature = "std")]
36+
impl std::error::Error for InvalidQM31Packing {}
37+
38+
#[cfg(feature = "std")]
39+
impl std::fmt::Display for InvalidQM31Packing {
40+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41+
write!(f, "felt is not a packed QM31: {}", self.0)
42+
}
43+
}
44+
45+
impl QM31 {
46+
/// Creates a QM31 from four M31 elements.
47+
pub fn from_coefficients(a: u32, b: u32, c: u32, d: u32) -> Self {
48+
Self(Degree4ExtensionField::const_from_coefficients(
49+
Mersenne31Field::from_base_type(a),
50+
Mersenne31Field::from_base_type(b),
51+
Mersenne31Field::from_base_type(c),
52+
Mersenne31Field::from_base_type(d),
53+
))
54+
}
55+
56+
/// Extracts M31 elements from a QM31.
57+
pub fn to_coefficients(&self) -> (u32, u32, u32, u32) {
58+
// Take CM31 coordinates from QM31.
59+
let [a, b] = self.0.value();
60+
61+
// Take M31 coordinates from both CM31.
62+
let [c1, c2] = a.value();
63+
let [c3, c4] = b.value();
64+
65+
(c1.to_raw(), c2.to_raw(), c3.to_raw(), c4.to_raw())
66+
}
67+
68+
/// Packs the [QM31] into a [Felt].
69+
///
70+
/// Stores the four M31 coordinates in the first 144 bits of a Felt. Each
71+
/// coordinate takes 36 bits, and the resulting felt is equal to:
72+
/// `C1 + C2 << 36 + C3 << 72 + C4 << 108`
73+
///
74+
/// Why the stride between coordinates is 36 instead of 31? In Stwo, Felt
75+
/// elements are stored in memory as 28 M31s, each representing 9 bits
76+
/// (that representation is efficient for multiplication). 36 is the first
77+
/// multiple of 9 that is greater than 31.
78+
pub fn pack_into_felt(&self) -> Felt {
79+
let (c1, c2, c3, c4) = self.to_coefficients();
80+
81+
// Pack as: c1 + c2 << 36 + c3 << 36*2 + c4 << 36*3.
82+
let lo = c1 as u128 + ((c2 as u128) << 36);
83+
let hi = c3 as u128 + ((c4 as u128) << 36);
84+
let mut felt_bytes = [0u8; 32];
85+
felt_bytes[0..9].copy_from_slice(&lo.to_le_bytes()[0..9]);
86+
felt_bytes[9..18].copy_from_slice(&hi.to_le_bytes()[0..9]);
87+
Felt::from_bytes_le(&felt_bytes)
88+
}
89+
90+
/// Unpacks a [QM31] from the [Felt]
91+
///
92+
/// See the method [QM31::pack_into_felt] for a detailed explanation on the
93+
/// packing format.
94+
pub fn unpack_from_felt(felt: &Felt) -> Result<QM31, InvalidQM31Packing> {
95+
const MASK_36: u64 = (1 << 36) - 1;
96+
const MASK_8: u64 = (1 << 8) - 1;
97+
98+
let digits = felt.to_le_digits();
99+
100+
// The QM31 is packed in the first 144 bits,
101+
// the remaining bits must be zero.
102+
if digits[3] != 0 || digits[2] >= 1 << 16 {
103+
return Err(InvalidQM31Packing(*felt));
104+
}
105+
106+
// Unpack as: c1 + c2 << 36 + c3 << 36*2 + c4 << 36*3.
107+
let c1 = digits[0] & MASK_36;
108+
let c2 = (digits[0] >> 36) + ((digits[1] & MASK_8) << 28);
109+
let c3 = (digits[1] >> 8) & MASK_36;
110+
let c4 = (digits[1] >> 44) + (digits[2] << 20);
111+
112+
// Even though we use 36 bits for each coordinate,
113+
// the maximum value is still the field prime.
114+
for c in [c1, c2, c3, c4] {
115+
if c >= MERSENNE_31_PRIME_FIELD_ORDER as u64 {
116+
return Err(InvalidQM31Packing(*felt));
117+
}
118+
}
119+
120+
Ok(QM31(Degree4ExtensionField::const_from_coefficients(
121+
c1 as u32, c2 as u32, c3 as u32, c4 as u32,
122+
)))
123+
}
124+
125+
/// Multiplicative inverse inside field.
126+
pub fn inverse(&self) -> Result<Self, FieldError> {
127+
Ok(Self(self.0.inv()?))
128+
}
129+
}
130+
131+
impl Add for QM31 {
132+
type Output = QM31;
133+
134+
fn add(self, rhs: Self) -> Self::Output {
135+
Self(self.0.add(rhs.0))
136+
}
137+
}
138+
impl Sub for QM31 {
139+
type Output = QM31;
140+
141+
fn sub(self, rhs: Self) -> Self::Output {
142+
Self(self.0.sub(rhs.0))
143+
}
144+
}
145+
impl Mul for QM31 {
146+
type Output = QM31;
147+
148+
fn mul(self, rhs: Self) -> Self::Output {
149+
Self(self.0.mul(rhs.0))
150+
}
151+
}
152+
impl Div for QM31 {
153+
type Output = Result<QM31, FieldError>;
154+
155+
fn div(self, rhs: Self) -> Self::Output {
156+
Ok(Self(self.0.div(rhs.0)?))
157+
}
158+
}
159+
impl AddAssign for QM31 {
160+
fn add_assign(&mut self, rhs: Self) {
161+
self.0.add_assign(rhs.0);
162+
}
163+
}
164+
impl MulAssign for QM31 {
165+
fn mul_assign(&mut self, rhs: Self) {
166+
self.0.mul_assign(rhs.0);
167+
}
168+
}
169+
impl Neg for QM31 {
170+
type Output = QM31;
171+
172+
fn neg(self) -> Self::Output {
173+
Self(self.0.neg())
174+
}
175+
}
176+
177+
#[cfg(test)]
178+
mod test {
179+
use lambdaworks_math::field::fields::mersenne31::{
180+
extensions::Degree4ExtensionField, field::MERSENNE_31_PRIME_FIELD_ORDER,
181+
};
182+
use num_bigint::BigInt;
183+
184+
use crate::{felt::Felt, qm31::QM31};
185+
186+
#[test]
187+
fn qm31_packing_and_unpacking() {
188+
const MAX: u32 = MERSENNE_31_PRIME_FIELD_ORDER - 1;
189+
190+
let cases = [
191+
[1, 2, 3, 4],
192+
[MAX, 0, 0, 0],
193+
[MAX, MAX, 0, 0],
194+
[MAX, MAX, MAX, 0],
195+
[MAX, MAX, MAX, MAX],
196+
];
197+
198+
for [c1, c2, c3, c4] in cases {
199+
let qm31 = QM31(Degree4ExtensionField::const_from_coefficients(
200+
c1, c2, c3, c4,
201+
));
202+
let packed_qm31 = qm31.pack_into_felt();
203+
let unpacked_qm31 = QM31::unpack_from_felt(&packed_qm31).unwrap();
204+
205+
assert_eq!(qm31, unpacked_qm31)
206+
}
207+
}
208+
209+
#[test]
210+
fn qm31_packing() {
211+
const MAX: u32 = MERSENNE_31_PRIME_FIELD_ORDER - 2;
212+
213+
let cases = [
214+
[1, 2, 3, 4],
215+
[MAX, 0, 0, 0],
216+
[MAX, MAX, 0, 0],
217+
[MAX, MAX, MAX, 0],
218+
[MAX, MAX, MAX, MAX],
219+
];
220+
221+
for [c1, c2, c3, c4] in cases {
222+
let qm31 = QM31(Degree4ExtensionField::const_from_coefficients(
223+
c1, c2, c3, c4,
224+
));
225+
let packed_qm31 = qm31.pack_into_felt();
226+
227+
let expected_packing = BigInt::from(c1)
228+
+ (BigInt::from(c2) << 36)
229+
+ (BigInt::from(c3) << 72)
230+
+ (BigInt::from(c4) << 108);
231+
232+
assert_eq!(packed_qm31, Felt::from(expected_packing))
233+
}
234+
}
235+
236+
#[test]
237+
fn qm31_invalid_packing() {
238+
const MAX: u64 = MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1;
239+
240+
let cases = [
241+
[MAX + 1, 0, 0, 0],
242+
[0, MAX + 1, 0, 0],
243+
[0, 0, MAX + 1, 0],
244+
[0, 0, 0, MAX + 1],
245+
];
246+
247+
for [c1, c2, c3, c4] in cases {
248+
let invalid_packing = Felt::from(
249+
BigInt::from(c1)
250+
+ (BigInt::from(c2) << 36)
251+
+ (BigInt::from(c3) << 72)
252+
+ (BigInt::from(c4) << 108),
253+
);
254+
255+
QM31::unpack_from_felt(&invalid_packing).unwrap_err();
256+
}
257+
}
258+
259+
#[test]
260+
fn qm31_packing_with_high_bits() {
261+
let invalid_packing = Felt::from(BigInt::from(1) << 200);
262+
263+
QM31::unpack_from_felt(&invalid_packing).unwrap_err();
264+
}
265+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use lambdaworks_math::field::{
2+
element::FieldElement, errors::FieldError,
3+
fields::mersenne31::extensions::Degree4ExtensionField,
4+
};
5+
use num_traits::{Inv, One, Pow, Zero};
6+
7+
use super::QM31;
8+
9+
impl Zero for QM31 {
10+
fn zero() -> Self {
11+
Self(FieldElement::<Degree4ExtensionField>::zero())
12+
}
13+
14+
fn is_zero(&self) -> bool {
15+
self == &Self::zero()
16+
}
17+
}
18+
impl One for QM31 {
19+
fn one() -> Self {
20+
Self(FieldElement::<Degree4ExtensionField>::one())
21+
}
22+
}
23+
impl Inv for QM31 {
24+
type Output = Result<QM31, FieldError>;
25+
26+
fn inv(self) -> Self::Output {
27+
self.inverse()
28+
}
29+
}
30+
impl Pow<u8> for QM31 {
31+
type Output = Self;
32+
33+
fn pow(self, rhs: u8) -> Self::Output {
34+
Self(self.0.pow(rhs as u128))
35+
}
36+
}
37+
impl Pow<u16> for QM31 {
38+
type Output = Self;
39+
40+
fn pow(self, rhs: u16) -> Self::Output {
41+
Self(self.0.pow(rhs as u128))
42+
}
43+
}
44+
impl Pow<u32> for QM31 {
45+
type Output = Self;
46+
47+
fn pow(self, rhs: u32) -> Self::Output {
48+
Self(self.0.pow(rhs as u128))
49+
}
50+
}
51+
impl Pow<u64> for QM31 {
52+
type Output = Self;
53+
54+
fn pow(self, rhs: u64) -> Self::Output {
55+
Self(self.0.pow(rhs as u128))
56+
}
57+
}
58+
impl Pow<u128> for QM31 {
59+
type Output = Self;
60+
61+
fn pow(self, rhs: u128) -> Self::Output {
62+
Self(self.0.pow(rhs))
63+
}
64+
}

0 commit comments

Comments
 (0)