Skip to content

Commit 58f2676

Browse files
committed
ongoing work
1 parent 6352858 commit 58f2676

File tree

5 files changed

+100
-27
lines changed

5 files changed

+100
-27
lines changed

mesa/experimental/datacollection/collectors.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class BaseCollector:
11-
def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, callable: Callable=None):
11+
def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, fn: Callable=None):
1212
"""
1313
1414
Args
@@ -31,7 +31,7 @@ def __init__(self, name:str, obj: Any, attributes: str|List[str]=None, callable:
3131
self.name = name
3232
self.obj = obj
3333
self.attributes = attributes
34-
self.callable = callable
34+
self.callable = fn
3535
self.data_over_time = {}
3636

3737
def collect(self, time):
@@ -67,8 +67,8 @@ def to_dataframe(self):
6767

6868

6969
class AgentSetCollector(BaseCollector):
70-
def __init__(self, name, obj, attributes=None, callable=None):
71-
super().__init__(name, obj, attributes=attributes, callable=callable)
70+
def __init__(self, name, obj, attributes=None, fn=None):
71+
super().__init__(name, obj, attributes=attributes, fn=fn)
7272
self.attributes.append("unique_id")
7373

7474
def collect(self, time):
@@ -165,24 +165,24 @@ def collect_all(self):
165165
collector.collect(time)
166166

167167

168-
def collect(name:str, obj:Any, attributes: str|List[str]=None, callable: Callable=None):
168+
def collect(name:str, obj:Any, attributes: str|List[str]=None, fn: Callable=None):
169169
"""
170170
171171
Args
172172
name : name of the collector
173173
obj : object form which to collect information
174174
attributes : attributes to collect, option. If not provided, attributes defaults to name
175-
callable : callable to apply to collected data.
175+
fn : callable to apply to collected data.
176176
177177
FIXME:: what about callable to object directly? or simply not allow for it and solve this
178178
FIXME:: through measures?
179179
180180
"""
181181

182182
if isinstance(obj, AgentSet):
183-
return AgentSetCollector(name, obj, attributes, callable)
183+
return AgentSetCollector(name, obj, attributes, fn)
184184
else:
185-
return BaseCollector(name, obj, attributes, callable)
185+
return BaseCollector(name, obj, attributes, fn)
186186

187187

188188

@@ -193,13 +193,13 @@ class Measure:
193193
# FIXME:: doing so would turn measure into a descriptor
194194
# FIXME:: what about callable vs. attribute based Measures?
195195

196-
def __init__(self, model:Model, obj: Any, callable: Callable):
196+
def __init__(self, model, obj: Any, callable: Callable):
197197
super().__init__()
198198
self.obj = obj
199199
self.callable = callable
200-
self.model = model
201200
self._update_step = -1
202201
self._cached_value = None
202+
self.model = model
203203

204204
def get_value(self, force_update: bool = False):
205205
"""
@@ -213,3 +213,20 @@ def get_value(self, force_update: bool = False):
213213
self._cached_value = self.callable(self.obj)
214214
return self._cached_value
215215

216+
217+
218+
class MeasureDescriptor:
219+
220+
def some_test_method(self, obj, *args, **kwargs):
221+
print("blaat")
222+
223+
def __set_name__(self, owner, name):
224+
self.public_name = name
225+
self.private_name = "_" + name
226+
def __get__(self, obj, owner):
227+
return getattr(obj, self.private_name).get_value()
228+
229+
def __set__(self, obj, value):
230+
setattr(obj, self.private_name, value)
231+
232+

mesa/experimental/datacollection/examples/boltzmann.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55

66
from mesa.experimental.datacollection.mesa_classes import ObservableModel, ObservableAgent
77
from mesa.experimental.datacollection.pubsub import Events, ObservableNumber
8-
from mesa.experimental.datacollection.collectors import collect, DataCollector
8+
from mesa.experimental.datacollection.collectors import collect, DataCollector, Measure, MeasureDescriptor
99
from mesa.experimental.datacollection.pubsub import AgentSetObserver
1010

11-
def compute_gini(model):
12-
agent_wealths = [agent.wealth for agent in model.schedule.agents]
11+
def compute_gini(agents):
12+
agent_wealths = [agent.wealth for agent in agents]
1313
x = sorted(agent_wealths)
1414
N = model.num_agents
1515
B = sum(xi * (N - i) for i, xi in enumerate(x)) / (N * sum(x))
1616
return 1 + (1 / N) - 2 * B
1717

1818

19+
1920
class BoltzmannWealthModel(ObservableModel):
2021
"""A simple model of an economy where agents exchange currency at random.
2122
@@ -24,14 +25,13 @@ class BoltzmannWealthModel(ObservableModel):
2425
highly skewed distribution of wealth.
2526
"""
2627

28+
# gini = MeasureDescriptor()
29+
2730
def __init__(self, N=100, width=10, height=10):
2831
super().__init__()
2932
self.num_agents = N
3033
self.grid = MultiGrid(width, height, True)
3134
self.schedule = RandomActivation(self)
32-
# self.datacollector = DataCollector(
33-
# model_reporters={"Gini": compute_gini}, agent_reporters={"Wealth": "wealth"}
34-
# )
3535

3636
# Create agents
3737
for i in range(self.num_agents):
@@ -43,21 +43,16 @@ def __init__(self, N=100, width=10, height=10):
4343
self.grid.place_agent(a, (x, y))
4444

4545
self.running = True
46-
# self.datacollector.collect(self)
46+
47+
self.gini.some_test_method()
48+
self.gini = Measure(self, self.agents, compute_gini)
4749

4850
def step(self):
4951
self.schedule.step()
50-
# collect data
51-
# self.datacollector.collect(self)
52-
53-
def run_model(self, n):
54-
for i in range(n):
55-
self.step()
5652

5753

5854
class MoneyAgent(ObservableAgent):
5955
"""An agent with fixed initial wealth."""
60-
wealth = ObservableNumber()
6156

6257
def __init__(self, unique_id, model):
6358
super().__init__(unique_id, model)
@@ -90,14 +85,20 @@ def handler(subject, state):
9085
if state == "wealth":
9186
return getattr(subject, state)
9287

88+
89+
def some_func(obj):
90+
return obj.get_value()
91+
9392
if __name__ == '__main__':
9493
model = BoltzmannWealthModel()
9594

96-
datacollector = DataCollector(model, [collect("wealth", model.agents)])
95+
datacollector = DataCollector(model, [collect("wealth", model.agents),
96+
collect("gini", model)])
9797

9898

9999
for _ in range(10):
100100
model.step()
101101
datacollector.collect_all()
102102

103-
print(datacollector.wealth.to_dataframe().head())
103+
print(datacollector.wealth.to_dataframe().head())
104+
print(datacollector.gini.to_dataframe().head())

mesa/experimental/datacollection/mesa_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def get(self, attr_names: str | list[str]) -> list[Any]:
1616

1717

1818
class ObservableModel(Model):
19+
1920
def __init__(self, *args: Any, **kwargs: Any) -> None:
2021
super().__init__(*args, **kwargs)
2122
self.event_producer = EventProducer(self)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
class Measure:
2+
3+
def __init__(self, model, identifier, *args, **kwargs):
4+
self.model = model
5+
self.identifier = identifier
6+
7+
def get_value(self):
8+
return getattr(self.model, self.identifier)
9+
10+
11+
class MeasureDescriptor:
12+
def __set_name__(self, owner, name):
13+
self.public_name = name
14+
self.private_name = "_" + name
15+
16+
def __get__(self, obj, owner):
17+
return getattr(obj, self.private_name).get_value()
18+
19+
def __set__(self, obj, value):
20+
setattr(obj, self.private_name, value)
21+
22+
23+
class Model:
24+
25+
def __setattr__(self, name, value):
26+
if isinstance(value, Measure) and not name.startswith("_"):
27+
klass = type(self)
28+
descr = MeasureDescriptor()
29+
descr.__set_name__(klass, name)
30+
setattr(klass, name, descr)
31+
descr.__set__(self, value)
32+
else:
33+
super().__setattr__(name, value)
34+
35+
def __init__(self, identifier, *args, **kwargs):
36+
self.gini = Measure(self, "identifier")
37+
self.identifier = identifier
38+
39+
40+
if __name__ == '__main__':
41+
model1 = Model(1)
42+
model2 = Model(2)
43+
print(model1.gini)
44+
print(model2.gini)

mesa/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,17 @@
2121
TimeT = Union[float, int]
2222

2323

24-
class Model:
24+
class MetaClass(type):
25+
def __new__(cls, name, bases, attrs):
26+
return super().__new__(cls, name, bases, attrs)
27+
28+
def __init__(self, name, bases, attrs):
29+
# perform any additional initialization here...
30+
super().__init__(name, bases, attrs)
31+
print("blaat")
32+
33+
34+
class Model(metaclass=MetaClass):
2535
"""Base class for models in the Mesa ABM library.
2636
2737
This class serves as a foundational structure for creating agent-based models.

0 commit comments

Comments
 (0)