Skip to content

Commit 6d497d3

Browse files
take output from runner instead of from proof
1 parent 3f0e22c commit 6d497d3

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

.github/workflows/upload_artifacts_workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Upload-Artifacts
33
on:
44
push:
55
branches:
6-
- nitsan/upload_stwo_run_and_prove_to_gcp
6+
- nitsan/take_output_from_runner
77

88
jobs:
99
artifacts-push:

crates/stwo_run_and_prove/src/main.rs

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ use cairo_air::utils::ProofFormat;
33
use cairo_air::verifier::verify_cairo;
44
use cairo_program_runner_lib::cairo_run_program;
55
use cairo_program_runner_lib::utils::{get_cairo_run_config, get_program, get_program_input};
6+
use cairo_vm::Felt252;
67
use cairo_vm::types::errors::program_errors::ProgramError;
78
use cairo_vm::types::layout_name::LayoutName;
89
use cairo_vm::vm::errors::cairo_run_errors::CairoRunError;
910
use cairo_vm::vm::errors::runner_errors::RunnerError;
11+
use cairo_vm::vm::errors::vm_errors::VirtualMachineError;
12+
use cairo_vm::vm::runners::cairo_runner::CairoRunner;
1013
use clap::Parser;
1114
#[cfg(test)]
1215
use mockall::automock;
@@ -33,8 +36,6 @@ use stwo_cairo_utils::file_utils::{IoErrorWithPath, create_file, read_to_string}
3336
use thiserror::Error;
3437
use tracing::{error, info, warn};
3538

36-
type OutputVec = Vec<[u32; 8]>;
37-
3839
fn parse_usize_ge1(s: &str) -> Result<usize, String> {
3940
let v: usize = s.parse().map_err(|_| "must be a number".to_string())?;
4041
if v >= 1 {
@@ -120,8 +121,12 @@ enum StwoRunAndProveError {
120121
Serializing(#[from] sonic_rs::error::Error),
121122
#[error(transparent)]
122123
Proving(#[from] ProvingError),
124+
#[error(transparent)]
125+
VM(#[from] VirtualMachineError),
123126
#[error("cairo verification failed.")]
124127
Verification,
128+
#[error("Failed to parse output line as Felt decimal.")]
129+
OutputParsing,
125130
}
126131

127132
// Implement From<Box<CairoRunError>> manually
@@ -201,11 +206,9 @@ fn stwo_run_and_prove(
201206
let mut prover_input_info = runner.get_prover_input_info()?;
202207
info!("Adapting prover input.");
203208
let prover_input = adapter(&mut prover_input_info)?;
204-
let (successful_proof_attempt, output_vec) =
205-
prove_with_retries(prover_input, prove_config, prover)?;
206-
209+
let successful_proof_attempt = prove_with_retries(prover_input, prove_config, prover)?;
207210
if let Some(output_path) = program_output {
208-
save_output_to_file(output_vec, output_path)?;
211+
write_output_to_file(runner, output_path)?;
209212
}
210213

211214
Ok(successful_proof_attempt)
@@ -219,7 +222,7 @@ fn prove_with_retries(
219222
prover_input: ProverInput,
220223
prove_config: ProveConfig,
221224
prover: Box<dyn ProverTrait>,
222-
) -> Result<(usize, OutputVec), StwoRunAndProveError> {
225+
) -> Result<usize, StwoRunAndProveError> {
223226
let ProverParameters {
224227
channel_hash,
225228
pcs_config,
@@ -239,7 +242,7 @@ fn prove_with_retries(
239242
std::fs::create_dir_all(&prove_config.proofs_dir)?;
240243
let proof_format = prove_config.proof_format;
241244

242-
for i in 1..=prove_config.n_proof_attempts + 1 {
245+
for i in 1..prove_config.n_proof_attempts + 1 {
243246
info!(
244247
"Attempting to generate proof {}/{}.",
245248
i, prove_config.n_proof_attempts
@@ -253,12 +256,12 @@ fn prove_with_retries(
253256
channel_hash,
254257
prove_config.verify,
255258
) {
256-
Ok(output_values) => {
259+
Ok(()) => {
257260
info!(
258261
"Proof generated and verified successfully on attempt {}/{}",
259262
i, prove_config.n_proof_attempts
260263
);
261-
return Ok((i, output_values));
264+
return Ok(i);
262265
}
263266

264267
Err(StwoRunAndProveError::Verification) => {
@@ -290,7 +293,7 @@ fn choose_channel_and_prove(
290293
proof_format: &ProofFormat,
291294
channel_hash: ChannelHash,
292295
verify: bool,
293-
) -> Result<OutputVec, StwoRunAndProveError> {
296+
) -> Result<(), StwoRunAndProveError> {
294297
match channel_hash {
295298
ChannelHash::Blake2s => prove::<Blake2sMerkleChannel>(
296299
cairo_prover_inputs,
@@ -316,7 +319,7 @@ trait ProverTrait {
316319
proof_format: &ProofFormat,
317320
channel_hash: ChannelHash,
318321
verify: bool,
319-
) -> Result<OutputVec, StwoRunAndProveError>;
322+
) -> Result<(), StwoRunAndProveError>;
320323
}
321324

322325
struct StwoProverEntryPoint;
@@ -329,7 +332,7 @@ impl ProverTrait for StwoProverEntryPoint {
329332
proof_format: &ProofFormat,
330333
channel_hash: ChannelHash,
331334
verify: bool,
332-
) -> Result<OutputVec, StwoRunAndProveError> {
335+
) -> Result<(), StwoRunAndProveError> {
333336
choose_channel_and_prove(
334337
cairo_prover_inputs,
335338
proof_file_path,
@@ -349,7 +352,7 @@ fn prove<MC: MerkleChannel>(
349352
proof_file_path: PathBuf,
350353
proof_format: &ProofFormat,
351354
verify: bool,
352-
) -> Result<OutputVec, StwoRunAndProveError>
355+
) -> Result<(), StwoRunAndProveError>
353356
where
354357
SimdBackend: BackendForChannel<MC>,
355358
MC::H: Serialize,
@@ -380,8 +383,6 @@ where
380383
}
381384
}
382385

383-
let output_addresses_and_values = proof.claim.public_data.public_memory.output.clone();
384-
385386
if verify {
386387
// We want to map this error to `StwoRunAndProveError::Verification` because we intend to
387388
// retry the proof generation in case of a verification failure. In the calling function we
@@ -391,23 +392,23 @@ where
391392
.map_err(|_| StwoRunAndProveError::Verification)?;
392393
}
393394

394-
let output_values = output_addresses_and_values
395-
.into_iter()
396-
.map(|(_, value)| value)
397-
.collect();
398-
399-
Ok(output_values)
395+
Ok(())
400396
}
401397

402-
/// Saves the program output to the specified output path as [u32; 8] values,
403-
/// that will be converted to [u256] in the Prover service.
404-
fn save_output_to_file(
405-
output_vec: OutputVec,
398+
/// Write the program output to the specified output path as Felt252 values.
399+
fn write_output_to_file(
400+
mut runner: CairoRunner,
406401
output_path: PathBuf,
407402
) -> Result<(), StwoRunAndProveError> {
408403
info!("Saving program output to: {:?}", output_path);
409-
let serialized_output = sonic_rs::to_string(&output_vec)?;
410-
std::fs::write(output_path, serialized_output)?;
404+
405+
let mut output_buffer = String::new();
406+
runner.vm.write_output(&mut output_buffer)?;
407+
let output_lines = output_buffer
408+
.lines()
409+
.map(|line| Felt252::from_dec_str(line).map_err(|_| StwoRunAndProveError::OutputParsing))
410+
.collect::<Result<Vec<Felt252>, _>>()?;
411+
std::fs::write(output_path, sonic_rs::to_string_pretty(&output_lines)?)?;
411412
Ok(())
412413
}
413414

@@ -418,7 +419,7 @@ mod tests {
418419
use std::fs;
419420
use tempfile::{NamedTempFile, TempDir, TempPath};
420421

421-
const ARRAY_SUM_EXPECTED_OUTPUT: [u32; 8] = [50, 0, 0, 0, 0, 0, 0, 0];
422+
const ARRAY_SUM_EXPECTED_OUTPUT: [&str; 1] = ["0x32"];
422423
const RESOURCES_PATH: &str = "resources";
423424
const PROGRAM_FILE_NAME: &str = "array_sum.json";
424425
const PROVER_PARAMS_FILE_NAME: &str = "prover_params.json";
@@ -485,7 +486,7 @@ mod tests {
485486
.returning(move |_, proof_file, _, _, _| {
486487
let expected_proof_file = get_path(EXPECTED_PROOF_FILE_NAME);
487488
fs::copy(&expected_proof_file, &proof_file).expect("Failed to copy proof file.");
488-
Ok(vec![ARRAY_SUM_EXPECTED_OUTPUT])
489+
Ok(())
489490
});
490491

491492
let successful_proof_attempt =
@@ -531,7 +532,7 @@ mod tests {
531532
// for the last attempt.
532533
let mut results = (0..n_proof_attempts.saturating_sub(1))
533534
.map(|_| Err(StwoRunAndProveError::Verification))
534-
.chain(std::iter::once(Ok(vec![ARRAY_SUM_EXPECTED_OUTPUT])));
535+
.chain(std::iter::once(Ok(())));
535536

536537
mock_prover
537538
.expect_choose_channel_and_prove()
@@ -575,10 +576,10 @@ mod tests {
575576
// Verifying the proof output.
576577
let output_content =
577578
std::fs::read_to_string(output_temp_file).expect("Failed to read output file");
578-
let output_vec: OutputVec =
579+
let output: Vec<String> =
579580
sonic_rs::from_str(&output_content).expect("Failed to parse output");
580581
assert_eq!(
581-
output_vec[0], ARRAY_SUM_EXPECTED_OUTPUT,
582+
output, ARRAY_SUM_EXPECTED_OUTPUT,
582583
"Expected output to be {:?}",
583584
ARRAY_SUM_EXPECTED_OUTPUT
584585
);

0 commit comments

Comments
 (0)