Skip to content

Commit 6b83686

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase balance package test coverage (facebookresearch#281)
Summary: The coverage report showed 98% overall coverage with 7 files having gaps. These tests cover previously untested edge cases including: - CLI exception handling paths for weighting failures - Sample class design effect diagnostics and IPW model parameters - CBPS optimization convergence warnings and constraint violation exceptions - Plotting functions with missing values, default parameters, and various dist_types - Distance metrics with empty numeric columns Differential Revision: D90946146
1 parent 6417fb0 commit 6b83686

File tree

5 files changed

+771
-0
lines changed

5 files changed

+771
-0
lines changed

tests/test_cbps.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13521352
transformations=None,
13531353
cbps_method="exact",
13541354
)
1355+
1356+
def test_cbps_over_method_logs_warnings(self) -> None:
13551357
"""Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561358
13571359
Verifies that when optimization algorithms fail to converge, appropriate
@@ -1477,3 +1479,128 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771479
),
14781480
msg=f"Expected exception to contain convergence-related message, got: {e}",
14791481
)
1482+
1483+
1484+
class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase):
1485+
"""Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1486+
1487+
These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1488+
ensuring the specific warning and exception branches in _cbps_optimization are executed.
1489+
"""
1490+
1491+
def _create_simple_test_data(self) -> tuple:
1492+
"""Create simple test data for CBPS testing."""
1493+
sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1494+
target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]})
1495+
sample_weights = pd.Series([1.0] * 5)
1496+
target_weights = pd.Series([1.0] * 5)
1497+
return sample_df, sample_weights, target_df, target_weights
1498+
1499+
def test_exact_method_constraint_violation_exception(self) -> None:
1500+
"""Test line 726: Exception when exact method constraints can't be satisfied.
1501+
1502+
Uses mocking to simulate scipy.optimize.minimize returning success=False
1503+
with a specific constraint violation message.
1504+
"""
1505+
from unittest.mock import MagicMock, patch
1506+
1507+
sample_df, sample_weights, target_df, target_weights = (
1508+
self._create_simple_test_data()
1509+
)
1510+
1511+
def mock_minimize(fun, x0, **kwargs):
1512+
result = MagicMock()
1513+
# Simulate constraint violation failure
1514+
result.__getitem__ = lambda self, key: {
1515+
"success": np.bool_(False),
1516+
"message": "Did not converge to a solution satisfying the constraints",
1517+
"x": x0,
1518+
"fun": 100.0,
1519+
}[key]
1520+
return result
1521+
1522+
def mock_minimize_scalar(fun, **kwargs):
1523+
return {
1524+
"success": np.bool_(True),
1525+
"message": "Success",
1526+
"x": np.array([1.0]),
1527+
}
1528+
1529+
with patch(
1530+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1531+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1532+
with self.assertRaises(Exception) as context:
1533+
balance_cbps.cbps(
1534+
sample_df,
1535+
sample_weights,
1536+
target_df,
1537+
target_weights,
1538+
transformations=None,
1539+
cbps_method="exact",
1540+
)
1541+
1542+
self.assertIn(
1543+
"no solution satisfying the constraints",
1544+
str(context.exception).lower(),
1545+
msg="Expected exception about constraint violation",
1546+
)
1547+
1548+
def test_over_method_both_gmm_constraint_violation_exception(self) -> None:
1549+
"""Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1550+
1551+
Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1552+
"""
1553+
from unittest.mock import MagicMock, patch
1554+
1555+
sample_df, sample_weights, target_df, target_weights = (
1556+
self._create_simple_test_data()
1557+
)
1558+
1559+
call_count = [0]
1560+
1561+
def mock_minimize(fun, x0, **kwargs):
1562+
call_count[0] += 1
1563+
result = MagicMock()
1564+
if call_count[0] == 1:
1565+
# First call is balance_optimize - succeed
1566+
result.__getitem__ = lambda self, key: {
1567+
"success": np.bool_(True),
1568+
"message": "Success",
1569+
"x": x0,
1570+
"fun": 1.0,
1571+
}[key]
1572+
else:
1573+
# Both GMM optimizations fail with constraint violation
1574+
result.__getitem__ = lambda self, key: {
1575+
"success": np.bool_(False),
1576+
"message": "Did not converge to a solution satisfying the constraints",
1577+
"x": x0,
1578+
"fun": 100.0,
1579+
}[key]
1580+
return result
1581+
1582+
def mock_minimize_scalar(fun, **kwargs):
1583+
return {
1584+
"success": np.bool_(True),
1585+
"message": "Success",
1586+
"x": np.array([1.0]),
1587+
}
1588+
1589+
with patch(
1590+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1591+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1592+
with self.assertRaises(Exception) as context:
1593+
balance_cbps.cbps(
1594+
sample_df,
1595+
sample_weights,
1596+
target_df,
1597+
target_weights,
1598+
transformations=None,
1599+
cbps_method="over",
1600+
)
1601+
1602+
self.assertIn(
1603+
"no solution satisfying the constraints",
1604+
str(context.exception).lower(),
1605+
msg="Expected exception about constraint violation in over method",
1606+
)

tests/test_cli.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,147 @@ def test_adapt_output_empty_df_returns_empty(self) -> None:
13771377
df = pd.DataFrame()
13781378
result = cli.adapt_output(df)
13791379
self.assertTrue(result.empty)
1380+
1381+
def test_cli_succeed_on_weighting_failure_with_return_df_with_original_dtypes(
1382+
self,
1383+
) -> None:
1384+
"""Test succeed_on_weighting_failure flag with return_df_with_original_dtypes.
1385+
1386+
Verifies lines 757-794 in cli.py - the exception handling path when
1387+
weighting fails and return_df_with_original_dtypes is True.
1388+
"""
1389+
with (
1390+
tempfile.TemporaryDirectory() as temp_dir,
1391+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file,
1392+
):
1393+
in_contents = "x,y,is_respondent,id,weight\na,b,1,1,1\na,b,0,1,1"
1394+
in_file.write(in_contents)
1395+
in_file.close()
1396+
out_file = os.path.join(temp_dir, "out.csv")
1397+
diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv")
1398+
1399+
parser = make_parser()
1400+
1401+
args = parser.parse_args(
1402+
[
1403+
"--input_file",
1404+
in_file.name,
1405+
"--output_file",
1406+
out_file,
1407+
"--diagnostics_output_file",
1408+
diagnostics_out_file,
1409+
"--covariate_columns",
1410+
"x,y",
1411+
"--succeed_on_weighting_failure",
1412+
"--return_df_with_original_dtypes",
1413+
]
1414+
)
1415+
cli = BalanceCLI(args)
1416+
cli.update_attributes_for_main_used_by_adjust()
1417+
cli.main()
1418+
1419+
self.assertTrue(os.path.isfile(out_file))
1420+
self.assertTrue(os.path.isfile(diagnostics_out_file))
1421+
1422+
diagnostics_df = pd.read_csv(diagnostics_out_file)
1423+
self.assertIn("adjustment_failure", diagnostics_df["metric"].values)
1424+
1425+
def test_cli_ipw_method_with_model_in_adjusted_kwargs(self) -> None:
1426+
"""Test CLI with IPW method to verify model is passed to adjust.
1427+
1428+
Verifies line 719 in cli.py where model is added to adjusted_kwargs.
1429+
"""
1430+
input_dataset = _create_sample_and_target_data()
1431+
1432+
with (
1433+
tempfile.TemporaryDirectory() as temp_dir,
1434+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as input_file,
1435+
):
1436+
input_dataset.to_csv(path_or_buf=input_file)
1437+
input_file.close()
1438+
output_file = os.path.join(temp_dir, "weights_out.csv")
1439+
diagnostics_output_file = os.path.join(temp_dir, "diagnostics_out.csv")
1440+
features = "age,gender"
1441+
1442+
parser = make_parser()
1443+
args = parser.parse_args(
1444+
[
1445+
"--input_file",
1446+
input_file.name,
1447+
"--output_file",
1448+
output_file,
1449+
"--diagnostics_output_file",
1450+
diagnostics_output_file,
1451+
"--covariate_columns",
1452+
features,
1453+
"--method=ipw",
1454+
"--ipw_logistic_regression_kwargs",
1455+
'{"solver": "lbfgs", "max_iter": 200}',
1456+
]
1457+
)
1458+
cli = BalanceCLI(args)
1459+
cli.update_attributes_for_main_used_by_adjust()
1460+
cli.main()
1461+
1462+
self.assertTrue(os.path.isfile(output_file))
1463+
self.assertTrue(os.path.isfile(diagnostics_output_file))
1464+
1465+
def test_cli_batch_columns_empty_batches(self) -> None:
1466+
"""Test CLI batch processing with empty batches.
1467+
1468+
Verifies lines 1082-1099, 1101-1106 in cli.py - batch processing
1469+
path including the empty results case.
1470+
"""
1471+
with (
1472+
tempfile.TemporaryDirectory() as temp_dir,
1473+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file,
1474+
):
1475+
in_contents = (
1476+
"x,y,is_respondent,id,weight,batch\n"
1477+
+ ("1.0,50,1,1,1,A\n" * 50)
1478+
+ ("2.0,60,0,1,1,A\n" * 50)
1479+
+ ("1.0,50,1,2,1,B\n" * 50)
1480+
+ ("2.0,60,0,2,1,B\n" * 50)
1481+
)
1482+
in_file.write(in_contents)
1483+
in_file.close()
1484+
out_file = os.path.join(temp_dir, "out.csv")
1485+
diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv")
1486+
1487+
parser = make_parser()
1488+
args = parser.parse_args(
1489+
[
1490+
"--input_file",
1491+
in_file.name,
1492+
"--output_file",
1493+
out_file,
1494+
"--diagnostics_output_file",
1495+
diagnostics_out_file,
1496+
"--covariate_columns",
1497+
"x,y",
1498+
"--batch_columns",
1499+
"batch",
1500+
]
1501+
)
1502+
cli = BalanceCLI(args)
1503+
cli.update_attributes_for_main_used_by_adjust()
1504+
cli.main()
1505+
1506+
self.assertTrue(os.path.isfile(out_file))
1507+
self.assertTrue(os.path.isfile(diagnostics_out_file))
1508+
1509+
output_df = pd.read_csv(out_file)
1510+
self.assertTrue(len(output_df) > 0)
1511+
1512+
1513+
class TestCliMainFunction(balance.testutil.BalanceTestCase):
1514+
"""Test cases for CLI main() entry point function (lines 1421-1425)."""
1515+
1516+
def test_main_is_callable(self) -> None:
1517+
"""Test that main function is callable.
1518+
1519+
Verifies lines 1421-1425 in cli.py.
1520+
"""
1521+
from balance.cli import main
1522+
1523+
self.assertTrue(callable(main))

0 commit comments

Comments
 (0)