@@ -299,37 +299,119 @@ def test_equal_distribution():
299299
300300
301301def test_normalize_distribution ():
302- # Case 1: Normalizing a Series distribution
303- dist = pd .Series ([2 , 3 , 5 ], index = ["X" , "Y" , "Z" ])
302+ # Case 1: Normalizing a Series distribution with MultiIndex
303+ dist = pd .Series (
304+ [2 , 3 , 5 , 4 , 6 , 8 ],
305+ index = pd .MultiIndex .from_tuples (
306+ [
307+ ("R1" , "S1" ),
308+ ("R1" , "S2" ),
309+ ("R1" , "S3" ),
310+ ("R2" , "S1" ),
311+ ("R2" , "S2" ),
312+ ("R2" , "S3" ),
313+ ],
314+ names = ["region" , "sector" ],
315+ ),
316+ )
304317 affected = pd .Index (["A" ])
305- addressed_to = pd .Index (["X" , "Y" , "Z" ])
318+ addressed_to = pd .MultiIndex .from_tuples (
319+ [
320+ ("R1" , "S1" ),
321+ ("R1" , "S2" ),
322+ ("R1" , "S3" ),
323+ ("R2" , "S1" ),
324+ ("R2" , "S2" ),
325+ ("R2" , "S3" ),
326+ ],
327+ names = ["region" , "sector" ],
328+ )
306329 result = _normalize_distribution (dist , affected , addressed_to )
307330
308- expected = pd .DataFrame ({"A" : [0.2 , 0.3 , 0.5 ]}, index = ["X" , "Y" , "Z" ])
331+ expected = pd .DataFrame (
332+ {"A" : [2 / 6.0 , 3 / 9.0 , 5 / 13.0 , 4 / 6.0 , 6 / 9.0 , 8 / 13.0 ]},
333+ index = pd .MultiIndex .from_tuples (
334+ [
335+ ("R1" , "S1" ),
336+ ("R1" , "S2" ),
337+ ("R1" , "S3" ),
338+ ("R2" , "S1" ),
339+ ("R2" , "S2" ),
340+ ("R2" , "S3" ),
341+ ],
342+ names = ["region" , "sector" ],
343+ ),
344+ )
309345 pd .testing .assert_frame_equal (result , expected )
310346
311- # Case 2: Normalizing a DataFrame distribution
312- dist = pd .DataFrame ({"A" : [2 , 3 , 5 ], "B" : [4 , 6 , 10 ]}, index = ["X" , "Y" , "Z" ])
347+ # Case 2: Normalizing a DataFrame distribution with MultiIndex
348+ dist = pd .DataFrame (
349+ {"A" : [2 , 3 , 5 , 4 , 6 , 8 ], "B" : [10 , 15 , 25 , 20 , 30 , 40 ]},
350+ index = pd .MultiIndex .from_tuples (
351+ [
352+ ("R1" , "S1" ),
353+ ("R1" , "S2" ),
354+ ("R1" , "S3" ),
355+ ("R2" , "S1" ),
356+ ("R2" , "S2" ),
357+ ("R2" , "S3" ),
358+ ],
359+ names = ["region" , "sector" ],
360+ ),
361+ )
313362 affected = pd .Index (["A" , "B" ])
314- addressed_to = pd .Index (["X" , "Y" , "Z" ])
363+ addressed_to = pd .MultiIndex .from_tuples (
364+ [
365+ ("R1" , "S1" ),
366+ ("R1" , "S2" ),
367+ ("R1" , "S3" ),
368+ ("R2" , "S1" ),
369+ ("R2" , "S2" ),
370+ ("R2" , "S3" ),
371+ ],
372+ names = ["region" , "sector" ],
373+ )
315374 result = _normalize_distribution (dist , affected , addressed_to )
316375
317376 expected = pd .DataFrame (
318- {"A" : [0.2 , 0.3 , 0.5 ], "B" : [0.2 , 0.3 , 0.5 ]}, index = ["X" , "Y" , "Z" ]
377+ {
378+ "A" : [2 / 6.0 , 3 / 9.0 , 5 / 13.0 , 4 / 6.0 , 6 / 9.0 , 8 / 13.0 ],
379+ "B" : [10 / 30.0 , 15 / 45.0 , 25 / 65.0 , 20 / 30.0 , 30 / 45.0 , 40 / 65.0 ],
380+ },
381+ index = pd .MultiIndex .from_tuples (
382+ [
383+ ("R1" , "S1" ),
384+ ("R1" , "S2" ),
385+ ("R1" , "S3" ),
386+ ("R2" , "S1" ),
387+ ("R2" , "S2" ),
388+ ("R2" , "S3" ),
389+ ],
390+ names = ["region" , "sector" ],
391+ ),
319392 )
320393 pd .testing .assert_frame_equal (result , expected )
321394
322395 # Case 6: Mismatched indices in Series
323- dist = pd .Series ([2 , 3 ], index = ["X" , "Y" ])
396+ dist = pd .Series (
397+ [2 , 3 ],
398+ index = pd .MultiIndex .from_tuples (
399+ [("R1" , "S1" ), ("R1" , "S2" )], names = ["region" , "sector" ]
400+ ),
401+ )
324402 affected = pd .Index (["A" ])
325- addressed_to = pd .Index (["X" , "Y" , "Z" ])
403+ addressed_to = pd .MultiIndex .from_tuples (
404+ [("R1" , "S1" ), ("R1" , "S2" ), ("R1" , "S3" )], names = ["region" , "sector" ]
405+ )
326406 with pytest .raises (KeyError ):
327407 _normalize_distribution (dist , affected , addressed_to )
328408
329409 # Case 7: Invalid distribution type
330410 dist = [2 , 3 , 5 ] # Not a Series or DataFrame
331411 affected = pd .Index (["A" ])
332- addressed_to = pd .Index (["X" , "Y" , "Z" ])
412+ addressed_to = pd .MultiIndex .from_tuples (
413+ [("R1" , "S1" ), ("R1" , "S2" ), ("R1" , "S3" )], names = ["region" , "sector" ]
414+ )
333415 with pytest .raises (
334416 ValueError , match = "given distribution should be a Series or a DataFrame"
335417 ):
0 commit comments