Skip to content

Commit 444176c

Browse files
committed
Dag in node structure
1 parent 5bd62cc commit 444176c

File tree

1 file changed

+82
-13
lines changed
  • crates/multilinear_extensions/src/expression

1 file changed

+82
-13
lines changed

crates/multilinear_extensions/src/expression/utils.rs

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{Fixed, Instance, WitnessId, combine_cumulative_either, monomial::Ter
55
use either::Either;
66
use ff_ext::ExtensionField;
77
use itertools::Itertools;
8+
use serde::{Deserialize, Serialize};
89

910
impl WitIn {
1011
pub fn assign<E: ExtensionField>(&self, instance: &mut [E::BaseField], value: E::BaseField) {
@@ -198,13 +199,23 @@ pub const DagLoadScalar: usize = 1;
198199
pub const DagAdd: usize = 2;
199200
pub const DagMul: usize = 3;
200201

202+
#[derive(Clone, Debug, Serialize, Deserialize)]
203+
#[repr(C)]
204+
pub struct Node {
205+
pub op: u32,
206+
pub left_id: u32,
207+
pub right_id: u32,
208+
pub out: u32,
209+
}
210+
201211
pub fn expr_compression_to_dag<E: ExtensionField>(
202212
expr: &Expression<E>,
203213
) -> (
204-
Vec<u32>,
214+
Vec<Node>,
205215
Vec<Instance>,
206216
Vec<Expression<E>>,
207217
Vec<Either<E::BaseField, E>>,
218+
u32,
208219
(usize, usize)
209220
) {
210221
let mut constant_dedup = HashMap::new();
@@ -213,6 +224,7 @@ pub fn expr_compression_to_dag<E: ExtensionField>(
213224
let mut constant = vec![];
214225
let mut instance_scalar = vec![];
215226
let mut challenges = vec![];
227+
let mut stack_pos: u32 = 0;
216228
// traverse first time to collect offset
217229
let _ = expr_compression_to_dag_helper(
218230
&mut dag,
@@ -223,6 +235,7 @@ pub fn expr_compression_to_dag<E: ExtensionField>(
223235
&mut constant,
224236
&mut challenges_dedup,
225237
&mut constant_dedup,
238+
&mut stack_pos,
226239
expr,
227240
);
228241

@@ -235,6 +248,7 @@ pub fn expr_compression_to_dag<E: ExtensionField>(
235248
challenges.truncate(0);
236249
challenges_dedup.clear();
237250
constant_dedup.clear();
251+
stack_pos = 0;
238252
let (max_degree, max_depth) = expr_compression_to_dag_helper(
239253
&mut dag,
240254
&mut instance_scalar,
@@ -244,34 +258,48 @@ pub fn expr_compression_to_dag<E: ExtensionField>(
244258
&mut constant,
245259
&mut challenges_dedup,
246260
&mut constant_dedup,
261+
&mut stack_pos,
247262
expr,
248263
);
249-
(dag, instance_scalar, challenges, constant, (max_degree, max_depth))
264+
(dag, instance_scalar, challenges, constant, stack_pos, (max_degree, max_depth))
250265
}
251266

252267
fn expr_compression_to_dag_helper<E: ExtensionField>(
253-
dag: &mut Vec<u32>,
268+
dag: &mut Vec<Node>,
254269
instance_scalar: &mut Vec<Instance>,
255270
challenges_offset: usize,
256271
challenges: &mut Vec<Expression<E>>,
257272
constant_offset: usize,
258273
constant: &mut Vec<Either<E::BaseField, E>>,
259274
challenges_dedup: &mut HashMap<Expression<E>, u32>,
260275
constant_dedup: &mut HashMap<Either<E::BaseField, E>, u32>,
276+
stack_pos: &mut u32,
261277
expr: &Expression<E>,
262278
) -> (usize, usize) {
263279
// (max_degree, max_depth)
264280
match expr {
265281
Expression::Fixed(_) => unimplemented!(),
266282
Expression::WitIn(wit_id) => {
267-
dag.extend(vec![DagLoadWit as u32, *wit_id as u32]);
283+
dag.push(Node {
284+
op: DagLoadWit as u32,
285+
left_id: *wit_id as u32,
286+
right_id: 0,
287+
out: *stack_pos,
288+
});
289+
*stack_pos += 1;
268290
(1, 1)
269291
}
270292
Expression::StructuralWitIn(_, ..) => unimplemented!(),
271293
Expression::Instance(_) => unimplemented!(),
272294
Expression::InstanceScalar(inst) => {
273295
instance_scalar.push(inst.clone());
274-
dag.extend(vec![DagLoadScalar as u32, instance_scalar.len() as u32 - 1]);
296+
dag.push(Node {
297+
op: DagLoadScalar as u32,
298+
left_id: instance_scalar.len() as u32 - 1,
299+
right_id: 0,
300+
out: *stack_pos,
301+
});
302+
*stack_pos += 1;
275303
(0, 1)
276304
}
277305
Expression::Constant(value) => {
@@ -284,8 +312,13 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
284312
id
285313
}
286314
};
287-
288-
dag.extend([DagLoadScalar as u32, constant_id]);
315+
dag.push(Node {
316+
op: DagLoadScalar as u32,
317+
left_id: constant_id,
318+
right_id: 0,
319+
out: *stack_pos,
320+
});
321+
*stack_pos += 1;
289322
(0, 1)
290323
}
291324
Expression::Sum(a, b) => {
@@ -298,6 +331,7 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
298331
constant,
299332
challenges_dedup,
300333
constant_dedup,
334+
stack_pos,
301335
a,
302336
);
303337
let (max_degree_b, max_depth_b) = expr_compression_to_dag_helper(
@@ -309,9 +343,16 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
309343
constant,
310344
challenges_dedup,
311345
constant_dedup,
346+
stack_pos,
312347
b,
313348
);
314-
dag.extend(vec![DagAdd as u32]);
349+
dag.push(Node {
350+
op: DagAdd as u32,
351+
left_id: *stack_pos-2,
352+
right_id: *stack_pos-1,
353+
out: *stack_pos-2,
354+
});
355+
*stack_pos -= 1;
315356
(
316357
max_degree_a.max(max_degree_b),
317358
max_depth_a.max(max_depth_b + 1),
@@ -327,6 +368,7 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
327368
constant,
328369
challenges_dedup,
329370
constant_dedup,
371+
stack_pos,
330372
a,
331373
);
332374
let (max_degree_b, max_depth_b) = expr_compression_to_dag_helper(
@@ -338,9 +380,16 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
338380
constant,
339381
challenges_dedup,
340382
constant_dedup,
383+
stack_pos,
341384
b,
342385
);
343-
dag.extend(vec![DagMul as u32]);
386+
dag.push(Node {
387+
op: DagMul as u32,
388+
left_id: *stack_pos-2,
389+
right_id: *stack_pos-1,
390+
out: *stack_pos-2,
391+
});
392+
*stack_pos -= 1;
344393
(
345394
max_degree_a + max_degree_b,
346395
max_depth_a.max(max_depth_b + 1),
@@ -356,6 +405,7 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
356405
constant,
357406
challenges_dedup,
358407
constant_dedup,
408+
stack_pos,
359409
x,
360410
);
361411
let (max_degree_a, max_depth_a) = expr_compression_to_dag_helper(
@@ -367,11 +417,18 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
367417
constant,
368418
challenges_dedup,
369419
constant_dedup,
420+
stack_pos,
370421
a,
371422
);
372423
let xa_degree = max_degree_x + max_degree_a;
373424
let ax_max_depth = max_depth_x.max(max_depth_a + 1);
374-
dag.extend(vec![DagMul as u32]);
425+
dag.push(Node {
426+
op: DagMul as u32,
427+
left_id: *stack_pos-2,
428+
right_id: *stack_pos-1,
429+
out: *stack_pos-2,
430+
});
431+
*stack_pos -= 1;
375432
let (max_degree_b, max_depth_b) = expr_compression_to_dag_helper(
376433
dag,
377434
instance_scalar,
@@ -381,9 +438,16 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
381438
constant,
382439
challenges_dedup,
383440
constant_dedup,
441+
stack_pos,
384442
b,
385443
);
386-
dag.extend(vec![DagAdd as u32]);
444+
dag.push(Node {
445+
op: DagAdd as u32,
446+
left_id: *stack_pos-2,
447+
right_id: *stack_pos-1,
448+
out: *stack_pos-2,
449+
});
450+
*stack_pos -= 1;
387451
(
388452
xa_degree.max(max_degree_b),
389453
(ax_max_depth).max(max_depth_b + 1),
@@ -399,8 +463,13 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
399463
id
400464
}
401465
};
402-
403-
dag.extend([DagLoadScalar as u32, challenge_id]);
466+
dag.push(Node {
467+
op: DagLoadScalar as u32,
468+
left_id: challenge_id,
469+
right_id: 0,
470+
out: *stack_pos,
471+
});
472+
*stack_pos += 1;
404473
(0, 1)
405474
}
406475
}

0 commit comments

Comments
 (0)