Skip to content

Commit 27d4fb2

Browse files
committed
fix: ensure GeniusOpponentModel initializes BaseUtilityFunction attributes
Add __attrs_post_init__ to GeniusOpponentModel to call BaseUtilityFunction.__init__, ensuring parent class attributes (_reserved_value, _invalid_value, _cached_inverse, _cached_inverse_type) are properly initialized when attrs-based subclasses like GSmithFrequencyModel are instantiated. Also add tests for GSmithFrequencyModel initialization and integration with BOA negotiators.
1 parent 458dac8 commit 27d4fb2

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

src/negmas/gb/components/genius/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class GeniusOpponentModel(VolatileUFunMixin, GBComponent, BaseUtilityFunction):
3838
private_info with learned opponent utility function estimates.
3939
"""
4040

41+
def __attrs_post_init__(self) -> None:
42+
"""Initialize parent classes after attrs initialization."""
43+
BaseUtilityFunction.__init__(self)
44+
4145
def _update_private_info(self, partner_id: str | None = None) -> None:
4246
"""Update the negotiator's private_info with this model.
4347

tests/core/test_genius_models.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from negmas.gb.negotiators.timebased import AspirationNegotiator
2+
from negmas.outcomes.base_issue import make_issue
3+
from negmas.outcomes.outcome_space import make_os
4+
from negmas.preferences.base_ufun import BaseUtilityFunction
5+
from negmas.preferences.crisp.linear import LinearAdditiveUtilityFunction
6+
from negmas.preferences.ops import compare_ufuns
7+
from negmas.preferences.value_fun import AffineFun, LinearFun
8+
from negmas.sao.mechanism import SAOMechanism
9+
from negmas.sao.negotiators.modular import BOANegotiator
10+
from negmas.sao.components.offering import TimeBasedOfferingPolicy
11+
from negmas.sao.components.acceptance import ACNext
12+
from negmas.gb.components.genius.models import GSmithFrequencyModel
13+
14+
15+
def test_gsmith_frequency_model_initializes_base_ufun_attributes():
16+
"""Test that GSmithFrequencyModel properly initializes BaseUtilityFunction attributes.
17+
18+
This ensures that the attrs-based GeniusOpponentModel correctly calls
19+
BaseUtilityFunction.__init__ via __attrs_post_init__.
20+
"""
21+
model = GSmithFrequencyModel()
22+
23+
# Verify all attributes set by BaseUtilityFunction.__init__ are present and have correct defaults
24+
assert hasattr(model, "_reserved_value"), "Missing _reserved_value attribute"
25+
assert hasattr(model, "_invalid_value"), "Missing _invalid_value attribute"
26+
assert hasattr(model, "_cached_inverse"), "Missing _cached_inverse attribute"
27+
assert hasattr(model, "_cached_inverse_type"), (
28+
"Missing _cached_inverse_type attribute"
29+
)
30+
31+
# Check default values match BaseUtilityFunction defaults
32+
assert model._reserved_value == float("-inf"), (
33+
f"Expected -inf, got {model._reserved_value}"
34+
)
35+
assert model._invalid_value is None, f"Expected None, got {model._invalid_value}"
36+
assert model._cached_inverse is None, f"Expected None, got {model._cached_inverse}"
37+
assert model._cached_inverse_type is None, (
38+
f"Expected None, got {model._cached_inverse_type}"
39+
)
40+
41+
# Verify the model is an instance of BaseUtilityFunction
42+
assert isinstance(model, BaseUtilityFunction)
43+
44+
# Verify the reserved_value property works (uses _reserved_value internally)
45+
assert model.reserved_value == float("-inf")
46+
47+
48+
def calc_scores(m: SAOMechanism) -> dict[str, dict[str, float]]:
49+
"""Compute scores for the given agreement according the ANL 2026 rules."""
50+
51+
# extract the agreement
52+
agreement = m.agreement
53+
54+
# extract negotiator names
55+
negotiators = [_.__class__.__name__ for _ in m.negotiators]
56+
57+
# find advantages (utility above reserved value)
58+
advantages = [
59+
float(_.ufun(agreement)) - float(_.ufun.reserved_value) if _.ufun else 0.0
60+
for _ in m.negotiators
61+
]
62+
63+
# calculate modeling accuracies
64+
ufuns = [_.ufun for _ in m.negotiators]
65+
models = [_.opponent_ufun for _ in m.negotiators]
66+
models.reverse()
67+
accuracies = [
68+
(1 + compare_ufuns(u, model, method="kendall", outcome_space=m.outcome_space))
69+
/ 2
70+
for u, model in zip(ufuns, models)
71+
]
72+
73+
# normalize accuracies so that we divide one point among all negotiators with
74+
# negotiators with higher accuracy getting higher part of this point.
75+
accsum = sum(accuracies)
76+
if accsum > 0:
77+
accuracies = [_ / accsum for _ in accuracies]
78+
else:
79+
accuracies = [0] * len(negotiators)
80+
accuracies.reverse()
81+
# return final scores. You can improve your score in one of three ways:
82+
# 1. Increase your advantage (negotiating a better deal for yourself)
83+
# 2. Increase your modeling accuracy (better opponent modeling)
84+
# 3. Decrease your opponent's accuracy (confuse their opponent modeling)
85+
return dict(
86+
zip(
87+
negotiators,
88+
(
89+
dict(Advavntage=adv, Accuracy=acc, Score=adv + acc)
90+
for adv, acc in zip(advantages, accuracies)
91+
),
92+
)
93+
)
94+
95+
96+
class BOANeg(BOANegotiator):
97+
def __init__(self, *args, **kwargs):
98+
offering = TimeBasedOfferingPolicy()
99+
kwargs |= dict(
100+
acceptance=ACNext(offering), offering=offering, model=GSmithFrequencyModel()
101+
)
102+
super().__init__(*args, **kwargs)
103+
104+
105+
def test_smith_frequency_model():
106+
os = make_os([make_issue(10, "i1"), make_issue(10, "i2")])
107+
m = SAOMechanism(
108+
n_steps=100,
109+
outcome_space=os,
110+
ignore_negotiator_exceptions=False,
111+
one_offer_per_step=True,
112+
)
113+
m.add(
114+
BOANeg(
115+
ufun=LinearAdditiveUtilityFunction(
116+
values=[LinearFun(slope=0.1), LinearFun(slope=0.1)],
117+
weights=[0.5, 0.5],
118+
outcome_space=os,
119+
),
120+
id="boa",
121+
)
122+
)
123+
m.add(
124+
AspirationNegotiator(
125+
ufun=LinearAdditiveUtilityFunction(
126+
values=[AffineFun(slope=-0.1, bias=10), LinearFun(slope=0.1)],
127+
weights=[0.8, 0.2],
128+
outcome_space=os,
129+
),
130+
id="asp",
131+
)
132+
)
133+
m.run()
134+
print(calc_scores(m))
135+
trace = m.extended_trace
136+
assert len(trace) > 2, f"{trace}"

0 commit comments

Comments
 (0)