Skip to content

Commit 8ebfedd

Browse files
committed
update to boltzman
1 parent 3891b5b commit 8ebfedd

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

mesa/experimental/datacollection/examples/boltzmann.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def handler(subject, state):
9595

9696
observer = AgentSetObserver(model.agents, Events.STATE_CHANGE, handler)
9797

98-
datacollector = DataCollector()
98+
datacollector = DataCollector(model)
9999
datacollector.add_collector(AgentSetCollector("wealth", model.agents, "wealth"))
100100

101101

mesa/experimental/datacollection/mesa_classes.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
from typing import Any
22

3+
import itertools
4+
35
from ...agent import Agent, AgentSet
46
from ...model import Model
57
from .pubsub import EventProducer, Events
68

9+
class UpdatedAgentSet(AgentSet):
10+
11+
def get(self, attr_names: str | list[str]) -> list[Any]:
12+
if isinstance(attr_names, str):
13+
attr_names = [attr_names]
14+
a = [[getattr(agent, attr_name) for attr_name in attr_names] for agent in self._agents]
15+
return a
16+
717

818
class ObservableModel(Model):
919
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -26,6 +36,14 @@ def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
2636
"""Retrieves an AgentSet containing all agents of the specified type."""
2737
return UpdatedAgentSet(self.agents_[agenttype].keys(), self)
2838

39+
@property
40+
def agents(self) -> AgentSet:
41+
if hasattr(self, "_agents"):
42+
return self._agents
43+
else:
44+
all_agents = itertools.chain.from_iterable(self.agents_.values())
45+
return UpdatedAgentSet(all_agents, self)
46+
2947
def subscribe(self, event: str, event_handler: callable):
3048
self.event_producer.subscribe(event, event_handler)
3149

@@ -62,13 +80,7 @@ def unsubscribe(self, event: str, event_handler: callable):
6280
self.event_producer.unsubscribe(event, event_handler)
6381

6482

65-
class UpdatedAgentSet(AgentSet):
6683

67-
def get(self, attr_names: str | list[str]) -> list[Any]:
68-
if isinstance(attr_names, str):
69-
attr_names = [attr_names]
70-
a = [[getattr(agent, attr_name) for attr_name in attr_names] for agent in self._agents]
71-
return a
7284

7385

7486

0 commit comments

Comments
 (0)