Skip to content

Commit 7501f10

Browse files
authored
60% state replay speedup (#4434)
* 60% state replay speedup * don't use HashList for epoch participation - in addition to the code currently clearing the caches several times redundantly, clearing has to be done each block nullifying the benefit (35%) * introduce active balance cache - computing it is slow due to cache unfriendliness in the random access pattern and bounds checking and we do it for every block - this cache follows the same update pattern as the active validator index cache (20%) * avoid recomputing base reward several times per attestation (5%) Applying 1024 blocks goes from 20s to ~8s on my laptop - these kinds of requests happen on historical REST queries but also whenever there's a reorg. * fix test and diffs
1 parent 064d164 commit 7501f10

File tree

8 files changed

+68
-40
lines changed

8 files changed

+68
-40
lines changed

beacon_chain/spec/beaconstate.nim

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,13 @@ func get_total_active_balance*(state: ForkyBeaconState, cache: var StateCache):
637637

638638
let epoch = state.get_current_epoch()
639639

640-
get_total_balance(
641-
state, cache.get_shuffled_active_validator_indices(state, epoch))
640+
cache.total_active_balance.withValue(epoch, tab) do:
641+
return tab[]
642+
do:
643+
let tab = get_total_balance(
644+
state, cache.get_shuffled_active_validator_indices(state, epoch))
645+
cache.total_active_balance[epoch] = tab
646+
return tab
642647

643648
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#get_base_reward_per_increment
644649
func get_base_reward_per_increment_sqrt*(
@@ -704,15 +709,15 @@ func get_proposer_reward*(state: ForkyBeaconState,
704709
state, attestation.data, state.slot - attestation.data.slot)
705710
for index in get_attesting_indices(
706711
state, attestation.data, attestation.aggregation_bits, cache):
712+
let
713+
base_reward = get_base_reward(state, index, base_reward_per_increment)
707714
for flag_index, weight in PARTICIPATION_FLAG_WEIGHTS:
708715
if flag_index in participation_flag_indices and
709716
not has_flag(epoch_participation.item(index), flag_index):
710-
epoch_participation[index] =
717+
asList(epoch_participation)[index] =
711718
add_flag(epoch_participation.item(index), flag_index)
712719
# these are all valid; TODO statically verify or do it type-safely
713-
result += get_base_reward(
714-
state, index, base_reward_per_increment) * weight.uint64
715-
epoch_participation.asHashList.clearCache()
720+
result += base_reward * weight.uint64
716721

717722
let proposer_reward_denominator =
718723
(WEIGHT_DENOMINATOR.uint64 - PROPOSER_WEIGHT.uint64) *
@@ -860,8 +865,7 @@ func upgrade_to_altair*(cfg: RuntimeConfig, pre: phase0.BeaconState):
860865
empty_participation: EpochParticipationFlags
861866
inactivity_scores = HashList[uint64, Limit VALIDATOR_REGISTRY_LIMIT]()
862867

863-
doAssert empty_participation.data.setLen(pre.validators.len)
864-
empty_participation.asHashList.resetCache()
868+
doAssert empty_participation.asList.setLen(pre.validators.len)
865869

866870
doAssert inactivity_scores.data.setLen(pre.validators.len)
867871
inactivity_scores.resetCache()

beacon_chain/spec/datatypes/altair.nim

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ type
7878
ParticipationFlags* = uint8
7979

8080
EpochParticipationFlags* =
81-
distinct HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
81+
distinct List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
8282

8383
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#syncaggregate
8484
SyncAggregate* = object
@@ -558,10 +558,8 @@ type
558558
# Represent in full; for the next epoch, current_epoch_participation in
559559
# epoch n is previous_epoch_participation in epoch n+1 but this doesn't
560560
# generalize.
561-
previous_epoch_participation*:
562-
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
563-
current_epoch_participation*:
564-
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
561+
previous_epoch_participation*: EpochParticipationFlags
562+
current_epoch_participation*: EpochParticipationFlags
565563

566564
justification_bits*: JustificationBits
567565
previous_justified_checkpoint*: Checkpoint
@@ -589,26 +587,44 @@ template `[]`*(arr: array[SYNC_COMMITTEE_SIZE, auto] | seq;
589587
makeLimitedU8(SyncSubcommitteeIndex, SYNC_COMMITTEE_SUBNET_COUNT)
590588
makeLimitedU16(IndexInSyncCommittee, SYNC_COMMITTEE_SIZE)
591589

592-
template asHashList*(epochFlags: EpochParticipationFlags): untyped =
593-
HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
590+
template asList*(epochFlags: EpochParticipationFlags): untyped =
591+
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
592+
template asList*(epochFlags: var EpochParticipationFlags): untyped =
593+
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr epochFlags)
594+
tmp[]
595+
596+
template asSeq*(epochFlags: EpochParticipationFlags): untyped =
597+
seq[ParticipationFlags] asList(epochFlags)
598+
599+
template asSeq*(epochFlags: var EpochParticipationFlags): untyped =
600+
let tmp = cast[ptr seq[ParticipationFlags]](addr epochFlags)
601+
tmp[]
594602

595603
template item*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex): ParticipationFlags =
596-
asHashList(epochFlags).item(idx)
604+
asList(epochFlags)[idx]
597605

598-
template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64): ParticipationFlags =
599-
asHashList(epochFlags)[idx]
606+
template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64|int): ParticipationFlags =
607+
asList(epochFlags)[idx]
600608

601609
template `[]=`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex, flags: ParticipationFlags) =
602-
asHashList(epochFlags)[idx] = flags
610+
asList(epochFlags)[idx] = flags
603611

604612
template add*(epochFlags: var EpochParticipationFlags, flags: ParticipationFlags): bool =
605-
asHashList(epochFlags).add flags
613+
asList(epochFlags).add flags
606614

607615
template len*(epochFlags: EpochParticipationFlags): int =
608-
asHashList(epochFlags).len
609-
610-
template data*(epochFlags: EpochParticipationFlags): untyped =
611-
asHashList(epochFlags).data
616+
asList(epochFlags).len
617+
618+
template low*(epochFlags: EpochParticipationFlags): int =
619+
asSeq(epochFlags).low
620+
template high*(epochFlags: EpochParticipationFlags): int =
621+
asSeq(epochFlags).high
622+
623+
template assign*(v: var EpochParticipationFlags, src: EpochParticipationFlags) =
624+
# TODO https://github.com/nim-lang/Nim/issues/21123
625+
mixin assign
626+
var tmp = cast[ptr seq[ParticipationFlags]](addr v)
627+
assign(tmp[], distinctBase src)
612628

613629
func shortLog*(v: SomeBeaconBlock): auto =
614630
(

beacon_chain/spec/datatypes/base.nim

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ type
406406
# This doesn't know about forks or branches in the DAG. It's for straight,
407407
# linear chunks of the chain.
408408
StateCache* = object
409+
total_active_balance*: Table[Epoch, Gwei]
409410
shuffled_active_validator_indices*: Table[Epoch, seq[ValidatorIndex]]
410411
beacon_proposer_indices*: Table[Slot, Option[ValidatorIndex]]
411412
sync_committees*: Table[SyncCommitteePeriod, SyncCommitteeCache]
@@ -923,6 +924,14 @@ func prune*(cache: var StateCache, epoch: Epoch) =
923924
pruneEpoch = epoch - 2
924925

925926
var drops: seq[Slot]
927+
block:
928+
for k in cache.total_active_balance.keys:
929+
if k < pruneEpoch:
930+
drops.add pruneEpoch.start_slot
931+
for drop in drops:
932+
cache.total_active_balance.del drop.epoch
933+
drops.setLen(0)
934+
926935
block:
927936
for k in cache.shuffled_active_validator_indices.keys:
928937
if k < pruneEpoch:
@@ -948,6 +957,7 @@ func prune*(cache: var StateCache, epoch: Epoch) =
948957
drops.setLen(0)
949958

950959
func clear*(cache: var StateCache) =
960+
cache.total_active_balance.clear
951961
cache.shuffled_active_validator_indices.clear
952962
cache.beacon_proposer_indices.clear
953963
cache.sync_committees.clear

beacon_chain/spec/eth2_apis/eth2_rest_serialization.nim

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,15 +612,12 @@ proc readValue*(reader: var JsonReader[RestJson], value: var Epoch) {.
612612
proc writeValue*(writer: var JsonWriter[RestJson],
613613
epochFlags: EpochParticipationFlags)
614614
{.raises: [IOError, Defect].} =
615-
for e in writer.stepwiseArrayCreation(epochFlags.asHashList):
615+
for e in writer.stepwiseArrayCreation(epochFlags.asList):
616616
writer.writeValue $e
617617

618618
proc readValue*(reader: var JsonReader[RestJson],
619619
epochFlags: var EpochParticipationFlags)
620620
{.raises: [SerializationError, IOError, Defect].} =
621-
# Please note that this function won't compute the cached hash tree roots
622-
# immediately. They will be computed on the first HTR attempt.
623-
624621
for e in reader.readArray(string):
625622
let parsed = try:
626623
parseBiggestUInt(e)
@@ -632,7 +629,7 @@ proc readValue*(reader: var JsonReader[RestJson],
632629
reader.raiseUnexpectedValue(
633630
"The usigned integer value should fit in 8 bits")
634631

635-
if not epochFlags.data.add(uint8(parsed)):
632+
if not epochFlags.asList.add(uint8(parsed)):
636633
reader.raiseUnexpectedValue(
637634
"The participation flags list size exceeds limit")
638635

beacon_chain/spec/ssz_codec.nim

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import
1717
./datatypes/base
1818

1919
from ./datatypes/altair import
20-
ParticipationFlags, EpochParticipationFlags, asHashList
20+
ParticipationFlags, EpochParticipationFlags
2121

2222
export codec, base, typetraits, EpochParticipationFlags
2323

@@ -28,7 +28,7 @@ template toSszType*(v: BlsCurveType): auto = toRaw(v)
2828
template toSszType*(v: ForkDigest|GraffitiBytes): auto = distinctBase(v)
2929
template toSszType*(v: Version): auto = distinctBase(v)
3030
template toSszType*(v: JustificationBits): auto = distinctBase(v)
31-
template toSszType*(epochFlags: EpochParticipationFlags): auto = asHashList epochFlags
31+
template toSszType*(v: EpochParticipationFlags): auto = asList v
3232

3333
func fromSszBytes*(T: type GraffitiBytes, data: openArray[byte]): T {.raisesssz.} =
3434
if data.len != sizeof(result):
@@ -60,4 +60,6 @@ func fromSszBytes*(T: type JustificationBits, bytes: openArray[byte]): T {.raise
6060
copyMem(result.addr, unsafeAddr bytes[0], sizeof(result))
6161

6262
func fromSszBytes*(T: type EpochParticipationFlags, bytes: openArray[byte]): T {.raisesssz.} =
63-
readSszValue(bytes, HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] result)
63+
# TODO https://github.com/nim-lang/Nim/issues/21123
64+
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr result)
65+
readSszValue(bytes, tmp[])

beacon_chain/spec/state_transition.nim

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ func process_slot*(
120120
hash_tree_root(state.latest_block_header)
121121

122122
func clear_epoch_from_cache(cache: var StateCache, epoch: Epoch) =
123+
cache.total_active_balance.del epoch
123124
cache.shuffled_active_validator_indices.del epoch
124125

125126
for slot in epoch.slots():

beacon_chain/spec/state_transition_epoch.nim

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,11 @@ func process_participation_flag_updates*(
10051005

10061006
const zero = 0.ParticipationFlags
10071007
for i in 0 ..< state.current_epoch_participation.len:
1008-
state.current_epoch_participation.data[i] = zero
1008+
asList(state.current_epoch_participation)[i] = zero
10091009

10101010
# Shouldn't be wasted zeroing, because state.current_epoch_participation only
10111011
# grows. New elements are automatically initialized to 0, as required.
1012-
doAssert state.current_epoch_participation.data.setLen(state.validators.len)
1013-
1014-
state.current_epoch_participation.asHashList.resetCache()
1012+
doAssert state.current_epoch_participation.asList.setLen(state.validators.len)
10151013

10161014
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#sync-committee-updates
10171015
func process_sync_committee_updates*(

beacon_chain/statediff.nim

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ func diffStates*(state0, state1: bellatrix.BeaconState): BeaconStateDiff =
143143
slashing: state1.slashings[state0.slot.epoch.uint64 mod
144144
EPOCHS_PER_HISTORICAL_VECTOR.uint64],
145145

146-
previous_epoch_participation: state1.previous_epoch_participation.data,
147-
current_epoch_participation: state1.current_epoch_participation.data,
146+
previous_epoch_participation: state1.previous_epoch_participation,
147+
current_epoch_participation: state1.current_epoch_participation,
148148

149149
justification_bits: state1.justification_bits,
150150
previous_justified_checkpoint: state1.previous_justified_checkpoint,
@@ -192,9 +192,9 @@ func applyDiff*(
192192
assign(state.slashings.mitem(epochIndex), stateDiff.slashing)
193193

194194
assign(
195-
state.previous_epoch_participation.data, stateDiff.previous_epoch_participation)
195+
state.previous_epoch_participation, stateDiff.previous_epoch_participation)
196196
assign(
197-
state.current_epoch_participation.data, stateDiff.current_epoch_participation)
197+
state.current_epoch_participation, stateDiff.current_epoch_participation)
198198

199199
state.justification_bits = stateDiff.justification_bits
200200
assign(

0 commit comments

Comments
 (0)