diff --git a/execution_chain/db/aristo/aristo_proof.nim b/execution_chain/db/aristo/aristo_proof.nim index 45541a60ec..47e19a9bd3 100644 --- a/execution_chain/db/aristo/aristo_proof.nim +++ b/execution_chain/db/aristo/aristo_proof.nim @@ -15,15 +15,11 @@ {.push raises: [].} import + std/[tables, sets, sequtils], eth/common/hashes, results, ./[aristo_desc, aristo_fetch, aristo_get, aristo_serialise, aristo_utils] -# ------------------------------------------------------------------------------ -# Public functions -# ------------------------------------------------------------------------------ - - const ChainRlpNodesNoEntry* = { PartChnLeafPathMismatch, PartChnExtPfxMismatch, PartChnBranchVoidEdge} @@ -133,16 +129,13 @@ proc trackRlpNodes( return err(PartTrkLinkExpected) chain.toOpenArray(1,chain.len-1).trackRlpNodes(nextKey, path.slice nChewOff) -# ------------------------------------------------------------------------------ -# Public functions -# ------------------------------------------------------------------------------ - proc makeProof( db: AristoTxRef; root: VertexID; path: NibblesBuf; nodesCache: var NodesCache; - ): Result[(seq[seq[byte]], bool), AristoError] = + chain: var seq[seq[byte]]; + ): Result[bool, AristoError] = ## This function returns a chain of rlp-encoded nodes along the argument ## path `(root,path)` followed by a `true` value if the `path` argument ## exists in the database. If the argument `path` is not on the database, @@ -150,12 +143,11 @@ proc makeProof( ## ## Errors will only be returned for invalid paths. ## - var chain: seq[seq[byte]] let rc = db.chainRlpNodes((root,root), path, chain, nodesCache) if rc.isOk: - ok((chain, true)) + ok(true) elif rc.error in ChainRlpNodesNoEntry: - ok((chain, false)) + ok(false) else: err(rc.error) @@ -163,8 +155,11 @@ proc makeAccountProof*( db: AristoTxRef; accPath: Hash32; ): Result[(seq[seq[byte]], bool), AristoError] = - var nodesCache: NodesCache - db.makeProof(STATE_ROOT_VID, NibblesBuf.fromBytes accPath.data, nodesCache) + var + nodesCache: NodesCache + proof: seq[seq[byte]] + let exists = ?db.makeProof(STATE_ROOT_VID, NibblesBuf.fromBytes accPath.data, nodesCache, proof) + ok((proof, exists)) proc makeStorageProof*( db: AristoTxRef; @@ -177,8 +172,11 @@ proc makeStorageProof*( if error == FetchPathStoRootMissing: return ok((@[],false)) return err(error) - var nodesCache: NodesCache - db.makeProof(vid, NibblesBuf.fromBytes stoPath.data, nodesCache) + var + nodesCache: NodesCache + proof: seq[seq[byte]] + let exists = ?db.makeProof(vid, NibblesBuf.fromBytes stoPath.data, nodesCache, proof) + ok((proof, exists)) proc makeStorageProofs*( db: AristoTxRef; @@ -196,14 +194,56 @@ proc makeStorageProofs*( var nodesCache: NodesCache proofs = newSeqOfCap[seq[seq[byte]]](stoPaths.len()) - for stoPath in stoPaths: - let (proof, _) = ?db.makeProof(vid, NibblesBuf.fromBytes stoPath.data, nodesCache) + var proof: seq[seq[byte]] + discard ?db.makeProof(vid, NibblesBuf.fromBytes stoPath.data, nodesCache, proof) proofs.add(proof) ok(proofs) -# ---------- +proc makeStorageMultiProof( + db: AristoTxRef; + accPath: Hash32; + stoPaths: openArray[Hash32]; + nodesCache: var NodesCache; + multiProof: var HashSet[seq[byte]] + ): Result[void, AristoError] = + ## Note that the function returns an error unless + ## the argument `accPath` is valid. + let vid = db.fetchStorageID(accPath).valueOr: + if error == FetchPathStoRootMissing: + return ok() + return err(error) + + for stoPath in stoPaths: + var proof: seq[seq[byte]] + discard ?db.makeProof(vid, NibblesBuf.fromBytes stoPath.data, nodesCache, proof) + for node in proof: + multiProof.incl(node) + + ok() + +proc makeMultiProof*( + db: AristoTxRef; + paths: Table[Hash32, seq[Hash32]], # maps each account path to a list of storage paths + multiProof: var seq[seq[byte]] + ): Result[void, AristoError] = + var + nodesCache: NodesCache + proofNodes: HashSet[seq[byte]] + + for accPath, stoPaths in paths: + var accProof: seq[seq[byte]] + let exists = ?db.makeProof(STATE_ROOT_VID, NibblesBuf.fromBytes accPath.data, nodesCache, accProof) + for node in accProof: + proofNodes.incl(node) + + if exists: + ?db.makeStorageMultiProof(accPath, stoPaths, nodesCache, proofNodes) + + multiProof = proofNodes.toSeq() + + ok() proc verifyProof*( chain: openArray[seq[byte]]; @@ -222,7 +262,3 @@ proc verifyProof*( return err(rc.error) except RlpError: return err(PartTrkRlpError) - -# ------------------------------------------------------------------------------ -# End -# ------------------------------------------------------------------------------ diff --git a/execution_chain/db/core_db/base.nim b/execution_chain/db/core_db/base.nim index ad43693c10..8715cd9503 100644 --- a/execution_chain/db/core_db/base.nim +++ b/execution_chain/db/core_db/base.nim @@ -307,6 +307,20 @@ proc getStateRoot*(acc: CoreDbTxRef): CoreDbRc[Hash32] = ok(rc) +proc multiProof*( + acc: CoreDbTxRef; + paths: Table[Hash32, seq[Hash32]]; + multiProof: var seq[seq[byte]] + ): CoreDbRc[void] = + ## Returns a multiproof for every account and storage path specified + ## in the paths table. All rlp-encoded trie nodes from all account + ## and storage proofs are returned in a single list. + + acc.aTx.makeMultiProof(paths, multiProof).isOkOr: + return err(error.toError("", ProofCreate)) + + ok() + # ------------ storage --------------- proc slotProofs*( diff --git a/execution_chain/stateless/witness_generation.nim b/execution_chain/stateless/witness_generation.nim index 3466a40814..522eaf1360 100644 --- a/execution_chain/stateless/witness_generation.nim +++ b/execution_chain/stateless/witness_generation.nim @@ -25,42 +25,55 @@ proc build*( T: type Witness, witnessKeys: WitnessTable, preStateLedger: LedgerRef): T = + var + proofPaths: Table[Hash32, seq[Hash32]] + addedCodeHashes: HashSet[Hash32] + accPreimages: Table[Hash32, array[20, byte]] + stoPreimages: Table[Hash32, array[32, byte]] witness = Witness.init() - addedState = initHashSet[seq[byte]]() - addedCodeHashes = initHashSet[Hash32]() for key, codeTouched in witnessKeys: - if key.slot.isNone(): # Is an account key - witness.addKey(key.address.data()) + let + addressBytes = key.address.data() + accPath = keccak256(addressBytes) + accPreimages[accPath] = addressBytes - let proof = preStateLedger.getAccountProof(key.address) - for trieNode in proof: - addedState.incl(trieNode) + if key.slot.isNone(): # Is an account key + proofPaths.withValue(accPath, v): + discard + do: + proofPaths[accPath] = @[] + # codeTouched is only set for account keys if codeTouched: let codeHash = preStateLedger.getCodeHash(key.address) if codeHash != EMPTY_CODE_HASH and codeHash notin addedCodeHashes: witness.addCodeHash(codeHash) addedCodeHashes.incl(codeHash) - # Add the storage slots for this account - var slots: seq[UInt256] - for key2, codeTouched2 in witnessKeys: - if key2.address == key.address and key2.slot.isSome(): - let slot = key2.slot.get() - slots.add(slot) - witness.addKey(slot.toBytesBE()) - - if slots.len() > 0: - let proofs = preStateLedger.getStorageProof(key.address, slots) - doAssert(proofs.len() == slots.len()) - for proof in proofs: - for trieNode in proof: - addedState.incl(trieNode) - - for s in addedState.items(): - witness.addState(s) + else: # Is a slot key + let + slotBytes = key.slot.get().toBytesBE() + slotPath = keccak256(slotBytes) + stoPreimages[slotPath] = slotBytes + + proofPaths.withValue(accPath, v): + v[].add(slotPath) + do: + var paths: seq[Hash32] + paths.add(slotPath) + proofPaths[accPath] = paths + + var multiProof: seq[seq[byte]] + preStateLedger.txFrame.multiProof(proofPaths, multiProof).isOkOr: + raiseAssert "Failed to get multiproof: " & $$error + witness.state = move(multiProof) + + for accPath, stoPaths in proofPaths: + witness.addKey(accPreimages.getOrDefault(accPath)) + for stoPath in stoPaths: + witness.addKey(stoPreimages.getOrDefault(stoPath)) witness @@ -92,10 +105,11 @@ proc build*( let blockHashes = ledger.getBlockHashesCache() earliestBlockNumber = getEarliestCachedBlockNumber(blockHashes) + if earliestBlockNumber.isSome(): - var n = parent.number - 1 + var n = parent.number while n >= earliestBlockNumber.get(): + dec n let blockHash = ledger.getBlockHash(BlockNumber(n)) doAssert(blockHash != default(Hash32)) witness.addHeaderHash(blockHash) - dec n diff --git a/tests/test_stateless_witness_generation.nim b/tests/test_stateless_witness_generation.nim index 9dd7623235..7a1c54790e 100644 --- a/tests/test_stateless_witness_generation.nim +++ b/tests/test_stateless_witness_generation.nim @@ -144,8 +144,8 @@ suite "Stateless: Witness Generation": check: witness.keys.len() == 5 - witness.keys[0] == addr1.data() - witness.keys[1] == slot1.toBytesBE() - witness.keys[2] == slot2.toBytesBE() - witness.keys[3] == slot3.toBytesBE() - witness.keys[4] == addr2.data() + witness.keys[0] == addr2.data() + witness.keys[1] == addr1.data() + witness.keys[2] == slot1.toBytesBE() + witness.keys[3] == slot2.toBytesBE() + witness.keys[4] == slot3.toBytesBE()