Skip to content

Commit 6fc60de

Browse files
committed
update to pubsub and ConditionalAgentSet
1 parent 58f2676 commit 6fc60de

File tree

7 files changed

+235
-105
lines changed

7 files changed

+235
-105
lines changed

mesa/experimental/datacollection/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ class Measure:
193193
# FIXME:: doing so would turn measure into a descriptor
194194
# FIXME:: what about callable vs. attribute based Measures?
195195

196-
def __init__(self, model, obj: Any, callable: Callable):
196+
def __init__(self, model, obj: Any, fn: Callable):
197197
super().__init__()
198198
self.obj = obj
199-
self.callable = callable
199+
self.fn = fn
200200
self._update_step = -1
201201
self._cached_value = None
202202
self.model = model
@@ -210,15 +210,15 @@ def get_value(self, force_update: bool = False):
210210
"""
211211

212212
if force_update or (self.model.step != self._update_step):
213-
self._cached_value = self.callable(self.obj)
213+
self._cached_value = self.fn(self.obj)
214214
return self._cached_value
215215

216216

217217

218218
class MeasureDescriptor:
219219

220220
def some_test_method(self, obj, *args, **kwargs):
221-
print("blaat")
221+
print(obj)
222222

223223
def __set_name__(self, owner, name):
224224
self.public_name = name

mesa/experimental/datacollection/examples/boltzmann.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mesa.time import RandomActivation
55

66
from mesa.experimental.datacollection.mesa_classes import ObservableModel, ObservableAgent
7-
from mesa.experimental.datacollection.pubsub import Events, ObservableNumber
7+
from mesa.experimental.datacollection.pubsub import MessageType, ObservableState
88
from mesa.experimental.datacollection.collectors import collect, DataCollector, Measure, MeasureDescriptor
99
from mesa.experimental.datacollection.pubsub import AgentSetObserver
1010

@@ -44,7 +44,6 @@ def __init__(self, N=100, width=10, height=10):
4444

4545
self.running = True
4646

47-
self.gini.some_test_method()
4847
self.gini = Measure(self, self.agents, compute_gini)
4948

5049
def step(self):

mesa/experimental/datacollection/examples/epstein.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import math
2+
import enum
23

34
from mesa.time import RandomActivation
45
from mesa.space import SingleGrid
56

67
from mesa.experimental.datacollection.mesa_classes import ObservableModel, ObservableAgent
78
from mesa.experimental.datacollection.collectors import DataCollector, collect, Measure
9+
from mesa.experimental.datacollection.mesa_classes import ConditionalAgentSet
10+
from mesa.experimental.datacollection.pubsub import ObservableState
811

912

1013
class EpsteinAgent(ObservableAgent):
@@ -13,6 +16,12 @@ def __init__(self, unique_id, model, vision):
1316
self.vision = vision
1417

1518

19+
class CitizenState(enum.StrEnum):
20+
ACTIVE = "active"
21+
QUIESCENT = "quiescent"
22+
JAILED = "jailed"
23+
24+
1625
class Citizen(EpsteinAgent):
1726
"""
1827
A member of the general population, may or may not be in active rebellion.
@@ -37,6 +46,7 @@ class Citizen(EpsteinAgent):
3746
arrest_probability: agent's assessment of arrest probability, given
3847
rebellion
3948
"""
49+
condition = ObservableState()
4050

4151
def __init__(
4252
self,
@@ -68,7 +78,7 @@ def __init__(
6878
self.regime_legitimacy = regime_legitimacy
6979
self.risk_aversion = risk_aversion
7080
self.threshold = threshold
71-
self.condition = "Quiescent"
81+
self.condition = CitizenState.QUIESCENT
7282
self.jail_sentence = 0
7383
self.grievance = self.hardship * (1 - self.regime_legitimacy)
7484
self.arrest_probability = None
@@ -77,16 +87,20 @@ def step(self):
7787
"""
7888
Decide whether to activate, then move if applicable.
7989
"""
80-
if self.jail_sentence:
90+
if self.condition == CitizenState.JAILED:
8191
self.jail_sentence -= 1
8292
return # no other changes or movements if agent is in jail.
93+
8394
self.update_neighbors()
8495
self.update_estimated_arrest_probability()
96+
8597
net_risk = self.risk_aversion * self.arrest_probability
86-
if self.grievance - net_risk > self.threshold:
87-
self.condition = "Active"
88-
else:
89-
self.condition = "Quiescent"
98+
if (self.grievance - net_risk > self.threshold) and (self.condition != CitizenState.ACTIVE):
99+
self.condition = CitizenState.ACTIVE
100+
elif self.condition == CitizenState.ACTIVE:
101+
self.condition = CitizenState.QUIESCENT
102+
# else, agent is quiescent and stays that way
103+
90104
if self.model.movement and self.empty_neighbors:
91105
new_pos = self.random.choice(self.empty_neighbors)
92106
self.model.grid.move_agent(self, new_pos)
@@ -113,14 +127,17 @@ def update_estimated_arrest_probability(self):
113127
for c in self.neighbors:
114128
if (
115129
isinstance(c, Citizen)
116-
and c.condition == "Active"
117-
and c.jail_sentence == 0
130+
and c.condition == CitizenState.ACTIVE
118131
):
119132
actives_in_vision += 1
120133
self.arrest_probability = 1 - math.exp(
121134
-1 * self.model.arrest_prob_constant * (cops_in_vision / actives_in_vision)
122135
)
123136

137+
def sent_to_jail(self, sentence):
138+
self.condition = CitizenState.JAILED
139+
self.jail_sentence = sentence
140+
124141

125142
class Cop(EpsteinAgent):
126143
"""
@@ -145,14 +162,12 @@ def step(self):
145162
if (
146163
isinstance(agent, Citizen)
147164
and agent.condition == "Active"
148-
and agent.jail_sentence == 0
149165
):
150166
active_neighbors.append(agent)
151167
if active_neighbors:
152168
arrestee = self.random.choice(active_neighbors)
153169
sentence = self.random.randint(0, self.model.max_jail_term)
154-
arrestee.jail_sentence = sentence
155-
arrestee.condition = "Quiescent"
170+
arrestee.sent_to_jail(sentence)
156171
if self.model.movement and self.empty_neighbors:
157172
new_pos = self.random.choice(self.empty_neighbors)
158173
self.model.grid.move_agent(self, new_pos)
@@ -226,12 +241,6 @@ def __init__(
226241
self.schedule = RandomActivation(self)
227242
self.grid = SingleGrid(width, height, torus=True)
228243

229-
230-
citizens = m.get_agents_of_type(Citizen)
231-
232-
self.quiescent = Measure(citizens, lambda agent_set:
233-
agent_set.select(lambda agent: agent.condition == "quiescent"))
234-
235244
if self.cop_density + self.citizen_density > 1:
236245
raise ValueError("Cop density + citizen density must be less than 1")
237246
for contents, pos in self.grid.coord_iter():
@@ -252,6 +261,22 @@ def __init__(
252261
self.grid.place_agent(agent, pos)
253262
self.schedule.add(agent)
254263

264+
# static groups
265+
citizens = self.get_agents_of_type(Citizen)
266+
cops = self.get_agents_of_type(Cop)
267+
268+
# conditional groups
269+
self.quiescent = ConditionalAgentSet(citizens, self,
270+
condition=lambda agent: agent.condition == CitizenState.QUIESCENT)
271+
self.active = ConditionalAgentSet(citizens, self,
272+
condition=lambda agent: agent.condition == CitizenState.ACTIVE)
273+
self.jailed = ConditionalAgentSet(citizens, self, condition=lambda agent: agent.jail_sentence > 0)
274+
275+
# measures
276+
self.n_quiescent = Measure(self, self.quiescent, lambda obj: len(obj))
277+
self.n_active = Measure(self, self.active, lambda obj: len(obj))
278+
self.n_jailed = Measure(self, self.jailed, lambda obj: len(obj))
279+
255280
self.running = True
256281

257282
def step(self):
@@ -265,26 +290,17 @@ def step(self):
265290

266291

267292
if __name__ == '__main__':
268-
model = EpsteinCivilViolence()
269-
270-
citizens = model.get_agents_of_type(Citizen)
271-
cops = model.get_agents_of_type(Cop)
293+
model = EpsteinCivilViolence(seed=15)
272294

273295
dc = DataCollector(model, [
274-
collect("n_quiescent", citizens, attributes="condition",
275-
callable=lambda d: len([e for e in d if e["condition"] == "Quiescent"])),
276-
collect("n_active", citizens, attributes="condition",
277-
callable=lambda d: len([e for e in d if e["condition"] == "Active"])),
278-
collect("jail_sentence", citizens,
279-
callable=lambda d: len([e for e in d if e["jail_sentence"] > 0])),
280-
collect("data", citizens, ["jail_sentence", "arrest_probability"])
296+
collect("model_data", model, attributes=["n_quiescent", "n_active", "n_jailed"]),
297+
collect("jail_sentence", model.jailed),
298+
collect("citizen_data", model.get_agents_of_type(Citizen), ["jail_sentence", "arrest_probability"])
281299
])
282300

283301
dc.collect_all()
284-
for _ in range(10):
302+
for _ in range(50):
285303
model.step()
286304
dc.collect_all()
287305

288-
print(dc.jail_sentence.to_dataframe().head())
289-
print(dc.n_quiescent.to_dataframe().head())
290-
print(dc.data.to_dataframe().head())
306+
print(dc.data.to_dataframe())
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from mesa.experimental.datacollection.mesa_classes import ObservableModel
2+
from mesa.experimental.datacollection.pubsub import MessageType
3+
4+
if __name__ == '__main__':
5+
model = ObservableModel()
6+
7+
8+
print(MessageType.message_types)

mesa/experimental/datacollection/mesa_classes.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,120 @@
1-
from typing import Any
1+
from typing import Any, Iterable, Callable, List
22

33
import itertools
44

55
from ...agent import Agent, AgentSet
66
from ...model import Model
7-
from .pubsub import EventProducer, Events
7+
from .pubsub import MessageProducer, MessageType
8+
from .collectors import Measure, MeasureDescriptor
9+
810

911
class UpdatedAgentSet(AgentSet):
1012

11-
def get(self, attr_names: str | list[str]) -> list[Any]:
13+
def get(self, attr_names: str | List[str]) -> list[Any]:
14+
"""
15+
Retrieve the specified attribute(s) from each agent in the AgentSet.
16+
17+
Args:
18+
attr_names (str | List[str]): The name(s) of the attribute(s) to retrieve from each agent.
19+
20+
Returns:
21+
list[Any]: A list of attribute values for each agent in the set.
22+
"""
23+
1224
if isinstance(attr_names, str):
13-
attr_names = [attr_names]
14-
a = [[getattr(agent, attr_name) for attr_name in attr_names] for agent in self._agents]
15-
return a
25+
return [getattr(agent, attr_names) for agent in self._agents]
26+
else:
27+
return [[getattr(agent, attr_name) for attr_name in attr_names] for agent in self._agents]
28+
29+
30+
class ConditionalAgentSet(UpdatedAgentSet):
31+
"""This is a dynamic agent set where membership depends on a specified condition
32+
33+
For this agent set, memberships depends on a user specified condition, and it can change over the course of the
34+
simulation. Agent membership is evaluated everytime the agent sends a STATE_CHANGE message.
35+
36+
If the user passes an initial set of agents, only these agents are considered to be potentially part of the
37+
agent set
38+
39+
If the suer does not pass an initial set of agents, it defaults to model.agents
40+
41+
42+
"""
43+
44+
def __init__(self, agents: Iterable[Agent] | None, model: Model, condition: Callable[[Agent], bool]) -> None:
45+
"""
46+
47+
Args:
48+
agents (Iterable[Agent]): An iterable of agents. These form the basis of the agents in the set
49+
model (Model): A model instance
50+
condition (Callable[[Agent], bool]): a function that takes an agent and returns boolean. If true, the agent
51+
is considered part of the agent set, otherwise the is currently not part of the agent set.
52+
53+
54+
"""
55+
56+
super().__init__({}, model)
57+
self._condition = condition
58+
59+
if agents is None:
60+
agents = model.agents
61+
model.subscribe(model.AGENT_ADDED.name)
62+
63+
for agent in agents:
64+
self.add_permanently(agent)
65+
66+
def add_permanently(self, agent: Agent) -> None:
67+
agent.subscribe(agent.STATE_CHANGE.name, self.state_change_handler)
68+
self._apply_condition(agent)
69+
def remove_permanently(self, agent):
70+
self.remove(agent)
71+
agent.unsubscribe(agent.STATE_CHANGE.name, self.state_change_handler)
72+
73+
def _apply_condition(self, agent: Agent):
74+
if self._condition(agent):
75+
self.add(agent)
76+
else:
77+
self.discard(agent)
78+
79+
def state_change_handler(self, message):
80+
self._apply_condition(message.sender)
81+
82+
def agent_added_handler(self, message):
83+
agent = message.agent
84+
agent.subscribe(agent.STATE_CHANGE)
85+
self._apply_condition(agent)
1686

1787

1888
class ObservableModel(Model):
89+
AGENT_ADDED = MessageType("agent")
90+
AGENT_REMOVED = MessageType("agent")
91+
STATE_CHANGED = MessageType("state")
92+
93+
def __setattr__(self, name, value):
94+
if isinstance(value, Measure) and not name.startswith("_"):
95+
klass = type(self)
96+
descr = MeasureDescriptor()
97+
descr.__set_name__(klass, name)
98+
setattr(klass, name, descr)
99+
descr.__set__(self, value)
100+
else:
101+
super().__setattr__(name, value)
19102

20103
def __init__(self, *args: Any, **kwargs: Any) -> None:
21104
super().__init__(*args, **kwargs)
22-
self.event_producer = EventProducer(self)
105+
self.event_producer = MessageProducer(self)
23106

24107
@property
25108
def time(self):
26109
return self._time
27110

28-
def add_agent(self, agent:Agent) -> None:
111+
def add_agent(self, agent: Agent) -> None:
29112
self.agents_[type(agent)][agent] = None
30-
self.event_producer.fire_event(Events.AGENT_ADDED, agent)
113+
self.event_producer.send_message(self.AGENT_ADDED, agent=agent)
31114

32-
def remove_agent(self, agent:Agent) -> None:
115+
def remove_agent(self, agent: Agent) -> None:
33116
self.agents_[type(agent)].pop(agent, default=None)
34-
self.event_producer.fire_event(Events.AGENT_REMOVED, agent)
117+
self.event_producer.send_message(self.AGENT_REMOVED, agent=agent)
35118

36119
def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
37120
"""Retrieves an AgentSet containing all agents of the specified type."""
@@ -54,6 +137,8 @@ def unsubscribe(self, event: str, event_handler: callable):
54137

55138

56139
class ObservableAgent(Agent):
140+
STATE_CHANGE = MessageType("state")
141+
57142
def __init__(self, unique_id: int, model: ObservableModel) -> None:
58143
"""
59144
Create a new agent.
@@ -62,10 +147,11 @@ def __init__(self, unique_id: int, model: ObservableModel) -> None:
62147
unique_id (int): A unique identifier for this agent.
63148
model (Model): The model instance in which the agent exists.
64149
"""
150+
super().__init__(unique_id, model)
65151
self.unique_id = unique_id
66152
self.model = model
67153
self.pos = None
68-
self.event_producer = EventProducer(self)
154+
self.event_producer = MessageProducer(self)
69155

70156
# register agent
71157
self.model.add_agent(self)
@@ -79,9 +165,3 @@ def subscribe(self, event: str, event_handler: callable):
79165
def unsubscribe(self, event: str, event_handler: callable):
80166
# or try except pass, which is slightly faster
81167
self.event_producer.unsubscribe(event, event_handler)
82-
83-
84-
85-
86-
87-

0 commit comments

Comments
 (0)