Skip to content

Commit 6352858

Browse files
committed
ongoing work
1 parent 8ebfedd commit 6352858

File tree

3 files changed

+75
-30
lines changed

3 files changed

+75
-30
lines changed

mesa/experimental/datacollection/collectors.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,26 @@
44
import pandas as pd
55

66
from mesa.agent import AgentSet
7+
from mesa import Model
78

89

910
class BaseCollector:
10-
def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, callable: Callable=None, column_names: List[str]=None):
11+
def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, callable: Callable=None):
1112
"""
1213
1314
Args
1415
name : name of the collector
1516
obj : object
1617
attributes :
1718
callable : callable
18-
column_names : list of column names, optional, defaults to attributes
19+
20+
Note::
21+
if a callable is passed, it is assumed that there will only be a single return value
1922
2023
"""
2124
super().__init__()
25+
if attributes is None:
26+
attributes = [name,]
2227

2328
if isinstance(attributes, str):
2429
attributes = [attributes,]
@@ -27,7 +32,6 @@ def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, callable:
2732
self.obj = obj
2833
self.attributes = attributes
2934
self.callable = callable
30-
self.column_names = column_names
3135
self.data_over_time = {}
3236

3337
def collect(self, time):
@@ -54,9 +58,7 @@ def to_dataframe(self):
5458
# or name, or even something else if callable does funky stuff
5559
# so we need meaningful defaults and a way to override them
5660

57-
if self.column_names is not None:
58-
columns = self.column_names
59-
elif self.callable is not None:
61+
if self.callable is not None:
6062
columns = [self.name]
6163
else:
6264
columns = self.attributes
@@ -65,8 +67,8 @@ def to_dataframe(self):
6567

6668

6769
class AgentSetCollector(BaseCollector):
68-
def __init__(self, name, obj, attributes=None, callable=None, column_names=None):
69-
super().__init__(name, obj, attributes, callable, column_names)
70+
def __init__(self, name, obj, attributes=None, callable=None):
71+
super().__init__(name, obj, attributes=attributes, callable=callable)
7072
self.attributes.append("unique_id")
7173

7274
def collect(self, time):
@@ -163,9 +165,51 @@ def collect_all(self):
163165
collector.collect(time)
164166

165167

168+
def collect(name:str, obj:Any, attributes: str|List[str]=None, callable: Callable=None):
169+
"""
170+
171+
Args
172+
name : name of the collector
173+
obj : object form which to collect information
174+
attributes : attributes to collect, option. If not provided, attributes defaults to name
175+
callable : callable to apply to collected data.
176+
177+
FIXME:: what about callable to object directly? or simply not allow for it and solve this
178+
FIXME:: through measures?
179+
180+
"""
166181

167-
def collect_from(name, object, attributes, callable=None):
168-
if isinstance(object, AgentSet):
169-
return AgentSetCollector(name, object, attributes, callable)
182+
if isinstance(obj, AgentSet):
183+
return AgentSetCollector(name, obj, attributes, callable)
170184
else:
171-
return BaseCollector(name, object, attributes, callable)
185+
return BaseCollector(name, obj, attributes, callable)
186+
187+
188+
189+
190+
class Measure:
191+
# FIXME:: do we want AgentSet based measures?
192+
# FIXME:: can we play some property trick to enable attribute retrieval
193+
# FIXME:: doing so would turn measure into a descriptor
194+
# FIXME:: what about callable vs. attribute based Measures?
195+
196+
def __init__(self, model:Model, obj: Any, callable: Callable):
197+
super().__init__()
198+
self.obj = obj
199+
self.callable = callable
200+
self.model = model
201+
self._update_step = -1
202+
self._cached_value = None
203+
204+
def get_value(self, force_update: bool = False):
205+
"""
206+
207+
Args:
208+
force_update (bool): force recalculation of measure.
209+
210+
"""
211+
212+
if force_update or (self.model.step != self._update_step):
213+
self._cached_value = self.callable(self.obj)
214+
return self._cached_value
215+

mesa/experimental/datacollection/examples/boltzmann.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mesa.experimental.datacollection.mesa_classes import ObservableModel, ObservableAgent
77
from mesa.experimental.datacollection.pubsub import Events, ObservableNumber
8-
from mesa.experimental.datacollection.collectors import BaseCollector, AgentSetCollector, DataCollector
8+
from mesa.experimental.datacollection.collectors import collect, DataCollector
99
from mesa.experimental.datacollection.pubsub import AgentSetObserver
1010

1111
def compute_gini(model):
@@ -93,10 +93,7 @@ def handler(subject, state):
9393
if __name__ == '__main__':
9494
model = BoltzmannWealthModel()
9595

96-
observer = AgentSetObserver(model.agents, Events.STATE_CHANGE, handler)
97-
98-
datacollector = DataCollector(model)
99-
datacollector.add_collector(AgentSetCollector("wealth", model.agents, "wealth"))
96+
datacollector = DataCollector(model, [collect("wealth", model.agents)])
10097

10198

10299
for _ in range(10):

mesa/experimental/datacollection/examples/epstein.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from mesa.space import SingleGrid
55

66
from mesa.experimental.datacollection.mesa_classes import ObservableModel, ObservableAgent
7-
from mesa.experimental.datacollection.pubsub import Events, ObservableNumber
8-
from mesa.experimental.datacollection.collectors import DataCollector, collect_from
7+
from mesa.experimental.datacollection.collectors import DataCollector, collect, Measure
98

109

1110
class EpsteinAgent(ObservableAgent):
@@ -224,10 +223,15 @@ def __init__(
224223
self.arrest_prob_constant = arrest_prob_constant
225224
self.movement = movement
226225
self.max_iters = max_iters
227-
self.iteration = 0
228226
self.schedule = RandomActivation(self)
229227
self.grid = SingleGrid(width, height, torus=True)
230228

229+
230+
citizens = m.get_agents_of_type(Citizen)
231+
232+
self.quiescent = Measure(citizens, lambda agent_set:
233+
agent_set.select(lambda agent: agent.condition == "quiescent"))
234+
231235
if self.cop_density + self.citizen_density > 1:
232236
raise ValueError("Cop density + citizen density must be less than 1")
233237
for contents, pos in self.grid.coord_iter():
@@ -256,31 +260,31 @@ def step(self):
256260
"""
257261
self.schedule.step()
258262
# collect data
259-
self.iteration += 1
260-
if self.iteration > self.max_iters:
263+
if self._steps > self.max_iters:
261264
self.running = False
262265

263266

264267
if __name__ == '__main__':
265268
model = EpsteinCivilViolence()
266269

267270
citizens = model.get_agents_of_type(Citizen)
268-
cops = model.get_agents_of_type(Citizen)
271+
cops = model.get_agents_of_type(Cop)
269272

270273
dc = DataCollector(model, [
271-
collect_from("n_quiescent", citizens, "condition",
272-
lambda d: len([e for e in d if e["condition"] == "Quiescent"])),
273-
collect_from("n_active", cops, "condition",
274-
lambda d: len([e for e in d if e["condition"] == "Active"])),
275-
collect_from("jailed", citizens, "jail_sentence",
276-
lambda d: len([e for e in d if e["jail_sentence"] > 0])),
277-
collect_from("data", citizens, ["jail_sentence", "arrest_probability"])
274+
collect("n_quiescent", citizens, attributes="condition",
275+
callable=lambda d: len([e for e in d if e["condition"] == "Quiescent"])),
276+
collect("n_active", citizens, attributes="condition",
277+
callable=lambda d: len([e for e in d if e["condition"] == "Active"])),
278+
collect("jail_sentence", citizens,
279+
callable=lambda d: len([e for e in d if e["jail_sentence"] > 0])),
280+
collect("data", citizens, ["jail_sentence", "arrest_probability"])
278281
])
279282

280283
dc.collect_all()
281284
for _ in range(10):
282285
model.step()
283286
dc.collect_all()
284287

288+
print(dc.jail_sentence.to_dataframe().head())
285289
print(dc.n_quiescent.to_dataframe().head())
286290
print(dc.data.to_dataframe().head())

0 commit comments

Comments
 (0)