Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/lambkin/core/decorators/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
@option, parses them once before the loop using an internal click parser, and
injects the resulting values into a Context class on each (variant, iteration)
pair.

Input hooks registered via "@nominal.input" are managed by "InputRegistry"
instance and resolved before the benchmark function runs on each iteration,
injecting their return values into "ctx.inputs".

Raises ValueError if variants is empty.
"""

Expand All @@ -30,6 +35,7 @@

from lambkin.core.ctx.context import Context
from lambkin.core.ctx.source import Source
from lambkin.core.decorators.input import InputRegistry


def _parse_options(fn, cli_args):
Expand All @@ -45,18 +51,19 @@ def _parse_options(fn, cli_args):
def benchmark(variants, num_iterations):
"""Drive the benchmark execution loop over all variants and iterations.

Parses CLI options registered by @lambkin.option once before the loop,
then creates a Context for each (variant, iteration) pair and calls
the decorated function with it.
Parses CLI options registered by @lambkin.option once before the loop, then
creates a Context for each (variant, iteration) pair and calls the decorated
function with it. Input hooks registered via @nominal.input are resolved
before each call, injecting their return values into ctx.inputs.

Parameters
----------
variants : iterable of dict
Sequence of variant dicts to sweep over. Each dict is exposed
as attributes on ctx.variant.
Sequence of variant dicts to sweep over. Each dict is exposed as
attributes on ctx.variant.
num_iterations : int
Number of times to repeat each variant. Controls the iter_<N>
subfolders under each variant directory.
Number of times to repeat each variant. Controls the iter_<N> subfolders
under each variant directory.

Raises:
------
Expand All @@ -70,11 +77,23 @@ def benchmark(variants, num_iterations):
)

def decorator(fn):
inputs = InputRegistry()

@functools.wraps(fn)
def wrapper(args=None, output_dir=None):
cli_args = sys.argv[1:] if args is None else args
options = _parse_options(fn, cli_args)
source = Source(path=inspect.getfile(fn))
base_ctx = Context(
variant={},
iteration=0,
options=options,
source=source,
variant_index=0,
output_dir=output_dir,
)
inputs.resolve(base_ctx)
resolved_inputs = base_ctx.inputs
for variant_index, variant in enumerate(variants):
for iteration in range(num_iterations):
ctx = Context(
Expand All @@ -85,8 +104,10 @@ def wrapper(args=None, output_dir=None):
variant_index=variant_index,
output_dir=output_dir,
)
ctx.inputs = resolved_inputs
fn(ctx)

wrapper.input = inputs.register
return wrapper

return decorator
56 changes: 55 additions & 1 deletion src/lambkin/core/decorators/input.py
Comment thread
teresa-ortega marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input decorator for lambkin."""
"""Input decorator for lambkin.

Provides the class "InputRegistry" class, which manages the registration and
resolution of input hooks for a benchmark function. Hooks are registered via the
"InputRegistry.register" method and resolved before the benchmark function runs,
injecting their return values into "ctx.inputs" under the hook's function name.
"""

import inspect


def _validate_result(hook_fn, result) -> None:
"""Validate the return value of a hook function."""
if result is None:
raise ValueError(
f"Hook '{hook_fn.__name__}' returned None or did not return a value."
)
if isinstance(result, str) and not result.strip():
raise ValueError(
f"Hook '{hook_fn.__name__}' returned an empty or blank string."
)


def _validate_hook_signature(hook_fn) -> None:
"""Validate that the hook function accepts a single 'ctx' parameter."""
params = list(inspect.signature(hook_fn).parameters.keys())

if len(params) != 1:
raise ValueError(f"Hook '{hook_fn.__name__}' must have exactly 1 parameter, ")
Comment thread
teresa-ortega marked this conversation as resolved.
Outdated


class InputRegistry:
"""Manages the registration and resolution of input hooks for a benchmark."""

def __init__(self):
"""Initialize the InputRegistry."""
self._hooks = []

def register(self, hook_fn):
"""Decorator used to register a function as an input provider."""
existing_names = [h.__name__ for h in self._hooks]
if hook_fn.__name__ in existing_names:
raise ValueError(
f"Hook name conflict: '{hook_fn.__name__}' is already registered."
)
self._hooks.append(hook_fn)
return hook_fn
Comment thread
teresa-ortega marked this conversation as resolved.

def resolve(self, ctx):
"""Resolve all registered input hooks."""
for hook in self._hooks:
_validate_hook_signature(hook)
Comment thread
teresa-ortega marked this conversation as resolved.
Outdated
result = hook(ctx)
_validate_result(hook, result)
setattr(ctx.inputs, hook.__name__, result)
164 changes: 164 additions & 0 deletions test/core/test_input.py
Comment thread
teresa-ortega marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2026 Ekumen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the input decorator in lambkin.core.decorators."""

import pytest

from lambkin.core.ctx import Context
from lambkin.core.decorators.benchmark import benchmark
from lambkin.core.decorators.input import InputRegistry


@pytest.fixture
def variant():
"""Base variants for testing."""
return [
{"sensor_model": "beam", "num_particles": 10},
]


def test_register_returns_original_function():
"""register() returns the original function unchanged."""
registry = InputRegistry()

def dataset(ctx):
return "data.mcap"

assert registry.register(dataset) is dataset


def test_registered_hook_name_is_preserved():
"""The ctx.inputs attribute name is the hook's __name__."""
registry = InputRegistry()

def my_dataset(ctx):
return "x"

registry.register(my_dataset)
assert registry._hooks[0].__name__ == "my_dataset"


def test_hook_with_no_parameters_raises():
"""A hook with no parameters raises ValueError at resolve time."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def bad_hook():
return "value"

registry.register(bad_hook)
with pytest.raises(ValueError, match="exactly 1 parameter"):
registry.resolve(ctx)


def test_hook_with_extra_parameters_raises():
"""A hook with more than 1 parameter raises ValueError at resolve time."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def bad_hook(ctx, extra):
return "value"

registry.register(bad_hook)
with pytest.raises(ValueError, match="exactly 1 parameter"):
registry.resolve(ctx)


def test_hook_returning_none_raises():
"""A hook that returns None raises ValueError."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def dataset(ctx):
return None

registry.register(dataset)
with pytest.raises(ValueError, match="None"):
registry.resolve(ctx)


def test_hook_with_no_return_raises():
"""A hook with no return statement raises ValueError."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def dataset(ctx):
pass

registry.register(dataset)
with pytest.raises(ValueError, match="None"):
registry.resolve(ctx)


def test_hook_returning_empty_string_raises():
"""A hook that returns an empty string raises ValueError."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def dataset(ctx):
return ""

registry.register(dataset)
with pytest.raises(ValueError, match="empty"):
registry.resolve(ctx)


def test_hook_returning_blank_string_raises():
"""A hook that returns a whitespace-only string raises ValueError."""
registry = InputRegistry()
ctx = Context.__new__(Context)

def dataset(ctx):
return " "

registry.register(dataset)
with pytest.raises(ValueError, match="empty"):
registry.resolve(ctx)


def test_ctx_inputs_populated_before_fn_runs(variant):
"""ctx.inputs.dataset is available inside the benchmark function after resolve."""
seen = []

@benchmark(variants=variant, num_iterations=1)
def nominal(ctx):
seen.append(ctx.inputs.dataset)

@nominal.input
def dataset(ctx):
return "path/to/dataset.mcap"

nominal(output_dir="/tmp")
assert seen == ["path/to/dataset.mcap"]


def test_multiple_inputs_all_injected(variant):
"""All registered inputs are injected into ctx.inputs before the benchmark runs."""
seen = []

@benchmark(variants=variant, num_iterations=1)
def nominal(ctx):
seen.append((ctx.inputs.dataset, ctx.inputs.map))

@nominal.input
def dataset(ctx):
return "path/to/dataset.mcap"

@nominal.input
def map(ctx):
return "path/to/map.yaml"

nominal(output_dir="/tmp")
assert seen == [("path/to/dataset.mcap", "path/to/map.yaml")]
Loading