Skip to content

Commit a05c33e

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase balance package test coverage (facebookresearch#281)
Summary: The coverage report showed 98% overall coverage with 7 files having gaps. These tests cover previously untested edge cases including: - CLI exception handling paths for weighting failures - Sample class design effect diagnostics and IPW model parameters - CBPS optimization convergence warnings and constraint violation exceptions - Plotting functions with missing values, default parameters, and various dist_types - Distance metrics with empty numeric columns Differential Revision: D90946146
1 parent cb78106 commit a05c33e

File tree

5 files changed

+927
-0
lines changed

5 files changed

+927
-0
lines changed

tests/test_cbps.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13521352
transformations=None,
13531353
cbps_method="exact",
13541354
)
1355+
1356+
def test_cbps_over_method_logs_warnings(self) -> None:
13551357
"""Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561358
13571359
Verifies that when optimization algorithms fail to converge, appropriate
@@ -1477,3 +1479,286 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771479
),
14781480
msg=f"Expected exception to contain convergence-related message, got: {e}",
14791481
)
1482+
1483+
1484+
class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase):
1485+
"""Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1486+
1487+
These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1488+
ensuring the specific warning and exception branches in _cbps_optimization are executed.
1489+
"""
1490+
1491+
def _create_simple_test_data(self) -> tuple:
1492+
"""Create simple test data for CBPS testing."""
1493+
sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1494+
target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]})
1495+
sample_weights = pd.Series([1.0] * 5)
1496+
target_weights = pd.Series([1.0] * 5)
1497+
return sample_df, sample_weights, target_df, target_weights
1498+
1499+
def test_exact_method_constraint_violation_exception(self) -> None:
1500+
"""Test line 726: Exception when exact method constraints can't be satisfied.
1501+
1502+
Uses mocking to simulate scipy.optimize.minimize returning success=False
1503+
with a specific constraint violation message.
1504+
"""
1505+
from unittest.mock import MagicMock, patch
1506+
1507+
sample_df, sample_weights, target_df, target_weights = (
1508+
self._create_simple_test_data()
1509+
)
1510+
1511+
def mock_minimize(fun, x0, **kwargs):
1512+
result = MagicMock()
1513+
# Simulate constraint violation failure
1514+
result.__getitem__ = lambda self, key: {
1515+
"success": np.bool_(False),
1516+
"message": "Did not converge to a solution satisfying the constraints",
1517+
"x": x0,
1518+
"fun": 100.0,
1519+
}[key]
1520+
return result
1521+
1522+
def mock_minimize_scalar(fun, **kwargs):
1523+
return {
1524+
"success": np.bool_(True),
1525+
"message": "Success",
1526+
"x": np.array([1.0]),
1527+
}
1528+
1529+
with patch(
1530+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1531+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1532+
with self.assertRaises(Exception) as context:
1533+
balance_cbps.cbps(
1534+
sample_df,
1535+
sample_weights,
1536+
target_df,
1537+
target_weights,
1538+
transformations=None,
1539+
cbps_method="exact",
1540+
)
1541+
1542+
self.assertIn(
1543+
"no solution satisfying the constraints",
1544+
str(context.exception).lower(),
1545+
msg="Expected exception about constraint violation",
1546+
)
1547+
1548+
def test_gmm_loss_glm_init_convergence_failure_warning(self) -> None:
1549+
"""Test line 747: Warning when gmm_loss with gmm_init fails.
1550+
1551+
Uses mocking to simulate the first gmm_loss optimization failing.
1552+
"""
1553+
import logging
1554+
from unittest.mock import MagicMock, patch
1555+
1556+
sample_df, sample_weights, target_df, target_weights = (
1557+
self._create_simple_test_data()
1558+
)
1559+
1560+
call_count = [0]
1561+
1562+
def mock_minimize(fun, x0, **kwargs):
1563+
call_count[0] += 1
1564+
result = MagicMock()
1565+
if call_count[0] == 1:
1566+
# First call is balance_optimize - succeed
1567+
result.__getitem__ = lambda self, key: {
1568+
"success": np.bool_(True),
1569+
"message": "Success",
1570+
"x": x0,
1571+
"fun": 1.0,
1572+
}[key]
1573+
elif call_count[0] == 2:
1574+
# Second call is gmm_optimize with glm_init - fail
1575+
result.__getitem__ = lambda self, key: {
1576+
"success": np.bool_(False),
1577+
"message": "Maximum iterations reached for gmm_init",
1578+
"x": x0,
1579+
"fun": 100.0,
1580+
}[key]
1581+
else:
1582+
# Third call is gmm_optimize with bal_init - succeed
1583+
result.__getitem__ = lambda self, key: {
1584+
"success": np.bool_(True),
1585+
"message": "Success",
1586+
"x": x0,
1587+
"fun": 0.5,
1588+
}[key]
1589+
return result
1590+
1591+
def mock_minimize_scalar(fun, **kwargs):
1592+
return {
1593+
"success": np.bool_(True),
1594+
"message": "Success",
1595+
"x": np.array([1.0]),
1596+
}
1597+
1598+
with patch(
1599+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1600+
), patch("scipy.optimize.minimize", side_effect=mock_minimize), self.assertLogs(
1601+
"balance.weighting_methods.cbps", level=logging.WARNING
1602+
) as log_context:
1603+
try:
1604+
balance_cbps.cbps(
1605+
sample_df,
1606+
sample_weights,
1607+
target_df,
1608+
target_weights,
1609+
transformations=None,
1610+
cbps_method="over",
1611+
)
1612+
except Exception:
1613+
pass
1614+
1615+
# Check that a warning about gmm_init was logged
1616+
gmm_warnings = [
1617+
r
1618+
for r in log_context.records
1619+
if "gmm" in r.getMessage().lower()
1620+
or "convergence" in r.getMessage().lower()
1621+
]
1622+
self.assertTrue(
1623+
len(gmm_warnings) > 0 or len(log_context.records) > 0,
1624+
msg="Expected warning about gmm_loss gmm_init convergence failure",
1625+
)
1626+
1627+
def test_gmm_loss_bal_init_convergence_failure_warning(self) -> None:
1628+
"""Test line 765: Warning when gmm_loss with beta_balance fails.
1629+
1630+
Uses mocking to simulate the second gmm_loss optimization failing.
1631+
"""
1632+
import logging
1633+
from unittest.mock import MagicMock, patch
1634+
1635+
sample_df, sample_weights, target_df, target_weights = (
1636+
self._create_simple_test_data()
1637+
)
1638+
1639+
call_count = [0]
1640+
1641+
def mock_minimize(fun, x0, **kwargs):
1642+
call_count[0] += 1
1643+
result = MagicMock()
1644+
if call_count[0] == 1:
1645+
# First call is balance_optimize - succeed
1646+
result.__getitem__ = lambda self, key: {
1647+
"success": np.bool_(True),
1648+
"message": "Success",
1649+
"x": x0,
1650+
"fun": 1.0,
1651+
}[key]
1652+
elif call_count[0] == 2:
1653+
# Second call is gmm_optimize with glm_init - succeed
1654+
result.__getitem__ = lambda self, key: {
1655+
"success": np.bool_(True),
1656+
"message": "Success",
1657+
"x": x0,
1658+
"fun": 0.5,
1659+
}[key]
1660+
else:
1661+
# Third call is gmm_optimize with bal_init - fail
1662+
result.__getitem__ = lambda self, key: {
1663+
"success": np.bool_(False),
1664+
"message": "Maximum iterations reached for beta_balance",
1665+
"x": x0,
1666+
"fun": 100.0,
1667+
}[key]
1668+
return result
1669+
1670+
def mock_minimize_scalar(fun, **kwargs):
1671+
return {
1672+
"success": np.bool_(True),
1673+
"message": "Success",
1674+
"x": np.array([1.0]),
1675+
}
1676+
1677+
with patch(
1678+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1679+
), patch("scipy.optimize.minimize", side_effect=mock_minimize), self.assertLogs(
1680+
"balance.weighting_methods.cbps", level=logging.WARNING
1681+
) as log_context:
1682+
try:
1683+
balance_cbps.cbps(
1684+
sample_df,
1685+
sample_weights,
1686+
target_df,
1687+
target_weights,
1688+
transformations=None,
1689+
cbps_method="over",
1690+
)
1691+
except Exception:
1692+
pass
1693+
1694+
# Check that a warning about beta_balance was logged
1695+
bal_warnings = [
1696+
r
1697+
for r in log_context.records
1698+
if "beta_balance" in r.getMessage().lower()
1699+
or "convergence" in r.getMessage().lower()
1700+
]
1701+
self.assertTrue(
1702+
len(bal_warnings) > 0 or len(log_context.records) > 0,
1703+
msg="Expected warning about gmm_loss beta_balance convergence failure",
1704+
)
1705+
1706+
def test_over_method_both_gmm_constraint_violation_exception(self) -> None:
1707+
"""Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1708+
1709+
Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1710+
"""
1711+
from unittest.mock import MagicMock, patch
1712+
1713+
sample_df, sample_weights, target_df, target_weights = (
1714+
self._create_simple_test_data()
1715+
)
1716+
1717+
call_count = [0]
1718+
1719+
def mock_minimize(fun, x0, **kwargs):
1720+
call_count[0] += 1
1721+
result = MagicMock()
1722+
if call_count[0] == 1:
1723+
# First call is balance_optimize - succeed
1724+
result.__getitem__ = lambda self, key: {
1725+
"success": np.bool_(True),
1726+
"message": "Success",
1727+
"x": x0,
1728+
"fun": 1.0,
1729+
}[key]
1730+
else:
1731+
# Both GMM optimizations fail with constraint violation
1732+
result.__getitem__ = lambda self, key: {
1733+
"success": np.bool_(False),
1734+
"message": "Did not converge to a solution satisfying the constraints",
1735+
"x": x0,
1736+
"fun": 100.0,
1737+
}[key]
1738+
return result
1739+
1740+
def mock_minimize_scalar(fun, **kwargs):
1741+
return {
1742+
"success": np.bool_(True),
1743+
"message": "Success",
1744+
"x": np.array([1.0]),
1745+
}
1746+
1747+
with patch(
1748+
"scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar
1749+
), patch("scipy.optimize.minimize", side_effect=mock_minimize):
1750+
with self.assertRaises(Exception) as context:
1751+
balance_cbps.cbps(
1752+
sample_df,
1753+
sample_weights,
1754+
target_df,
1755+
target_weights,
1756+
transformations=None,
1757+
cbps_method="over",
1758+
)
1759+
1760+
self.assertIn(
1761+
"no solution satisfying the constraints",
1762+
str(context.exception).lower(),
1763+
msg="Expected exception about constraint violation in over method",
1764+
)

0 commit comments

Comments
 (0)