Skip to content

Commit 3359434

Browse files
take output from runner instead of from proof
1 parent 15c6a89 commit 3359434

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

.github/workflows/upload_artifacts_workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Upload-Artifacts
33
on:
44
push:
55
branches:
6-
- nitsan/upload_stwo_run_and_prove_to_gcp
6+
- nitsan/take_output_from_runner
77

88
jobs:
99
artifacts-push:

crates/stwo_run_and_prove/src/main.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ use cairo_air::utils::ProofFormat;
33
use cairo_air::verifier::verify_cairo;
44
use cairo_program_runner_lib::cairo_run_program;
55
use cairo_program_runner_lib::utils::{get_cairo_run_config, get_program, get_program_input};
6+
use cairo_vm::Felt252;
67
use cairo_vm::types::errors::program_errors::ProgramError;
78
use cairo_vm::types::layout_name::LayoutName;
89
use cairo_vm::vm::errors::cairo_run_errors::CairoRunError;
910
use cairo_vm::vm::errors::runner_errors::RunnerError;
11+
use cairo_vm::vm::errors::vm_errors::VirtualMachineError;
12+
use cairo_vm::vm::runners::cairo_runner::CairoRunner;
1013
use clap::Parser;
1114
#[cfg(test)]
1215
use mockall::automock;
@@ -34,8 +37,6 @@ use stwo_cairo_utils::file_utils::{IoErrorWithPath, create_file, read_to_string}
3437
use thiserror::Error;
3538
use tracing::{error, info, warn};
3639

37-
type OutputVec = Vec<[u32; 8]>;
38-
3940
fn parse_usize_ge1(s: &str) -> Result<usize, String> {
4041
let v: usize = s.parse().map_err(|_| "must be a number".to_string())?;
4142
if v >= 1 {
@@ -121,8 +122,12 @@ enum StwoRunAndProveError {
121122
Serializing(#[from] sonic_rs::error::Error),
122123
#[error(transparent)]
123124
Proving(#[from] ProvingError),
125+
#[error(transparent)]
126+
VM(#[from] VirtualMachineError),
124127
#[error("cairo verification failed.")]
125128
Verification,
129+
#[error("Failed to parse output line as Felt decimal.")]
130+
OutputParsing,
126131
}
127132

128133
// Implement From<Box<CairoRunError>> manually
@@ -199,12 +204,11 @@ fn stwo_run_and_prove(
199204
let program_input = get_program_input(&program_input)?;
200205
info!("Running cairo run program.");
201206
let runner = cairo_run_program(&program, program_input, cairo_run_config)?;
207+
info!("Adapting prover input.");
202208
let prover_input = adapter(&runner);
203-
let (successful_proof_attempt, output_vec) =
204-
prove_with_retries(prover_input, prove_config, prover)?;
205-
209+
let successful_proof_attempt = prove_with_retries(prover_input, prove_config, prover)?;
206210
if let Some(output_path) = program_output {
207-
save_output_to_file(output_vec, output_path)?;
211+
write_output_to_file(runner, output_path)?;
208212
}
209213

210214
Ok(successful_proof_attempt)
@@ -218,7 +222,7 @@ fn prove_with_retries(
218222
prover_input: ProverInput,
219223
prove_config: ProveConfig,
220224
prover: Box<dyn ProverTrait>,
221-
) -> Result<(usize, OutputVec), StwoRunAndProveError> {
225+
) -> Result<usize, StwoRunAndProveError> {
222226
let ProverParameters {
223227
channel_hash,
224228
pcs_config,
@@ -238,7 +242,7 @@ fn prove_with_retries(
238242
std::fs::create_dir_all(&prove_config.proofs_dir)?;
239243
let proof_format = prove_config.proof_format;
240244

241-
for i in 1..=prove_config.n_proof_attempts + 1 {
245+
for i in 1..=prove_config.n_proof_attempts {
242246
info!(
243247
"Attempting to generate proof {}/{}.",
244248
i, prove_config.n_proof_attempts
@@ -252,12 +256,12 @@ fn prove_with_retries(
252256
channel_hash,
253257
prove_config.verify,
254258
) {
255-
Ok(output_values) => {
259+
Ok(()) => {
256260
info!(
257261
"Proof generated and verified successfully on attempt {}/{}",
258262
i, prove_config.n_proof_attempts
259263
);
260-
return Ok((i, output_values));
264+
return Ok(i);
261265
}
262266

263267
Err(StwoRunAndProveError::Verification) => {
@@ -289,7 +293,7 @@ fn choose_channel_and_prove(
289293
proof_format: &ProofFormat,
290294
channel_hash: ChannelHash,
291295
verify: bool,
292-
) -> Result<OutputVec, StwoRunAndProveError> {
296+
) -> Result<(), StwoRunAndProveError> {
293297
match channel_hash {
294298
ChannelHash::Blake2s => prove::<Blake2sMerkleChannel>(
295299
cairo_prover_inputs,
@@ -315,7 +319,7 @@ trait ProverTrait {
315319
proof_format: &ProofFormat,
316320
channel_hash: ChannelHash,
317321
verify: bool,
318-
) -> Result<OutputVec, StwoRunAndProveError>;
322+
) -> Result<(), StwoRunAndProveError>;
319323
}
320324

321325
struct StwoProverEntryPoint;
@@ -328,7 +332,7 @@ impl ProverTrait for StwoProverEntryPoint {
328332
proof_format: &ProofFormat,
329333
channel_hash: ChannelHash,
330334
verify: bool,
331-
) -> Result<OutputVec, StwoRunAndProveError> {
335+
) -> Result<(), StwoRunAndProveError> {
332336
choose_channel_and_prove(
333337
cairo_prover_inputs,
334338
proof_file_path,
@@ -348,7 +352,7 @@ fn prove<MC: MerkleChannel>(
348352
proof_file_path: PathBuf,
349353
proof_format: &ProofFormat,
350354
verify: bool,
351-
) -> Result<OutputVec, StwoRunAndProveError>
355+
) -> Result<(), StwoRunAndProveError>
352356
where
353357
SimdBackend: BackendForChannel<MC>,
354358
MC::H: Serialize,
@@ -379,8 +383,6 @@ where
379383
}
380384
}
381385

382-
let output_addresses_and_values = proof.claim.public_data.public_memory.output.clone();
383-
384386
if verify {
385387
// We want to map this error to `StwoRunAndProveError::Verification` because we intend to
386388
// retry the proof generation in case of a verification failure. In the calling function we
@@ -390,23 +392,27 @@ where
390392
.map_err(|_| StwoRunAndProveError::Verification)?;
391393
}
392394

393-
let output_values = output_addresses_and_values
394-
.into_iter()
395-
.map(|(_, value)| value)
396-
.collect();
397-
398-
Ok(output_values)
395+
Ok(())
399396
}
400397

401-
/// Saves the program output to the specified output path as [u32; 8] values,
402-
/// that will be converted to [u256] in the Prover service.
403-
fn save_output_to_file(
404-
output_vec: OutputVec,
398+
/// Write the program output to the specified output path as Felt252 values.
399+
fn write_output_to_file(
400+
mut runner: CairoRunner,
405401
output_path: PathBuf,
406402
) -> Result<(), StwoRunAndProveError> {
407403
info!("Saving program output to: {:?}", output_path);
408-
let serialized_output = sonic_rs::to_string(&output_vec)?;
409-
std::fs::write(output_path, serialized_output)?;
404+
// TODO(Nitsan): move this function to cairo_program_runner_lib or a new utils lib,
405+
// and call it from here and from cairo_program_runner.
406+
407+
let mut output_buffer = String::new();
408+
runner.vm.write_output(&mut output_buffer)?;
409+
let output_lines = output_buffer
410+
.lines()
411+
.map(|line: &str| {
412+
Felt252::from_dec_str(line).map_err(|_| StwoRunAndProveError::OutputParsing)
413+
})
414+
.collect::<Result<Vec<Felt252>, _>>()?;
415+
std::fs::write(output_path, sonic_rs::to_string_pretty(&output_lines)?)?;
410416
Ok(())
411417
}
412418

@@ -417,7 +423,7 @@ mod tests {
417423
use std::fs;
418424
use tempfile::{NamedTempFile, TempDir, TempPath};
419425

420-
const ARRAY_SUM_EXPECTED_OUTPUT: [u32; 8] = [50, 0, 0, 0, 0, 0, 0, 0];
426+
const ARRAY_SUM_EXPECTED_OUTPUT: [Felt252; 1] = [Felt252::from_hex_unchecked("0x32")];
421427
const RESOURCES_PATH: &str = "resources";
422428
const PROGRAM_FILE_NAME: &str = "array_sum.json";
423429
const PROVER_PARAMS_FILE_NAME: &str = "prover_params.json";
@@ -484,7 +490,7 @@ mod tests {
484490
.returning(move |_, proof_file, _, _, _| {
485491
let expected_proof_file = get_path(EXPECTED_PROOF_FILE_NAME);
486492
fs::copy(&expected_proof_file, &proof_file).expect("Failed to copy proof file.");
487-
Ok(vec![ARRAY_SUM_EXPECTED_OUTPUT])
493+
Ok(())
488494
});
489495

490496
let successful_proof_attempt =
@@ -530,7 +536,7 @@ mod tests {
530536
// for the last attempt.
531537
let mut results = (0..n_proof_attempts.saturating_sub(1))
532538
.map(|_| Err(StwoRunAndProveError::Verification))
533-
.chain(std::iter::once(Ok(vec![ARRAY_SUM_EXPECTED_OUTPUT])));
539+
.chain(std::iter::once(Ok(())));
534540

535541
mock_prover
536542
.expect_choose_channel_and_prove()
@@ -574,10 +580,10 @@ mod tests {
574580
// Verifying the proof output.
575581
let output_content =
576582
std::fs::read_to_string(output_temp_file).expect("Failed to read output file");
577-
let output_vec: OutputVec =
583+
let output: Vec<Felt252> =
578584
sonic_rs::from_str(&output_content).expect("Failed to parse output");
579585
assert_eq!(
580-
output_vec[0], ARRAY_SUM_EXPECTED_OUTPUT,
586+
output, ARRAY_SUM_EXPECTED_OUTPUT,
581587
"Expected output to be {:?}",
582588
ARRAY_SUM_EXPECTED_OUTPUT
583589
);

0 commit comments

Comments
 (0)