-
-
Notifications
You must be signed in to change notification settings - Fork 871
feat(flag): enable iteration via Flag.__values__ #4739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
974dd55
26eb76c
2bee3dc
3ca2abc
bc44129
5b31a52
262af81
77fd597
2fde4bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| def test_iterate_over_flag_type(get_contract): | ||
| code = """ | ||
| flag Permission: | ||
| A | ||
| B | ||
| C | ||
|
|
||
| @pure | ||
| @external | ||
| def sum_mask() -> uint256: | ||
| acc: uint256 = 0 | ||
| for p: Permission in Permission.__values__: | ||
| acc = acc | convert(p, uint256) | ||
| return acc | ||
| """ | ||
| c = get_contract(code) | ||
| # 1 | 2 | 4 = 7 | ||
| assert c.sum_mask() == 7 | ||
|
|
||
|
|
||
| def test_iterate_over_flag_type_count(get_contract): | ||
| code = """ | ||
| flag Permission: | ||
| A | ||
| B | ||
| C | ||
| D | ||
|
|
||
| @pure | ||
| @external | ||
| def count() -> uint256: | ||
| cnt: uint256 = 0 | ||
| for p: Permission in Permission.__values__: | ||
| cnt += 1 | ||
| return cnt | ||
| """ | ||
| c = get_contract(code) | ||
| assert c.count() == 4 | ||
|
|
||
|
|
||
| def test_iterate_over_flag_type_order(get_contract): | ||
| code = """ | ||
| flag Permission: | ||
| A | ||
| B | ||
| C | ||
| D | ||
|
|
||
| @pure | ||
| @external | ||
| def order_sum() -> uint256: | ||
| acc: uint256 = 0 | ||
| idx: uint256 = 0 | ||
| for p: Permission in Permission.__values__: | ||
| acc = acc + (convert(p, uint256) << idx) | ||
| idx += 1 | ||
| return acc | ||
| """ | ||
| c = get_contract(code) | ||
| # 1 + (2<<1) + (4<<2) + (8<<3) = 1 + 4 + 16 + 64 = 85 | ||
| assert c.order_sum() == 85 | ||
|
|
||
|
|
||
| def test_flag_iter_target_type_mismatch(assert_compile_failed, get_contract): | ||
| from vyper.exceptions import TypeMismatch | ||
|
|
||
| code = """ | ||
| flag A: | ||
| X | ||
| flag B: | ||
| Y | ||
|
|
||
| @pure | ||
| @external | ||
| def f() -> uint256: | ||
| s: uint256 = 0 | ||
| for p: B in A.__values__: | ||
| s += convert(p, uint256) | ||
| return s | ||
| """ | ||
| assert_compile_failed(lambda: get_contract(code), TypeMismatch) | ||
|
|
||
|
|
||
| def test_flag_iter_invalid_iterator(assert_compile_failed, get_contract): | ||
| from vyper.exceptions import InvalidType | ||
|
|
||
| code = """ | ||
| flag P: | ||
| A | ||
|
|
||
| @pure | ||
| @external | ||
| def f() -> uint256: | ||
| s: uint256 = 0 | ||
| for p: P in 5: | ||
| s += 1 | ||
| return s | ||
| """ | ||
| assert_compile_failed(lambda: get_contract(code), InvalidType) | ||
|
|
||
|
|
||
| def test_flag_iter_wrong_target_type(assert_compile_failed, get_contract): | ||
| from vyper.exceptions import TypeMismatch | ||
|
|
||
| code = """ | ||
| flag P: | ||
| A | ||
| B | ||
|
|
||
| @pure | ||
| @external | ||
| def f() -> uint256: | ||
| s: uint256 = 0 | ||
| for p: uint256 in P.__values__: | ||
| s += p # wrong type; loop var must be P | ||
| return s | ||
| """ | ||
| assert_compile_failed(lambda: get_contract(code), TypeMismatch) | ||
|
|
||
|
|
||
| def test_nested_flag_type_iteration(get_contract): | ||
| code = """ | ||
| flag A: | ||
| X | ||
| Y | ||
| Z | ||
|
|
||
| flag B: | ||
| P | ||
| Q | ||
|
|
||
| @pure | ||
| @external | ||
| def product_sum() -> uint256: | ||
| s: uint256 = 0 | ||
| for a: A in A.__values__: | ||
| for b: B in B.__values__: | ||
| s += convert(a, uint256) * convert(b, uint256) | ||
| return s | ||
| """ | ||
| c = get_contract(code) | ||
| # a in {1,2,4}, b in {1,2} => (1+2+4)*(1+2) = 7*3 = 21 | ||
| assert c.product_sum() == 21 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |
| ) | ||
| from vyper.semantics.data_locations import DataLocation | ||
| from vyper.semantics.types.base import VyperType | ||
| from vyper.semantics.types.subscriptable import HashMapT | ||
| from vyper.semantics.types.subscriptable import HashMapT, SArrayT | ||
| from vyper.semantics.types.utils import type_from_abi, type_from_annotation | ||
| from vyper.utils import keccak256 | ||
| from vyper.warnings import Deprecation, vyper_warn | ||
|
|
@@ -75,6 +75,13 @@ | |
| self._helper._id = name | ||
|
|
||
| def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": | ||
| # Special iterator helper for flags: `Flag.__values__` | ||
| # Returns a static array type of all flag values in declaration order. | ||
| if key == "__values__": | ||
|
||
| return SArrayT(self, len(self._flag_members)) | ||
|
|
||
| # Regular flag member access (e.g., `Flag.FOO`) validates the member name | ||
| # and yields the flag type in expression position. | ||
| self._helper.get_member(key, node) | ||
| return self | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.