Skip to content

Commit 5fbfbd6

Browse files
authored
feat: add devnet parm to fixture (#119)
1 parent eb32ff7 commit 5fbfbd6

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

python/test_utils/fixtures.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from test_utils.starknet_test_utils import StarknetTestUtils
33
from typing import Iterator
4+
import contextlib
45
import time
56

67

@@ -10,3 +11,14 @@ def starknet_test_utils() -> Iterator[StarknetTestUtils]:
1011
# TODO: replace the sleep with await.
1112
time.sleep(2)
1213
yield val
14+
15+
16+
@pytest.fixture
17+
def starknet_test_utils_factory():
18+
@contextlib.contextmanager
19+
def _factory(**kwargs):
20+
with StarknetTestUtils.context_manager(**kwargs) as val:
21+
time.sleep(2)
22+
yield val
23+
24+
return _factory

python/test_utils/starknet_test_utils.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import subprocess
66
import tempfile
77
from starknet_py.net.account.account import Account
8-
from starknet_py.net.full_node_client import FullNodeClient
8+
from starknet_py.devnet_utils import DevnetClient
99
from starknet_py.net.models.chains import StarknetChainId
1010
from starknet_py.net.signer.key_pair import KeyPair
1111
from starknet_py.net.signer.stark_curve_signer import StarkCurveSigner
12-
import requests
1312
from starknet_py.contract import Contract
1413
import re
1514
from pathlib import Path
@@ -91,24 +90,57 @@ class StarknetTestUtils:
9190

9291
MAX_RETRIES = 5
9392

94-
def __init__(self, port: int):
95-
self.starknet = Starknet(port=port)
93+
def __init__(
94+
self,
95+
port: int,
96+
seed: int,
97+
initial_balance: int,
98+
accounts: int,
99+
starknet_chain_id: StarknetChainId,
100+
fork_network: Optional[str],
101+
fork_block: Optional[int],
102+
):
103+
self.starknet = Starknet(
104+
port=port,
105+
seed=seed,
106+
initial_balance=initial_balance,
107+
accounts=accounts,
108+
starknet_chain_id=starknet_chain_id,
109+
fork_network=fork_network,
110+
fork_block=fork_block,
111+
)
96112
self.accounts = self.starknet.accounts
97113

98114
def stop(self):
99115
self.starknet.stop()
100116

101117
@classmethod
102118
@contextlib.contextmanager
103-
def context_manager(cls, port: int | None = None, backoff: float = 0.1):
119+
def context_manager(
120+
cls,
121+
port: int | None = None,
122+
seed: int = 500,
123+
initial_balance: int = 10**30,
124+
starknet_chain_id: StarknetChainId = StarknetChainId.SEPOLIA,
125+
fork_network: Optional[str] = None,
126+
fork_block: Optional[int] = None,
127+
backoff: float = 0.1,
128+
):
104129
"""
105130
Retry creating a Starknet instance if port is already in use.
106131
If port is None, will pick random free port.
107132
"""
108133
for attempt in range(cls.MAX_RETRIES):
109134
try:
110135
actual_port = port or get_free_port()
111-
res = cls(port=actual_port)
136+
res = cls(
137+
port=actual_port,
138+
seed=seed,
139+
initial_balance=initial_balance,
140+
starknet_chain_id=starknet_chain_id,
141+
fork_network=fork_network,
142+
fork_block=fork_block,
143+
)
112144
yield res
113145
return
114146
except OSError as e:
@@ -124,15 +156,7 @@ def context_manager(cls, port: int | None = None, backoff: float = 0.1):
124156
pass
125157

126158
def advance_time(self, n_seconds: int):
127-
payload = {
128-
"jsonrpc": "2.0",
129-
"id": "1",
130-
"method": "devnet_increaseTime",
131-
"params": {"time": n_seconds},
132-
}
133-
rpc_url = f"{self.starknet.get_client().url}\rpc"
134-
response = requests.post(rpc_url, json=payload)
135-
response.raise_for_status()
159+
self.starknet.get_client().increase_time(n_seconds)
136160

137161

138162
class Starknet:
@@ -206,9 +230,9 @@ def __init__(
206230
def __del__(self):
207231
self.stop()
208232

209-
def get_client(self) -> FullNodeClient:
233+
def get_client(self) -> DevnetClient:
210234
node_url = f"http://localhost:{self.port}"
211-
return FullNodeClient(node_url=node_url)
235+
return DevnetClient(node_url=node_url)
212236

213237
def stop(self):
214238
if not self.is_alive:

0 commit comments

Comments
 (0)