Skip to content

Commit 29f6733

Browse files
Add redeemer stuff (#184)
* Add memoize * Add async_batched * Add os token converter * Add META_VAULT_ID * Fix lint * Fix func metadata
1 parent 8f5e4a7 commit 29f6733

File tree

6 files changed

+110
-3
lines changed

6 files changed

+110
-3
lines changed

sw_utils/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .common import InterruptHandler
1+
from .common import InterruptHandler, async_batched
22
from .consensus import (
33
PENDING_STATUSES,
44
ExtendedAsyncBeacon,
@@ -9,7 +9,7 @@
99
get_chain_latest_head,
1010
get_consensus_client,
1111
)
12-
from .decorators import retry_aiohttp_errors, retry_ipfs_exception, safe
12+
from .decorators import memoize, retry_aiohttp_errors, retry_ipfs_exception, safe
1313
from .event_scanner import EventProcessor, EventScanner
1414
from .exceptions import IpfsException
1515
from .execution import GasManager, get_execution_client
@@ -30,6 +30,7 @@
3030
NETWORKS,
3131
BaseNetworkConfig,
3232
)
33+
from .os_token_converter import OsTokenConverter
3334
from .password import generate_password
3435
from .protocol_config import build_protocol_config
3536
from .signing import (

sw_utils/common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import logging
33
import signal
4-
from typing import Any
4+
from typing import Any, AsyncGenerator, TypeVar
55
from urllib.parse import urlparse, urlunparse
66

7+
T = TypeVar('T')
8+
79
logger = logging.getLogger(__name__)
810

911

@@ -65,3 +67,17 @@ def urljoin(base: str, *args: str) -> str:
6567

6668
def _join_paths(*args: str) -> str:
6769
return '/'.join(str(x).strip('/') for x in args)
70+
71+
72+
async def async_batched(
73+
async_gen: AsyncGenerator[T, None], batch_size: int
74+
) -> AsyncGenerator[list[T], None]:
75+
"""Batch items from an async iterator. Replacement for itertools.batched."""
76+
batch = []
77+
async for item in async_gen:
78+
batch.append(item)
79+
if len(batch) == batch_size:
80+
yield batch
81+
batch = []
82+
if batch:
83+
yield batch

sw_utils/decorators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,36 @@
1818
default_logger = logging.getLogger(__name__)
1919

2020

21+
def memoize(func: Callable) -> Callable:
22+
"""
23+
Helper to memoize both sync and async functions.
24+
Main usage is for async functions because `functools.cache` won't work with them.
25+
"""
26+
cache: dict = {}
27+
28+
@wraps(func)
29+
async def memoized_async_func(*args, **kwargs): # type: ignore
30+
key = (args, frozenset(sorted(kwargs.items())))
31+
if key in cache:
32+
return cache[key]
33+
result = await func(*args, **kwargs)
34+
cache[key] = result
35+
return result
36+
37+
@wraps(func)
38+
def memoized_sync_func(*args, **kwargs): # type: ignore
39+
key = (args, frozenset(sorted(kwargs.items())))
40+
if key in cache:
41+
return cache[key]
42+
result = func(*args, **kwargs)
43+
cache[key] = result
44+
return result
45+
46+
if asyncio.iscoroutinefunction(func):
47+
return memoized_async_func
48+
return memoized_sync_func
49+
50+
2151
def safe(func: Callable) -> Callable:
2252
if asyncio.iscoroutinefunction(func):
2353

sw_utils/networks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class BaseNetworkConfig:
5050
PECTRA_VAULT_VERSION: int
5151
OS_TOKEN_VAULT_CONTROLLER_CONTRACT_ADDRESS: ChecksumAddress
5252
MIN_EFFECTIVE_PRIORITY_FEE_PER_GAS: Wei
53+
META_VAULT_ID: HexStr
5354

5455
@property
5556
def SECONDS_PER_BLOCK(self) -> int:
@@ -128,6 +129,7 @@ def PECTRA_SLOT(self) -> int:
128129
'0x2A261e60FB14586B474C208b1B7AC6D0f5000306'
129130
),
130131
MIN_EFFECTIVE_PRIORITY_FEE_PER_GAS=Web3.to_wei(0, 'gwei'),
132+
META_VAULT_ID=HexStr('0xcfece609e9557b5f0b085dadb8b7c99d43a9052d28ab3d68c9e3c1a9c3ab85c0'),
131133
),
132134
HOODI: BaseNetworkConfig(
133135
SLOTS_PER_EPOCH=32,
@@ -179,6 +181,7 @@ def PECTRA_SLOT(self) -> int:
179181
'0x140Fc69Eabd77fFF91d9852B612B2323256f7Ac1'
180182
),
181183
MIN_EFFECTIVE_PRIORITY_FEE_PER_GAS=Web3.to_wei(0, 'gwei'),
184+
META_VAULT_ID=HexStr('0xcfece609e9557b5f0b085dadb8b7c99d43a9052d28ab3d68c9e3c1a9c3ab85c0'),
182185
),
183186
GNOSIS: BaseNetworkConfig(
184187
SLOTS_PER_EPOCH=16,
@@ -234,5 +237,6 @@ def PECTRA_SLOT(self) -> int:
234237
'0x60B2053d7f2a0bBa70fe6CDd88FB47b579B9179a'
235238
),
236239
MIN_EFFECTIVE_PRIORITY_FEE_PER_GAS=Web3.to_wei(1, 'gwei'),
240+
META_VAULT_ID=HexStr('0xfb5cee5ecc2ff8d1a7a5ad55f89156551c040a926bec689720ab06063922454d'),
237241
),
238242
}

sw_utils/os_token_converter.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from web3.types import Wei
2+
3+
4+
class OsTokenConverter:
5+
"""
6+
Convert between shares and assets based on total assets and total shares.
7+
Helps to avoid repeating calls to the contract.
8+
"""
9+
10+
def __init__(self, total_assets: Wei, total_shares: Wei):
11+
self.total_assets = total_assets
12+
self.total_shares = total_shares
13+
14+
def to_shares(self, assets: Wei) -> Wei:
15+
if self.total_assets == 0:
16+
return Wei(0)
17+
return Wei((assets * self.total_shares) // self.total_assets)
18+
19+
def to_assets(self, shares: Wei) -> Wei:
20+
if self.total_shares == 0:
21+
return Wei(0)
22+
return Wei((shares * self.total_assets) // self.total_shares)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
from web3.types import Wei
3+
4+
from sw_utils.os_token_converter import OsTokenConverter
5+
6+
7+
class TestOsTokenConverter:
8+
@pytest.mark.parametrize(
9+
'total_assets,total_shares,assets,expected_shares',
10+
[
11+
(Wei(1000), Wei(100), Wei(500), Wei(50)),
12+
(Wei(2000), Wei(200), Wei(1000), Wei(100)),
13+
(Wei(0), Wei(100), Wei(500), Wei(0)), # Edge case: total_assets is 0
14+
(Wei(1000), Wei(0), Wei(500), Wei(0)), # Edge case: total_shares is 0
15+
],
16+
)
17+
def test_to_shares(self, total_assets, total_shares, assets, expected_shares):
18+
converter = OsTokenConverter(total_assets=total_assets, total_shares=total_shares)
19+
shares = converter.to_shares(assets)
20+
assert shares == expected_shares
21+
22+
@pytest.mark.parametrize(
23+
'total_assets,total_shares,shares,expected_assets',
24+
[
25+
(Wei(1000), Wei(100), Wei(50), Wei(500)),
26+
(Wei(2000), Wei(200), Wei(100), Wei(1000)),
27+
(Wei(1000), Wei(0), Wei(50), Wei(0)), # Edge case: total_shares is 0
28+
(Wei(0), Wei(100), Wei(50), Wei(0)), # Edge case: total_assets is 0
29+
],
30+
)
31+
def test_to_assets(self, total_assets, total_shares, shares, expected_assets):
32+
converter = OsTokenConverter(total_assets=total_assets, total_shares=total_shares)
33+
assets = converter.to_assets(shares)
34+
assert assets == expected_assets

0 commit comments

Comments
 (0)