Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
626 changes: 626 additions & 0 deletions tests/functional/codegen/features/test_variable_initialization.py

Large diffs are not rendered by default.

148 changes: 148 additions & 0 deletions tests/functional/syntax/test_variable_initialization_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import pytest

from vyper import compile_code
from vyper.exceptions import (
CallViolation,
ImmutableViolation,
StateAccessViolation,
TypeMismatch,
UndeclaredDefinition,
VariableDeclarationException,
)


@pytest.mark.parametrize(
"bad_code,exc",
[
(
"""
# Cannot use function calls in initializer
@external
@view
def some_func() -> uint256:
return 42

x: uint256 = self.some_func()
""",
CallViolation,
),
(
"""
# Cannot use self attributes in initializer
y: uint256 = 10
x: uint256 = self.y
""",
StateAccessViolation,
),
],
)
def test_invalid_initializers(bad_code, exc):
with pytest.raises(exc):
compile_code(bad_code)


@pytest.mark.parametrize(
"bad_code,exc",
[
(
"""
# Type mismatch in initialization
x: uint256 = -1 # negative number for unsigned
""",
TypeMismatch,
),
(
"""
# Type mismatch with wrong literal type
x: address = 123
""",
TypeMismatch,
),
(
"""
# Boolean type mismatch
x: bool = 1
""",
TypeMismatch,
),
(
"""
# String literal not allowed for numeric type
x: uint256 = "hello"
""",
TypeMismatch,
),
],
)
def test_type_mismatch_in_initialization(bad_code, exc):
with pytest.raises(exc):
compile_code(bad_code)


def test_constant_requires_value():
"""Constants must have an initializer"""
bad_code = """
X: constant(uint256) # Missing initializer
"""
with pytest.raises(VariableDeclarationException):
compile_code(bad_code)


def test_immutable_requires_constructor_assignment_without_initializer():
"""Immutables without initializer must be set in constructor"""
bad_code = """
X: immutable(uint256) # No initializer

@deploy
def __init__():
pass # Forgot to set X
"""
with pytest.raises(ImmutableViolation):
compile_code(bad_code)


def test_initializer_cannot_reference_other_storage_vars():
"""Initializers cannot reference other storage variables"""
bad_code = """
a: uint256 = 100
b: uint256 = self.a + 50 # Cannot reference self.a
"""
with pytest.raises(StateAccessViolation):
compile_code(bad_code)


def test_circular_reference_in_constants():
"""Constants cannot have circular references"""
bad_code = """
A: constant(uint256) = B
B: constant(uint256) = A
"""
# This will raise VyperException with multiple UndeclaredDefinition errors
from vyper.exceptions import VyperException

with pytest.raises((UndeclaredDefinition, VyperException)):
compile_code(bad_code)


def test_initializer_cannot_use_pure_function_calls():
"""Cannot call even pure functions in initializers"""
bad_code = """
@internal
@pure
def helper() -> uint256:
return 42

x: uint256 = self.helper()
"""
with pytest.raises(StateAccessViolation):
compile_code(bad_code)


def test_initializer_cannot_reference_other_vars():
"""Cannot reference other storage variables regardless of order"""
bad_code = """
y: uint256 = 100
x: uint256 = self.y # Cannot reference self.y even though it's declared first
"""
with pytest.raises(StateAccessViolation):
compile_code(bad_code)
13 changes: 3 additions & 10 deletions vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ module: ( DOCSTRING
| import
| struct_def
| interface_def
| constant_def
| variable_def
| enum_def // TODO deprecate at some point in favor of flag
| flag_def
Expand All @@ -34,17 +33,11 @@ import: _IMPORT DOT* _import_path [import_alias]
| _import_from _IMPORT ( WILDCARD | _import_name [import_alias] )
| _import_from _IMPORT "(" import_list ")"

// Constant definitions
// Variable definitions (including constants)
// NOTE: Temporary until decorators used
constant: "constant" "(" type ")"
constant_private: NAME ":" constant
constant_with_getter: NAME ":" "public" "(" constant ")"
constant_def: (constant_private | constant_with_getter) "=" expr

variable_annotation: ("public" | "reentrant" | "immutable" | "transient" | "constant") "(" (variable_annotation | type) ")"
variable_def: NAME ":" (variable_annotation | type) ["=" expr]
variable: NAME ":" type
// NOTE: Temporary until decorators used
variable_annotation: ("public" | "reentrant" | "immutable" | "transient") "(" (variable_annotation | type) ")"
variable_def: NAME ":" (variable_annotation | type)

// A decorator "wraps" a method, modifying it's context.
// NOTE: One or more can be applied (some combos might conflict)
Expand Down
5 changes: 1 addition & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,10 +1449,7 @@ def validate(self):
"Only public variables can be marked `reentrant`!", self
)

if not self.is_constant and self.value is not None:
raise VariableDeclarationException(
f"{self._pretty_location} variables cannot have an initial value", self.value
)
# Allow initialization values for all variable types
if not isinstance(self.target, Name):
raise VariableDeclarationException("Invalid variable declaration", self.target)

Expand Down
7 changes: 7 additions & 0 deletions vyper/codegen/function_definitions/external_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ def generate_ir_for_external_function(code, compilation_target):

body += nonreentrant_pre

# if this is a constructor, inject storage variable initializations
if func_t.is_constructor:
from vyper.codegen.stmt import generate_variable_initializations

init_ir = generate_variable_initializations(compilation_target._module, context)
body.append(init_ir)

body += [parse_body(code.body, context, ensure_terminated=True)]

# wrap the body in labeled block
Expand Down
19 changes: 18 additions & 1 deletion vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import vyper.ast as vy_ast
from vyper.codegen import core, jumptable_utils
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import shr
from vyper.codegen.function_definitions import (
generate_ir_for_external_function,
generate_ir_for_internal_function,
)
from vyper.codegen.ir_node import IRnode
from vyper.codegen.memory_allocator import MemoryAllocator
from vyper.codegen.stmt import generate_variable_initializations
from vyper.compiler.settings import _is_debug_mode
from vyper.exceptions import CompilerPanic
from vyper.semantics.types.module import ModuleT
from vyper.utils import OrderedSet, method_id_int
from vyper.utils import MemoryPositions, OrderedSet, method_id_int


# calculate globally reachable functions to see which
Expand Down Expand Up @@ -510,6 +513,20 @@ def generate_ir_for_module(module_t: ModuleT) -> tuple[IRnode, IRnode]:
deploy_code.extend(ctor_internal_func_irs)

else:
# Generate initialization code for variables even without explicit constructor
# Create a minimal constructor context
memory_allocator = MemoryAllocator(MemoryPositions.RESERVED_MEMORY)
context = Context(
vars_=None,
module_ctx=module_t,
memory_allocator=memory_allocator,
constancy=Constancy.Mutable,
is_ctor_context=True,
)

init_ir = generate_variable_initializations(module_t._module, context)
deploy_code.append(init_ir)

if immutables_len != 0: # pragma: nocover
raise CompilerPanic("unreachable")
deploy_code.append(["deploy", 0, runtime, 0])
Expand Down
38 changes: 37 additions & 1 deletion vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
IRnode,
add_ofst,
clamp_le,
data_location_to_address_space,
get_dyn_array_count,
get_element_ptr,
get_type_for_exact_size,
Expand Down Expand Up @@ -356,7 +357,9 @@ def _is_terminated(code):


# codegen a list of statements
def parse_body(code, context, ensure_terminated=False):
def parse_body(
code: list[vy_ast.VyperNode], context: Context, ensure_terminated: bool = False
) -> IRnode:
ir_node = ["seq"]
for stmt in code:
ir = parse_stmt(stmt, context)
Expand All @@ -369,3 +372,36 @@ def parse_body(code, context, ensure_terminated=False):
# force zerovalent, even last statement
ir_node.append("pass") # CMC 2022-01-16 is this necessary?
return IRnode.from_list(ir_node)


def generate_variable_initializations(module_ast: vy_ast.Module, context: Context) -> IRnode:
"""
Generate initialization IR for storage variables with default values.
Returns an IRnode sequence containing all initialization statements.
"""
assert context.is_ctor_context, "Variable initialization must happen in constructor context"

init_stmts = []

for node in module_ast.body:
if isinstance(node, vy_ast.VariableDecl) and node.value is not None:
# skip constants - they are compile-time only
if node.is_constant:
continue

# generate assignment: self.var = value
varinfo = node.target._metadata["varinfo"]
location = data_location_to_address_space(varinfo.location, context.is_ctor_context)

lhs = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation=f"self.{node.target.id}",
)

rhs = Expr(node.value, context).ir_node
init_stmt = make_setter(lhs, rhs)
init_stmts.append(init_stmt)

return IRnode.from_list(["seq"] + init_stmts)
14 changes: 5 additions & 9 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
is_naked_self_reference,
uses_state,
validate_expected_type,
)
Expand All @@ -52,7 +53,6 @@
HashMapT,
IntegerT,
SArrayT,
SelfT,
StringT,
StructT,
TupleT,
Expand Down Expand Up @@ -183,17 +183,13 @@ def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name, typ: VyperType)
if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members:
raise StateAccessViolation("not allowed to query address members in pure functions")

if is_naked_self_reference(node):
raise StateAccessViolation("not allowed to query `self` in pure functions")

if (varinfo := info.var_info) is None:
return
# self is magic. we only need to check it if it is not the root of an Attribute
# node. (i.e. it is bare like `self`, not `self.foo`)
is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance(
node.get_ancestor(), vy_ast.Attribute
)
if is_naked_self:
raise StateAccessViolation("not allowed to query `self` in pure functions")

if varinfo.is_state_variable() or is_naked_self:
if varinfo.is_state_variable():
raise StateAccessViolation("not allowed to query state variables in pure functions")


Expand Down
18 changes: 16 additions & 2 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def visit_VariableDecl(self, node):
assignments = self.ast.get_descendants(
vy_ast.Assign, filters={"target.id": node.target.id}
)
if not assignments:
# immutables with initialization values don't require assignment
if not assignments and node.value is None:
# Special error message for common wrong usages via `self.<immutable name>`
wrong_self_attribute = self.ast.get_descendants(
vy_ast.Attribute, {"value.id": "self", "attr": node.target.id}
Expand Down Expand Up @@ -688,7 +689,20 @@ def _validate_self_namespace():

return _finalize()

assert node.value is None # checked in VariableDecl.validate()
# allow initialization for storage variables
if node.value is not None:
# validate the initialization expression
ExprVisitor().visit(node.value, type_) # performs validate_expected_type

# ensure the initialization expression is constant or runtime constant
# (allows literals, constants, msg.sender, self, etc.)
if not check_modifiability(node.value, Modifiability.RUNTIME_CONSTANT):
raise StateAccessViolation(
"Storage variable initializer must be a literal or runtime constant"
" (e.g. msg.sender, self)",
node.value,
)

if node.is_immutable:
_validate_self_namespace()
return _finalize()
Expand Down
Loading
Loading