Skip to content

Commit a4713ac

Browse files
authored
Allow chain/chain_id to take str/int and convert to int (#12)
* Ensure that chain_id is always an int * Allow str types, and call normalize function * allow str | int for any chain param, and handle with normalization * move normalize_chain_id to common/utils.py * remove normalize_chain_ids and ensure types are correct * allow str chain_id/chain and use noramlize function to convert to int * feat: add regex to normalize_chain_id to extract digits * feat: add initial tests for utils.py * docs: add information on running pytest * build: add pytest to deps * fix: rebase to main * docs: adjust uv commands
1 parent 02e8358 commit a4713ac

File tree

7 files changed

+193
-72
lines changed

7 files changed

+193
-72
lines changed

python/thirdweb-ai/README.md

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,46 @@ def adapt_to_my_framework(tools: list[Tool]):
146146

147147
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
148148

149-
## Contributing
149+
## Development and Testing
150150

151-
Contributions are welcome! Please feel free to submit a Pull Request.
151+
### Setting up development environment
152152

153-
## Support
153+
```bash
154+
# Clone the repository
155+
git clone https://github.com/thirdweb-dev/ai.git
156+
cd ai/python/thirdweb-ai
157+
158+
# Install dependencies with UV
159+
uv sync
160+
```
161+
162+
### Running tests
163+
164+
We use pytest for testing. You can run the tests with:
165+
166+
```bash
167+
# Run all tests
168+
uv run pytest
169+
170+
# Run tests with verbose output
171+
uv run pytest -v
154172

155-
If you need help or have questions, please contact [[email protected]](mailto:[email protected]).
173+
# Run specific test file
174+
uv run pytest tests/common/test_utils.py
175+
176+
# Run tests with coverage report
177+
uv run pytest --cov=thirdweb_ai
178+
179+
# Run tests and generate HTML coverage report
180+
uv run pytest --cov=thirdweb_ai --cov-report=html
181+
```
182+
183+
### Linting and Type Checking
184+
185+
```bash
186+
# Run the ruff linter
187+
uv run ruff check .
188+
189+
# Run type checking with pyright
190+
uv run pyright
191+
```
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import re
2+
3+
4+
def extract_digits(value: int | str) -> int:
5+
value_str = str(value).strip("\"'")
6+
digit_match = re.search(r"\d+", value_str)
7+
8+
if not digit_match:
9+
raise ValueError(f"Chain ID '{value}' does not contain any digits")
10+
11+
extracted_digits = digit_match.group()
12+
13+
if not extracted_digits.isdigit():
14+
raise ValueError(
15+
f"Extracted value '{extracted_digits}' is not a valid digit string"
16+
)
17+
return int(extracted_digits)
18+
19+
20+
def normalize_chain_id(
21+
chain_id: int | str | list[int | str] | None,
22+
) -> int | list[int] | None:
23+
"""Normalize chain IDs to integers."""
24+
25+
if chain_id is None:
26+
return None
27+
28+
if isinstance(chain_id, list):
29+
return [extract_digits(c) for c in chain_id]
30+
31+
return extract_digits(chain_id)

python/thirdweb-ai/src/thirdweb_ai/services/engine.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Annotated, Any
22

3+
from thirdweb_ai.common.utils import normalize_chain_id
34
from thirdweb_ai.services.service import Service
45
from thirdweb_ai.tools.tool import tool
56

@@ -9,15 +10,15 @@ def __init__(
910
self,
1011
engine_url: str,
1112
engine_auth_jwt: str,
12-
chain_id: int | None = None,
13+
chain_id: int | str | None = None,
1314
backend_wallet_address: str | None = None,
1415
secret_key: str = "",
1516
):
1617
super().__init__(base_url=engine_url, secret_key=secret_key)
1718
self.engine_url = engine_url
1819
self.engine_auth_jwt = engine_auth_jwt
1920
self.backend_wallet_address = backend_wallet_address
20-
self.chain_id = str(chain_id) if chain_id else None
21+
self.chain_id = normalize_chain_id(chain_id)
2122

2223
def _make_headers(self):
2324
headers = super()._make_headers()
@@ -75,7 +76,7 @@ def get_all_backend_wallet(
7576
def get_wallet_balance(
7677
self,
7778
chain_id: Annotated[
78-
str,
79+
str | int,
7980
"The numeric blockchain network ID to query (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
8081
],
8182
backend_wallet_address: Annotated[
@@ -84,9 +85,9 @@ def get_wallet_balance(
8485
] = None,
8586
) -> dict[str, Any]:
8687
"""Get wallet balance for native or ERC20 tokens."""
87-
chain_id = chain_id or self.chain_id
88+
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
8889
backend_wallet_address = backend_wallet_address or self.backend_wallet_address
89-
return self._get(f"backend-wallet/{chain_id}/{backend_wallet_address}/get-balance")
90+
return self._get(f"backend-wallet/{normalized_chain}/{backend_wallet_address}/get-balance")
9091

9192
@tool(
9293
description="Send an on-chain transaction. This powerful function can transfer native currency (ETH, MATIC), ERC20 tokens, or execute any arbitrary contract interaction. The transaction is signed and broadcast to the blockchain automatically."
@@ -122,10 +123,10 @@ def send_transaction(
122123
"data": data or "0x",
123124
}
124125

125-
chain_id = chain_id or self.chain_id
126+
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
126127
backend_wallet_address = backend_wallet_address or self.backend_wallet_address
127128
return self._post(
128-
f"backend-wallet/{chain_id}/send-transaction",
129+
f"backend-wallet/{normalized_chain}/send-transaction",
129130
payload,
130131
headers={"X-Backend-Wallet-Address": backend_wallet_address},
131132
)
@@ -170,8 +171,8 @@ def read_contract(
170171
"functionName": function_name,
171172
"args": function_args or [],
172173
}
173-
chain_id = chain_id or self.chain_id
174-
return self._get(f"contract/{chain_id!s}/{contract_address}/read", payload)
174+
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
175+
return self._get(f"contract/{normalized_chain}/{contract_address}/read", payload)
175176

176177
@tool(
177178
description="Execute a state-changing function on a smart contract by sending a transaction. This allows you to modify on-chain data, such as transferring tokens, minting NFTs, or updating contract configuration. The transaction is automatically signed by your backend wallet and submitted to the blockchain."
@@ -208,9 +209,9 @@ def write_contract(
208209
if value and value != "0":
209210
payload["txOverrides"] = {"value": value}
210211

211-
chain_id = chain_id or self.chain_id
212+
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
212213
return self._post(
213-
f"contract/{chain_id!s}/{contract_address}/write",
214+
f"contract/{normalized_chain}/{contract_address}/write",
214215
payload,
215216
headers={"X-Backend-Wallet-Address": self.backend_wallet_address},
216217
)

0 commit comments

Comments
 (0)