Skip to content

Commit 78e103f

Browse files
authored
Merge pull request #159 from spjuhel/develop
Hotfix for a bug with rebuilding demand distribution
2 parents 6669e93 + 3ca43be commit 78e103f

File tree

3 files changed

+102
-15
lines changed

3 files changed

+102
-15
lines changed

boario/simulation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,12 +1506,17 @@ def _normalize_distribution(
15061506
dist_sq = dist
15071507
if isinstance(dist_sq, pd.Series):
15081508
ret.loc[addressed_to, :] = (
1509-
dist_sq.loc[addressed_to].transform(lambda x: x / sum(x)).values[:, None]
1509+
dist_sq.loc[addressed_to]
1510+
.groupby(level=1)
1511+
.transform(lambda x: x / sum(x))
1512+
.values[:, None]
15101513
)
15111514
return ret
15121515
elif isinstance(dist_sq, pd.DataFrame):
1513-
ret.loc[addressed_to, affected] = dist_sq.loc[addressed_to, affected].transform(
1514-
lambda x: x / sum(x)
1516+
ret.loc[addressed_to, affected] = (
1517+
dist_sq.loc[addressed_to, affected]
1518+
.groupby(level=1)
1519+
.transform(lambda x: x / sum(x))
15151520
)
15161521
return ret
15171522
else:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "boario"
3-
version = "0.6.1"
3+
version = "0.6.2"
44
description = "BoARIO : The Adaptative Regional Input Output model in python."
55
authors = ["Samuel Juhel <pro@sjuhel.org>"]
66
license = "GNU General Public License v3 or later (GPLv3+)"

tests/test_simulation.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,37 +299,119 @@ def test_equal_distribution():
299299

300300

301301
def test_normalize_distribution():
302-
# Case 1: Normalizing a Series distribution
303-
dist = pd.Series([2, 3, 5], index=["X", "Y", "Z"])
302+
# Case 1: Normalizing a Series distribution with MultiIndex
303+
dist = pd.Series(
304+
[2, 3, 5, 4, 6, 8],
305+
index=pd.MultiIndex.from_tuples(
306+
[
307+
("R1", "S1"),
308+
("R1", "S2"),
309+
("R1", "S3"),
310+
("R2", "S1"),
311+
("R2", "S2"),
312+
("R2", "S3"),
313+
],
314+
names=["region", "sector"],
315+
),
316+
)
304317
affected = pd.Index(["A"])
305-
addressed_to = pd.Index(["X", "Y", "Z"])
318+
addressed_to = pd.MultiIndex.from_tuples(
319+
[
320+
("R1", "S1"),
321+
("R1", "S2"),
322+
("R1", "S3"),
323+
("R2", "S1"),
324+
("R2", "S2"),
325+
("R2", "S3"),
326+
],
327+
names=["region", "sector"],
328+
)
306329
result = _normalize_distribution(dist, affected, addressed_to)
307330

308-
expected = pd.DataFrame({"A": [0.2, 0.3, 0.5]}, index=["X", "Y", "Z"])
331+
expected = pd.DataFrame(
332+
{"A": [2 / 6.0, 3 / 9.0, 5 / 13.0, 4 / 6.0, 6 / 9.0, 8 / 13.0]},
333+
index=pd.MultiIndex.from_tuples(
334+
[
335+
("R1", "S1"),
336+
("R1", "S2"),
337+
("R1", "S3"),
338+
("R2", "S1"),
339+
("R2", "S2"),
340+
("R2", "S3"),
341+
],
342+
names=["region", "sector"],
343+
),
344+
)
309345
pd.testing.assert_frame_equal(result, expected)
310346

311-
# Case 2: Normalizing a DataFrame distribution
312-
dist = pd.DataFrame({"A": [2, 3, 5], "B": [4, 6, 10]}, index=["X", "Y", "Z"])
347+
# Case 2: Normalizing a DataFrame distribution with MultiIndex
348+
dist = pd.DataFrame(
349+
{"A": [2, 3, 5, 4, 6, 8], "B": [10, 15, 25, 20, 30, 40]},
350+
index=pd.MultiIndex.from_tuples(
351+
[
352+
("R1", "S1"),
353+
("R1", "S2"),
354+
("R1", "S3"),
355+
("R2", "S1"),
356+
("R2", "S2"),
357+
("R2", "S3"),
358+
],
359+
names=["region", "sector"],
360+
),
361+
)
313362
affected = pd.Index(["A", "B"])
314-
addressed_to = pd.Index(["X", "Y", "Z"])
363+
addressed_to = pd.MultiIndex.from_tuples(
364+
[
365+
("R1", "S1"),
366+
("R1", "S2"),
367+
("R1", "S3"),
368+
("R2", "S1"),
369+
("R2", "S2"),
370+
("R2", "S3"),
371+
],
372+
names=["region", "sector"],
373+
)
315374
result = _normalize_distribution(dist, affected, addressed_to)
316375

317376
expected = pd.DataFrame(
318-
{"A": [0.2, 0.3, 0.5], "B": [0.2, 0.3, 0.5]}, index=["X", "Y", "Z"]
377+
{
378+
"A": [2 / 6.0, 3 / 9.0, 5 / 13.0, 4 / 6.0, 6 / 9.0, 8 / 13.0],
379+
"B": [10 / 30.0, 15 / 45.0, 25 / 65.0, 20 / 30.0, 30 / 45.0, 40 / 65.0],
380+
},
381+
index=pd.MultiIndex.from_tuples(
382+
[
383+
("R1", "S1"),
384+
("R1", "S2"),
385+
("R1", "S3"),
386+
("R2", "S1"),
387+
("R2", "S2"),
388+
("R2", "S3"),
389+
],
390+
names=["region", "sector"],
391+
),
319392
)
320393
pd.testing.assert_frame_equal(result, expected)
321394

322395
# Case 6: Mismatched indices in Series
323-
dist = pd.Series([2, 3], index=["X", "Y"])
396+
dist = pd.Series(
397+
[2, 3],
398+
index=pd.MultiIndex.from_tuples(
399+
[("R1", "S1"), ("R1", "S2")], names=["region", "sector"]
400+
),
401+
)
324402
affected = pd.Index(["A"])
325-
addressed_to = pd.Index(["X", "Y", "Z"])
403+
addressed_to = pd.MultiIndex.from_tuples(
404+
[("R1", "S1"), ("R1", "S2"), ("R1", "S3")], names=["region", "sector"]
405+
)
326406
with pytest.raises(KeyError):
327407
_normalize_distribution(dist, affected, addressed_to)
328408

329409
# Case 7: Invalid distribution type
330410
dist = [2, 3, 5] # Not a Series or DataFrame
331411
affected = pd.Index(["A"])
332-
addressed_to = pd.Index(["X", "Y", "Z"])
412+
addressed_to = pd.MultiIndex.from_tuples(
413+
[("R1", "S1"), ("R1", "S2"), ("R1", "S3")], names=["region", "sector"]
414+
)
333415
with pytest.raises(
334416
ValueError, match="given distribution should be a Series or a DataFrame"
335417
):

0 commit comments

Comments
 (0)