Skip to content

Commit 84fc022

Browse files
authored
fix: add panic catch on operator calling FFI (#1196)
2 parents e06df5c + 20d03fb commit 84fc022

File tree

14 files changed

+176
-42
lines changed

14 files changed

+176
-42
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
#include <stdbool.h>
2+
#include <stdint.h>
23

3-
bool verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);
4+
int32_t verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);

operator/merkle_tree/lib/src/lib.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use aligned_sdk::core::types::{
44
use lambdaworks_crypto::merkle_tree::merkle::MerkleTree;
55
use log::error;
66

7-
#[no_mangle]
8-
pub extern "C" fn verify_merkle_tree_batch_ffi(
7+
fn inner_verify_merkle_tree_batch_ffi(
98
batch_ptr: *const u8,
109
batch_len: usize,
1110
merkle_root: &[u8; 32],
@@ -53,6 +52,22 @@ pub extern "C" fn verify_merkle_tree_batch_ffi(
5352
computed_batch_merkle_tree.root == *merkle_root
5453
}
5554

55+
#[no_mangle]
56+
pub extern "C" fn verify_merkle_tree_batch_ffi(
57+
batch_ptr: *const u8,
58+
batch_len: usize,
59+
merkle_root: &[u8; 32],
60+
) -> i32 {
61+
let result = std::panic::catch_unwind(|| {
62+
inner_verify_merkle_tree_batch_ffi(batch_ptr, batch_len, merkle_root)
63+
});
64+
65+
match result {
66+
Ok(v) => v as i32,
67+
Err(_) => -1,
68+
}
69+
}
70+
5671
#[cfg(test)]
5772
mod tests {
5873
use super::*;
@@ -75,7 +90,7 @@ mod tests {
7590
let result =
7691
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);
7792

78-
assert_eq!(result, true);
93+
assert_eq!(result, 1);
7994
}
8095

8196
#[test]
@@ -92,7 +107,7 @@ mod tests {
92107
let result =
93108
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);
94109

95-
assert_eq!(result, false);
110+
assert_eq!(result, 0);
96111
}
97112

98113
#[test]
@@ -109,6 +124,6 @@ mod tests {
109124
let result =
110125
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);
111126

112-
assert_eq!(result, false);
127+
assert_eq!(result, 0);
113128
}
114129
}

operator/merkle_tree/merkle_tree.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,35 @@ package merkle_tree
88
*/
99
import "C"
1010
import "unsafe"
11+
import "fmt"
1112

12-
func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) bool {
13+
func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) (isVerified bool, err error) {
14+
// Here we define the return value on failure
15+
isVerified = false
16+
err = nil
1317
if len(batchBuffer) == 0 {
14-
return false
18+
return isVerified, err
1519
}
1620

21+
// This will catch any go panic
22+
defer func() {
23+
rec := recover()
24+
if rec != nil {
25+
err = fmt.Errorf("Panic was caught while verifying merkle tree batch: %s", rec)
26+
}
27+
}()
28+
1729
batchPtr := (*C.uchar)(unsafe.Pointer(&batchBuffer[0]))
1830
merkleRootPtr := (*C.uchar)(unsafe.Pointer(&merkleRootBuffer[0]))
19-
return (bool)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))
31+
32+
r := (C.int32_t)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))
33+
34+
if r == -1 {
35+
err = fmt.Errorf("Panic happened on FFI while verifying merkle tree batch")
36+
return isVerified, err
37+
}
38+
39+
isVerified = (r == 1)
40+
41+
return isVerified, err
2042
}

operator/merkle_tree/merkle_tree_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ func TestVerifyMerkleTreeBatch(t *testing.T) {
3232
var merkleRoot [32]byte
3333
copy(merkleRoot[:], merkle_root)
3434

35-
if !VerifyMerkleTreeBatch(batchByteValue, merkleRoot) {
35+
verified, err := VerifyMerkleTreeBatch(batchByteValue, merkleRoot)
36+
if err != nil || !verified {
3637
t.Errorf("Batch did not verify Merkle Root")
3738
}
3839

operator/pkg/operator.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,13 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
496496
results <- verificationResult
497497

498498
case common.SP1:
499-
verificationResult := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
500-
o.Logger.Infof("SP1 proof verification result: %t", verificationResult)
501-
results <- verificationResult
499+
verificationResult, err := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
500+
o.handleVerificationResult(results, verificationResult, err, "SP1 proof verification")
502501

503502
case common.Risc0:
504-
verificationResult := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
503+
verificationResult, err := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
505504
verificationData.VmProgramCode, verificationData.PubInput)
505+
o.handleVerificationResult(results, verificationResult, err, "RiscZero proof verification")
506506

507507
o.Logger.Infof("Risc0 proof verification result: %t", verificationResult)
508508
results <- verificationResult
@@ -512,6 +512,16 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
512512
}
513513
}
514514

515+
func (o *Operator) handleVerificationResult(results chan bool, isVerified bool, err error, name string) {
516+
if err != nil {
517+
o.Logger.Errorf("%v failed %v", name, err)
518+
results <- false
519+
} else {
520+
o.Logger.Infof("%v result: %t", name, isVerified)
521+
results <- isVerified
522+
}
523+
}
524+
515525
// VerifyPlonkProofBLS12_381 verifies a PLONK proof using BLS12-381 curve.
516526
func (o *Operator) verifyPlonkProofBLS12_381(proofBytes []byte, pubInputBytes []byte, verificationKeyBytes []byte) bool {
517527
return o.verifyPlonkProof(proofBytes, pubInputBytes, verificationKeyBytes, ecc.BLS12_381)

operator/pkg/s3.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ func (o *Operator) getBatchFromDataService(ctx context.Context, batchURL string,
9090

9191
// Checks if downloaded merkle root is the same as the expected one
9292
o.Logger.Infof("Verifying batch merkle tree...")
93-
merkle_root_check := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
94-
if !merkle_root_check {
95-
return nil, fmt.Errorf("merkle root check failed")
93+
merkle_root_check, err := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
94+
if err != nil || !merkle_root_check {
95+
return nil, fmt.Errorf("Error while verifying merkle tree batch")
9696
}
9797
o.Logger.Infof("Batch merkle tree verified")
9898

operator/risc_zero/lib/risc_zero.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#include <stdbool.h>
22
#include <stdint.h>
33

4-
bool verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);
4+
int32_t verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);

operator/risc_zero/lib/src/lib.rs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use log::error;
22
use risc0_zkvm::{InnerReceipt, Receipt};
33

4-
#[no_mangle]
5-
pub extern "C" fn verify_risc_zero_receipt_ffi(
4+
fn inner_verify_risc_zero_receipt_ffi(
65
inner_receipt_bytes: *const u8,
76
inner_receipt_len: u32,
87
image_id: *const u8,
@@ -43,6 +42,32 @@ pub extern "C" fn verify_risc_zero_receipt_ffi(
4342
false
4443
}
4544

45+
#[no_mangle]
46+
pub extern "C" fn verify_risc_zero_receipt_ffi(
47+
inner_receipt_bytes: *const u8,
48+
inner_receipt_len: u32,
49+
image_id: *const u8,
50+
image_id_len: u32,
51+
public_input: *const u8,
52+
public_input_len: u32,
53+
) -> i32 {
54+
let result = std::panic::catch_unwind(|| {
55+
inner_verify_risc_zero_receipt_ffi(
56+
inner_receipt_bytes,
57+
inner_receipt_len,
58+
image_id,
59+
image_id_len,
60+
public_input,
61+
public_input_len,
62+
)
63+
});
64+
65+
match result {
66+
Ok(v) => v as i32,
67+
Err(_) => -1,
68+
}
69+
}
70+
4671
#[cfg(test)]
4772
mod tests {
4873
use super::*;
@@ -69,7 +94,7 @@ mod tests {
6994
public_input,
7095
PUBLIC_INPUT.len() as u32,
7196
);
72-
assert!(result)
97+
assert_eq!(result, 1)
7398
}
7499

75100
#[test]
@@ -86,7 +111,7 @@ mod tests {
86111
public_input,
87112
PUBLIC_INPUT.len() as u32,
88113
);
89-
assert!(!result)
114+
assert_eq!(result, 0)
90115
}
91116

92117
#[test]
@@ -103,6 +128,6 @@ mod tests {
103128
public_input,
104129
0,
105130
);
106-
assert!(!result)
131+
assert_eq!(result, 0)
107132
}
108133
}

operator/risc_zero/risc_zero.go

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,44 @@ package risc_zero
77
#include "lib/risc_zero.h"
88
*/
99
import "C"
10-
import (
11-
"unsafe"
12-
)
10+
import "unsafe"
11+
import "fmt"
12+
13+
func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) (isVerified bool, err error) {
14+
// Here we define the return value on failure
15+
isVerified = false
16+
err = nil
1317

14-
func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) bool {
1518
if len(innerReceiptBuffer) == 0 || len(imageIdBuffer) == 0 {
16-
return false
19+
return isVerified, err
1720
}
1821

22+
// This will catch any go panic
23+
defer func() {
24+
rec := recover()
25+
if rec != nil {
26+
err = fmt.Errorf("Panic was caught while verifying risc0 proof: %s", rec)
27+
}
28+
}()
29+
1930
receiptPtr := (*C.uchar)(unsafe.Pointer(&innerReceiptBuffer[0]))
2031
imageIdPtr := (*C.uchar)(unsafe.Pointer(&imageIdBuffer[0]))
2132

33+
r := (C.int32_t)(0)
34+
2235
if len(publicInputBuffer) == 0 { // allow empty public input
23-
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
36+
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
37+
} else {
38+
publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
39+
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
2440
}
2541

26-
publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
27-
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
42+
if r == -1 {
43+
err = fmt.Errorf("Panic happened on FFI while verifying risc0 proof")
44+
return isVerified, err
45+
}
46+
47+
isVerified = (r == 1)
48+
49+
return isVerified, err
2850
}

operator/risc_zero/risc_zero_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ func TestFibonacciRiscZeroProofVerifies(t *testing.T) {
2222
if err != nil {
2323
t.Errorf("could not open public input file: %s", err)
2424
}
25-
26-
if !risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes) {
25+
verified, err := risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes)
26+
if err != nil || !verified {
2727
t.Errorf("proof did not verify")
2828
}
2929
}

0 commit comments

Comments
 (0)