11use crate :: {
22 error:: ZKVMError ,
3- instructions:: riscv:: { DummyExtraConfig , MemPadder , MmuConfig , Rv32imConfig } ,
3+ instructions:: riscv:: {
4+ DummyExtraConfig , InstructionDispatchBuilder , MemPadder , MmuConfig , Rv32imConfig ,
5+ } ,
46 scheme:: {
57 PublicValues , ZKVMProof ,
68 constants:: SEPTIC_EXTENSION_DEGREE ,
@@ -575,6 +577,35 @@ pub trait StepCellExtractor {
575577 fn extract_cells ( & self , step : & StepRecord ) -> u64 ;
576578}
577579
580+ #[ derive( Clone , Copy , Debug , Default ) ]
581+ pub struct ShardStepSummary {
582+ pub step_count : usize ,
583+ pub first_cycle : Cycle ,
584+ pub last_cycle : Cycle ,
585+ pub first_pc_before : Addr ,
586+ pub last_pc_after : Addr ,
587+ pub first_heap_before : Addr ,
588+ pub last_heap_after : Addr ,
589+ pub first_hint_before : Addr ,
590+ pub last_hint_after : Addr ,
591+ }
592+
593+ impl ShardStepSummary {
594+ fn update ( & mut self , step : & StepRecord ) {
595+ if self . step_count == 0 {
596+ self . first_cycle = step. cycle ( ) ;
597+ self . first_pc_before = step. pc ( ) . before . 0 ;
598+ self . first_heap_before = step. heap_maxtouch_addr . before . 0 ;
599+ self . first_hint_before = step. hint_maxtouch_addr . before . 0 ;
600+ }
601+ self . step_count += 1 ;
602+ self . last_cycle = step. cycle ( ) ;
603+ self . last_pc_after = step. pc ( ) . after . 0 ;
604+ self . last_heap_after = step. heap_maxtouch_addr . after . 0 ;
605+ self . last_hint_after = step. hint_maxtouch_addr . after . 0 ;
606+ }
607+ }
608+
578609pub struct ShardContextBuilder {
579610 pub cur_shard_id : usize ,
580611 addr_future_accesses : Arc < NextCycleAccess > ,
@@ -645,9 +676,9 @@ impl ShardContextBuilder {
645676 & mut self ,
646677 steps_iter : & mut impl Iterator < Item = StepRecord > ,
647678 step_cell_extractor : impl StepCellExtractor ,
648- steps : & mut Vec < StepRecord > ,
649- ) -> Option < ShardContext < ' a > > {
650- steps . clear ( ) ;
679+ mut on_step : impl FnMut ( StepRecord ) ,
680+ ) -> Option < ( ShardContext < ' a > , ShardStepSummary ) > {
681+ let mut summary = ShardStepSummary :: default ( ) ;
651682 let target_cost_current_shard = if self . cur_shard_id == 0 {
652683 self . target_cell_first_shard
653684 } else {
@@ -666,78 +697,57 @@ impl ShardContextBuilder {
666697 let next_cycle = self . cur_acc_cycle + FullTracer :: SUBCYCLES_PER_INSN ;
667698 if next_cells >= target_cost_current_shard || next_cycle >= self . max_cycle_per_shard {
668699 assert ! (
669- !steps . is_empty ( ) ,
700+ summary . step_count > 0 ,
670701 "empty record match when splitting shards"
671702 ) ;
672703 self . pending_step = Some ( step) ;
673704 break ;
674705 }
675706 self . cur_cells = next_cells;
676707 self . cur_acc_cycle = next_cycle;
677- steps. push ( step) ;
708+ summary. update ( & step) ;
709+ on_step ( step) ;
678710 }
679711
680- if steps . is_empty ( ) {
712+ if summary . step_count == 0 {
681713 return None ;
682714 }
683715
684716 if self . cur_shard_id > 0 {
685717 assert_eq ! (
686- steps . first ( ) . map ( |step| step . cycle ( ) ) . unwrap_or_default ( ) ,
718+ summary . first_cycle ,
687719 self . prev_shard_cycle_range
688720 . last( )
689721 . copied( )
690722 . unwrap_or( FullTracer :: SUBCYCLES_PER_INSN )
691723 ) ;
692724 assert_eq ! (
693- steps
694- . first( )
695- . map( |step| step. heap_maxtouch_addr. before)
696- . unwrap_or_default( ) ,
725+ summary. first_heap_before,
697726 self . prev_shard_heap_range
698727 . last( )
699728 . copied( )
700729 . unwrap_or( self . platform. heap. start)
701- . into( )
702730 ) ;
703731 assert_eq ! (
704- steps
705- . first( )
706- . map( |step| step. hint_maxtouch_addr. before)
707- . unwrap_or_default( ) ,
732+ summary. first_hint_before,
708733 self . prev_shard_hint_range
709734 . last( )
710735 . copied( )
711736 . unwrap_or( self . platform. hints. start)
712- . into( )
713737 ) ;
714738 }
715739
716740 let shard_ctx = ShardContext {
717741 shard_id : self . cur_shard_id ,
718- cur_shard_cycle_range : steps . first ( ) . map ( |step| step . cycle ( ) as usize ) . unwrap ( )
719- ..( steps . last ( ) . unwrap ( ) . cycle ( ) + FullTracer :: SUBCYCLES_PER_INSN ) as usize ,
742+ cur_shard_cycle_range : summary . first_cycle as usize
743+ ..( summary . last_cycle + FullTracer :: SUBCYCLES_PER_INSN ) as usize ,
720744 addr_future_accesses : self . addr_future_accesses . clone ( ) ,
721745 prev_shard_cycle_range : self . prev_shard_cycle_range . clone ( ) ,
722746 prev_shard_heap_range : self . prev_shard_heap_range . clone ( ) ,
723747 prev_shard_hint_range : self . prev_shard_hint_range . clone ( ) ,
724748 platform : self . platform . clone ( ) ,
725- shard_heap_addr_range : steps
726- . first ( )
727- . map ( |step| step. heap_maxtouch_addr . before . 0 )
728- . unwrap_or_default ( )
729- ..steps
730- . last ( )
731- . map ( |step| step. heap_maxtouch_addr . after . 0 )
732- . unwrap_or_default ( ) ,
733- shard_hint_addr_range : steps
734- . first ( )
735- . map ( |step| step. hint_maxtouch_addr . before . 0 )
736- . unwrap_or_default ( )
737- ..steps
738- . last ( )
739- . map ( |step| step. hint_maxtouch_addr . after . 0 )
740- . unwrap_or_default ( ) ,
749+ shard_heap_addr_range : summary. first_heap_before ..summary. last_heap_after ,
750+ shard_hint_addr_range : summary. first_hint_before ..summary. last_hint_after ,
741751 ..Default :: default ( )
742752 } ;
743753 self . prev_shard_cycle_range
@@ -750,7 +760,7 @@ impl ShardContextBuilder {
750760 self . cur_acc_cycle = 0 ;
751761 self . cur_shard_id += 1 ;
752762
753- Some ( shard_ctx)
763+ Some ( ( shard_ctx, summary ) )
754764 }
755765}
756766
@@ -1124,6 +1134,7 @@ pub fn init_static_addrs(program: &Program) -> Vec<MemInitRecord> {
11241134pub struct ConstraintSystemConfig < E : ExtensionField > {
11251135 pub zkvm_cs : ZKVMConstraintSystem < E > ,
11261136 pub config : Rv32imConfig < E > ,
1137+ pub inst_dispatch_builder : InstructionDispatchBuilder ,
11271138 pub mmu_config : MmuConfig < E > ,
11281139 pub dummy_config : DummyExtraConfig < E > ,
11291140 pub prog_config : ProgramTableConfig ,
@@ -1134,14 +1145,15 @@ pub fn construct_configs<E: ExtensionField>(
11341145) -> ConstraintSystemConfig < E > {
11351146 let mut zkvm_cs = ZKVMConstraintSystem :: new_with_platform ( program_params) ;
11361147
1137- let config = Rv32imConfig :: < E > :: construct_circuits ( & mut zkvm_cs) ;
1148+ let ( config, inst_dispatch_builder ) = Rv32imConfig :: < E > :: construct_circuits ( & mut zkvm_cs) ;
11381149 let mmu_config = MmuConfig :: < E > :: construct_circuits ( & mut zkvm_cs) ;
11391150 let dummy_config = DummyExtraConfig :: < E > :: construct_circuits ( & mut zkvm_cs) ;
11401151 let prog_config = zkvm_cs. register_table_circuit :: < ProgramTableCircuit < E > > ( ) ;
11411152 zkvm_cs. register_global_state :: < GlobalState > ( ) ;
11421153 ConstraintSystemConfig {
11431154 zkvm_cs,
11441155 config,
1156+ inst_dispatch_builder,
11451157 mmu_config,
11461158 dummy_config,
11471159 prog_config,
@@ -1195,27 +1207,27 @@ pub fn generate_witness<'a, E: ExtensionField>(
11951207 "execution trace must contain at least one step"
11961208 ) ;
11971209
1210+ let mut instrunction_dispatch_ctx = system_config. inst_dispatch_builder . to_dispatch_ctx ( ) ;
11981211 let pi_template = emul_result. pi . clone ( ) ;
11991212 let mut step_iter = StepReplay :: new (
12001213 platform. clone ( ) ,
12011214 program. clone ( ) ,
12021215 init_mem_state,
12031216 emul_result. executed_steps ,
12041217 ) ;
1205- let mut shard_steps = Vec :: new ( ) ;
1206-
12071218 std:: iter:: from_fn ( move || {
12081219 info_span ! (
12091220 "[ceno] app_prove.generate_witness" ,
12101221 shard_id = shard_ctx_builder. cur_shard_id
12111222 )
12121223 . in_scope ( || {
1213- let mut shard_ctx = match shard_ctx_builder. position_next_shard (
1224+ instrunction_dispatch_ctx. begin_shard ( ) ;
1225+ let ( mut shard_ctx, shard_summary) = match shard_ctx_builder. position_next_shard (
12141226 & mut step_iter,
12151227 & system_config. config ,
1216- & mut shard_steps ,
1228+ |step| instrunction_dispatch_ctx . ingest_step ( step ) ,
12171229 ) {
1218- Some ( ctx ) => ctx ,
1230+ Some ( result ) => result ,
12191231 None => return None ,
12201232 } ;
12211233
@@ -1224,23 +1236,22 @@ pub fn generate_witness<'a, E: ExtensionField>(
12241236 tracing:: debug!(
12251237 "{}th shard collect {} steps, heap_addr_range {:x} - {:x}, hint_addr_range {:x} - {:x}" ,
12261238 shard_ctx. shard_id,
1227- shard_steps . len ( ) ,
1239+ shard_summary . step_count ,
12281240 shard_ctx. shard_heap_addr_range. start,
12291241 shard_ctx. shard_heap_addr_range. end,
12301242 shard_ctx. shard_hint_addr_range. start,
12311243 shard_ctx. shard_hint_addr_range. end,
12321244 ) ;
12331245
12341246 let current_shard_offset_cycle = shard_ctx. current_shard_offset_cycle ( ) ;
1235- let last_step = shard_steps. last ( ) . expect ( "shard must contain steps" ) ;
1236- let current_shard_end_cycle =
1237- last_step. cycle ( ) + FullTracer :: SUBCYCLES_PER_INSN - current_shard_offset_cycle;
1247+ let current_shard_end_cycle = shard_summary. last_cycle + FullTracer :: SUBCYCLES_PER_INSN
1248+ - current_shard_offset_cycle;
12381249 let current_shard_init_pc = if shard_ctx. is_first_shard ( ) {
12391250 program. entry
12401251 } else {
1241- shard_steps . first ( ) . unwrap ( ) . pc ( ) . before . 0
1252+ shard_summary . first_pc_before
12421253 } ;
1243- let current_shard_end_pc = last_step . pc ( ) . after . 0 ;
1254+ let current_shard_end_pc = shard_summary . last_pc_after ;
12441255
12451256 pi. init_pc = current_shard_init_pc;
12461257 pi. init_cycle = FullTracer :: SUBCYCLES_PER_INSN ;
@@ -1267,13 +1278,13 @@ pub fn generate_witness<'a, E: ExtensionField>(
12671278 }
12681279
12691280 let time = std:: time:: Instant :: now ( ) ;
1270- let dummy_records = system_config
1281+ system_config
12711282 . config
12721283 . assign_opcode_circuit (
12731284 & system_config. zkvm_cs ,
12741285 & mut shard_ctx,
1286+ & mut instrunction_dispatch_ctx,
12751287 & mut zkvm_witness,
1276- & shard_steps,
12771288 )
12781289 . unwrap ( ) ;
12791290 tracing:: debug!( "assign_opcode_circuit finish in {:?}" , time. elapsed( ) ) ;
@@ -1283,8 +1294,8 @@ pub fn generate_witness<'a, E: ExtensionField>(
12831294 . assign_opcode_circuit (
12841295 & system_config. zkvm_cs ,
12851296 & mut shard_ctx,
1297+ & instrunction_dispatch_ctx,
12861298 & mut zkvm_witness,
1287- dummy_records,
12881299 )
12891300 . unwrap ( ) ;
12901301 tracing:: debug!( "assign_dummy_config finish in {:?}" , time. elapsed( ) ) ;
@@ -1375,7 +1386,6 @@ pub fn generate_witness<'a, E: ExtensionField>(
13751386 "assign_dynamic_init_table_circuit finish in {:?}" ,
13761387 time. elapsed( )
13771388 ) ;
1378-
13791389 let time = std:: time:: Instant :: now ( ) ;
13801390 system_config
13811391 . mmu_config
@@ -2096,14 +2106,10 @@ mod tests {
20962106 let mut steps_iter = ( 0 ..executed_instruction) . map ( |i| {
20972107 StepRecord :: new_ecall_any ( FullTracer :: SUBCYCLES_PER_INSN * ( i + 1 ) as u64 , 0 . into ( ) )
20982108 } ) ;
2099- let mut steps = Vec :: new ( ) ;
2100-
21012109 let shard_ctx = std:: iter:: from_fn ( || {
2102- shard_ctx_builder. position_next_shard (
2103- & mut steps_iter,
2104- & UniformStepExtractor { } ,
2105- & mut steps,
2106- )
2110+ shard_ctx_builder
2111+ . position_next_shard ( & mut steps_iter, & UniformStepExtractor { } , |_| { } )
2112+ . map ( |( ctx, _) | ctx)
21072113 } )
21082114 . collect_vec ( ) ;
21092115
0 commit comments