Skip to content

Commit 48da792

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase balance package test coverage (facebookresearch#281)
Summary: The coverage report showed 98% overall coverage with 7 files having gaps. These tests cover previously untested edge cases including: - CLI exception handling paths for weighting failures - Sample class design effect diagnostics and IPW model parameters - CBPS optimization convergence warnings and constraint violation exceptions - Plotting functions with missing values, default parameters, and various dist_types - Distance metrics with empty numeric columns Differential Revision: D90946146
1 parent 6417fb0 commit 48da792

File tree

5 files changed

+779
-0
lines changed

5 files changed

+779
-0
lines changed

tests/test_cbps.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
)
1515

1616
import warnings
17+
from typing import Any, Callable, Dict, List, Tuple, Union
18+
from unittest.mock import MagicMock
1719

1820
import balance.testutil
1921
import numpy as np
@@ -1352,6 +1354,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13521354
transformations=None,
13531355
cbps_method="exact",
13541356
)
1357+
1358+
def test_cbps_over_method_logs_warnings(self) -> None:
13551359
"""Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561360
13571361
Verifies that when optimization algorithms fail to converge, appropriate
@@ -1477,3 +1481,134 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771481
),
14781482
msg=f"Expected exception to contain convergence-related message, got: {e}",
14791483
)
1484+
1485+
1486+
class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase):
1487+
"""Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1488+
1489+
These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1490+
ensuring the specific warning and exception branches in _cbps_optimization are executed.
1491+
"""
1492+
1493+
def _create_simple_test_data(
1494+
self,
1495+
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
1496+
"""Create simple test data for CBPS testing."""
1497+
sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1498+
target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]})
1499+
sample_weights = pd.Series([1.0] * 5)
1500+
target_weights = pd.Series([1.0] * 5)
1501+
return sample_df, sample_weights, target_df, target_weights
1502+
1503+
def test_exact_method_constraint_violation_exception(self) -> None:
1504+
"""Test line 726: Exception when exact method constraints can't be satisfied.
1505+
1506+
Uses mocking to simulate scipy.optimize.minimize returning success=False
1507+
with a specific constraint violation message.
1508+
"""
1509+
from unittest.mock import patch
1510+
1511+
sample_df, sample_weights, target_df, target_weights = (
1512+
self._create_simple_test_data()
1513+
)
1514+
1515+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1516+
result = MagicMock()
1517+
# Simulate constraint violation failure
1518+
result.__getitem__ = lambda self, key: {
1519+
"success": np.bool_(False),
1520+
"message": "Did not converge to a solution satisfying the constraints",
1521+
"x": x0,
1522+
"fun": 100.0,
1523+
}[key]
1524+
return result
1525+
1526+
def mock_minimize_scalar(
1527+
fun: Callable[..., Any], **kwargs: Any
1528+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1529+
return {
1530+
"success": np.bool_(True),
1531+
"message": "Success",
1532+
"x": np.array([1.0]),
1533+
}
1534+
1535+
with patch(
1536+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1537+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1538+
with self.assertRaises(Exception) as context:
1539+
balance_cbps.cbps(
1540+
sample_df,
1541+
sample_weights,
1542+
target_df,
1543+
target_weights,
1544+
transformations=None,
1545+
cbps_method="exact",
1546+
)
1547+
1548+
self.assertIn(
1549+
"no solution satisfying the constraints",
1550+
str(context.exception).lower(),
1551+
msg="Expected exception about constraint violation",
1552+
)
1553+
1554+
def test_over_method_both_gmm_constraint_violation_exception(self) -> None:
1555+
"""Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1556+
1557+
Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1558+
"""
1559+
from unittest.mock import patch
1560+
1561+
sample_df, sample_weights, target_df, target_weights = (
1562+
self._create_simple_test_data()
1563+
)
1564+
1565+
call_count: List[int] = [0]
1566+
1567+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1568+
call_count[0] += 1
1569+
result = MagicMock()
1570+
if call_count[0] == 1:
1571+
# First call is balance_optimize - succeed
1572+
result.__getitem__ = lambda self, key: {
1573+
"success": np.bool_(True),
1574+
"message": "Success",
1575+
"x": x0,
1576+
"fun": 1.0,
1577+
}[key]
1578+
else:
1579+
# Both GMM optimizations fail with constraint violation
1580+
result.__getitem__ = lambda self, key: {
1581+
"success": np.bool_(False),
1582+
"message": "Did not converge to a solution satisfying the constraints",
1583+
"x": x0,
1584+
"fun": 100.0,
1585+
}[key]
1586+
return result
1587+
1588+
def mock_minimize_scalar(
1589+
fun: Callable[..., Any], **kwargs: Any
1590+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1591+
return {
1592+
"success": np.bool_(True),
1593+
"message": "Success",
1594+
"x": np.array([1.0]),
1595+
}
1596+
1597+
with patch(
1598+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1599+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1600+
with self.assertRaises(Exception) as context:
1601+
balance_cbps.cbps(
1602+
sample_df,
1603+
sample_weights,
1604+
target_df,
1605+
target_weights,
1606+
transformations=None,
1607+
cbps_method="over",
1608+
)
1609+
1610+
self.assertIn(
1611+
"no solution satisfying the constraints",
1612+
str(context.exception).lower(),
1613+
msg="Expected exception about constraint violation in over method",
1614+
)

tests/test_cli.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,147 @@ def test_adapt_output_empty_df_returns_empty(self) -> None:
13771377
df = pd.DataFrame()
13781378
result = cli.adapt_output(df)
13791379
self.assertTrue(result.empty)
1380+
1381+
def test_cli_succeed_on_weighting_failure_with_return_df_with_original_dtypes(
1382+
self,
1383+
) -> None:
1384+
"""Test succeed_on_weighting_failure flag with return_df_with_original_dtypes.
1385+
1386+
Verifies lines 757-794 in cli.py - the exception handling path when
1387+
weighting fails and return_df_with_original_dtypes is True.
1388+
"""
1389+
with (
1390+
tempfile.TemporaryDirectory() as temp_dir,
1391+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file,
1392+
):
1393+
in_contents = "x,y,is_respondent,id,weight\na,b,1,1,1\na,b,0,1,1"
1394+
in_file.write(in_contents)
1395+
in_file.close()
1396+
out_file = os.path.join(temp_dir, "out.csv")
1397+
diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv")
1398+
1399+
parser = make_parser()
1400+
1401+
args = parser.parse_args(
1402+
[
1403+
"--input_file",
1404+
in_file.name,
1405+
"--output_file",
1406+
out_file,
1407+
"--diagnostics_output_file",
1408+
diagnostics_out_file,
1409+
"--covariate_columns",
1410+
"x,y",
1411+
"--succeed_on_weighting_failure",
1412+
"--return_df_with_original_dtypes",
1413+
]
1414+
)
1415+
cli = BalanceCLI(args)
1416+
cli.update_attributes_for_main_used_by_adjust()
1417+
cli.main()
1418+
1419+
self.assertTrue(os.path.isfile(out_file))
1420+
self.assertTrue(os.path.isfile(diagnostics_out_file))
1421+
1422+
diagnostics_df = pd.read_csv(diagnostics_out_file)
1423+
self.assertIn("adjustment_failure", diagnostics_df["metric"].values)
1424+
1425+
def test_cli_ipw_method_with_model_in_adjusted_kwargs(self) -> None:
1426+
"""Test CLI with IPW method to verify model is passed to adjust.
1427+
1428+
Verifies line 719 in cli.py where model is added to adjusted_kwargs.
1429+
"""
1430+
input_dataset = _create_sample_and_target_data()
1431+
1432+
with (
1433+
tempfile.TemporaryDirectory() as temp_dir,
1434+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as input_file,
1435+
):
1436+
input_dataset.to_csv(path_or_buf=input_file)
1437+
input_file.close()
1438+
output_file = os.path.join(temp_dir, "weights_out.csv")
1439+
diagnostics_output_file = os.path.join(temp_dir, "diagnostics_out.csv")
1440+
features = "age,gender"
1441+
1442+
parser = make_parser()
1443+
args = parser.parse_args(
1444+
[
1445+
"--input_file",
1446+
input_file.name,
1447+
"--output_file",
1448+
output_file,
1449+
"--diagnostics_output_file",
1450+
diagnostics_output_file,
1451+
"--covariate_columns",
1452+
features,
1453+
"--method=ipw",
1454+
"--ipw_logistic_regression_kwargs",
1455+
'{"solver": "lbfgs", "max_iter": 200}',
1456+
]
1457+
)
1458+
cli = BalanceCLI(args)
1459+
cli.update_attributes_for_main_used_by_adjust()
1460+
cli.main()
1461+
1462+
self.assertTrue(os.path.isfile(output_file))
1463+
self.assertTrue(os.path.isfile(diagnostics_output_file))
1464+
1465+
def test_cli_batch_columns_empty_batches(self) -> None:
1466+
"""Test CLI batch processing with empty batches.
1467+
1468+
Verifies lines 1082-1099, 1101-1106 in cli.py - batch processing
1469+
path including the empty results case.
1470+
"""
1471+
with (
1472+
tempfile.TemporaryDirectory() as temp_dir,
1473+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file,
1474+
):
1475+
in_contents = (
1476+
"x,y,is_respondent,id,weight,batch\n"
1477+
+ ("1.0,50,1,1,1,A\n" * 50)
1478+
+ ("2.0,60,0,1,1,A\n" * 50)
1479+
+ ("1.0,50,1,2,1,B\n" * 50)
1480+
+ ("2.0,60,0,2,1,B\n" * 50)
1481+
)
1482+
in_file.write(in_contents)
1483+
in_file.close()
1484+
out_file = os.path.join(temp_dir, "out.csv")
1485+
diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv")
1486+
1487+
parser = make_parser()
1488+
args = parser.parse_args(
1489+
[
1490+
"--input_file",
1491+
in_file.name,
1492+
"--output_file",
1493+
out_file,
1494+
"--diagnostics_output_file",
1495+
diagnostics_out_file,
1496+
"--covariate_columns",
1497+
"x,y",
1498+
"--batch_columns",
1499+
"batch",
1500+
]
1501+
)
1502+
cli = BalanceCLI(args)
1503+
cli.update_attributes_for_main_used_by_adjust()
1504+
cli.main()
1505+
1506+
self.assertTrue(os.path.isfile(out_file))
1507+
self.assertTrue(os.path.isfile(diagnostics_out_file))
1508+
1509+
output_df = pd.read_csv(out_file)
1510+
self.assertTrue(len(output_df) > 0)
1511+
1512+
1513+
class TestCliMainFunction(balance.testutil.BalanceTestCase):
1514+
"""Test cases for CLI main() entry point function (lines 1421-1425)."""
1515+
1516+
def test_main_is_callable(self) -> None:
1517+
"""Test that main function is callable.
1518+
1519+
Verifies lines 1421-1425 in cli.py.
1520+
"""
1521+
from balance.cli import main
1522+
1523+
self.assertTrue(callable(main))

0 commit comments

Comments
 (0)