@@ -1352,6 +1352,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13521352 transformations = None ,
13531353 cbps_method = "exact" ,
13541354 )
1355+
1356+ def test_cbps_over_method_logs_warnings (self ) -> None :
13551357 """Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561358
13571359 Verifies that when optimization algorithms fail to converge, appropriate
@@ -1477,3 +1479,286 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771479 ),
14781480 msg = f"Expected exception to contain convergence-related message, got: { e } " ,
14791481 )
1482+
1483+
1484+ class TestCbpsOptimizationConvergenceWithMocking (balance .testutil .BalanceTestCase ):
1485+ """Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1486+
1487+ These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1488+ ensuring the specific warning and exception branches in _cbps_optimization are executed.
1489+ """
1490+
1491+ def _create_simple_test_data (self ) -> tuple :
1492+ """Create simple test data for CBPS testing."""
1493+ sample_df = pd .DataFrame ({"a" : [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]})
1494+ target_df = pd .DataFrame ({"a" : [2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]})
1495+ sample_weights = pd .Series ([1.0 ] * 5 )
1496+ target_weights = pd .Series ([1.0 ] * 5 )
1497+ return sample_df , sample_weights , target_df , target_weights
1498+
1499+ def test_exact_method_constraint_violation_exception (self ) -> None :
1500+ """Test line 726: Exception when exact method constraints can't be satisfied.
1501+
1502+ Uses mocking to simulate scipy.optimize.minimize returning success=False
1503+ with a specific constraint violation message.
1504+ """
1505+ from unittest .mock import MagicMock , patch
1506+
1507+ sample_df , sample_weights , target_df , target_weights = (
1508+ self ._create_simple_test_data ()
1509+ )
1510+
1511+ def mock_minimize (fun , x0 , ** kwargs ):
1512+ result = MagicMock ()
1513+ # Simulate constraint violation failure
1514+ result .__getitem__ = lambda self , key : {
1515+ "success" : np .bool_ (False ),
1516+ "message" : "Did not converge to a solution satisfying the constraints" ,
1517+ "x" : x0 ,
1518+ "fun" : 100.0 ,
1519+ }[key ]
1520+ return result
1521+
1522+ def mock_minimize_scalar (fun , ** kwargs ):
1523+ return {
1524+ "success" : np .bool_ (True ),
1525+ "message" : "Success" ,
1526+ "x" : np .array ([1.0 ]),
1527+ }
1528+
1529+ with patch (
1530+ "scipy.optimize.minimize_scalar" , side_effect = mock_minimize_scalar
1531+ ), patch ("scipy.optimize.minimize" , side_effect = mock_minimize ):
1532+ with self .assertRaises (Exception ) as context :
1533+ balance_cbps .cbps (
1534+ sample_df ,
1535+ sample_weights ,
1536+ target_df ,
1537+ target_weights ,
1538+ transformations = None ,
1539+ cbps_method = "exact" ,
1540+ )
1541+
1542+ self .assertIn (
1543+ "no solution satisfying the constraints" ,
1544+ str (context .exception ).lower (),
1545+ msg = "Expected exception about constraint violation" ,
1546+ )
1547+
1548+ def test_gmm_loss_glm_init_convergence_failure_warning (self ) -> None :
1549+ """Test line 747: Warning when gmm_loss with gmm_init fails.
1550+
1551+ Uses mocking to simulate the first gmm_loss optimization failing.
1552+ """
1553+ import logging
1554+ from unittest .mock import MagicMock , patch
1555+
1556+ sample_df , sample_weights , target_df , target_weights = (
1557+ self ._create_simple_test_data ()
1558+ )
1559+
1560+ call_count = [0 ]
1561+
1562+ def mock_minimize (fun , x0 , ** kwargs ):
1563+ call_count [0 ] += 1
1564+ result = MagicMock ()
1565+ if call_count [0 ] == 1 :
1566+ # First call is balance_optimize - succeed
1567+ result .__getitem__ = lambda self , key : {
1568+ "success" : np .bool_ (True ),
1569+ "message" : "Success" ,
1570+ "x" : x0 ,
1571+ "fun" : 1.0 ,
1572+ }[key ]
1573+ elif call_count [0 ] == 2 :
1574+ # Second call is gmm_optimize with glm_init - fail
1575+ result .__getitem__ = lambda self , key : {
1576+ "success" : np .bool_ (False ),
1577+ "message" : "Maximum iterations reached for gmm_init" ,
1578+ "x" : x0 ,
1579+ "fun" : 100.0 ,
1580+ }[key ]
1581+ else :
1582+ # Third call is gmm_optimize with bal_init - succeed
1583+ result .__getitem__ = lambda self , key : {
1584+ "success" : np .bool_ (True ),
1585+ "message" : "Success" ,
1586+ "x" : x0 ,
1587+ "fun" : 0.5 ,
1588+ }[key ]
1589+ return result
1590+
1591+ def mock_minimize_scalar (fun , ** kwargs ):
1592+ return {
1593+ "success" : np .bool_ (True ),
1594+ "message" : "Success" ,
1595+ "x" : np .array ([1.0 ]),
1596+ }
1597+
1598+ with patch (
1599+ "scipy.optimize.minimize_scalar" , side_effect = mock_minimize_scalar
1600+ ), patch ("scipy.optimize.minimize" , side_effect = mock_minimize ), self .assertLogs (
1601+ "balance.weighting_methods.cbps" , level = logging .WARNING
1602+ ) as log_context :
1603+ try :
1604+ balance_cbps .cbps (
1605+ sample_df ,
1606+ sample_weights ,
1607+ target_df ,
1608+ target_weights ,
1609+ transformations = None ,
1610+ cbps_method = "over" ,
1611+ )
1612+ except Exception :
1613+ pass
1614+
1615+ # Check that a warning about gmm_init was logged
1616+ gmm_warnings = [
1617+ r
1618+ for r in log_context .records
1619+ if "gmm" in r .getMessage ().lower ()
1620+ or "convergence" in r .getMessage ().lower ()
1621+ ]
1622+ self .assertTrue (
1623+ len (gmm_warnings ) > 0 or len (log_context .records ) > 0 ,
1624+ msg = "Expected warning about gmm_loss gmm_init convergence failure" ,
1625+ )
1626+
1627+ def test_gmm_loss_bal_init_convergence_failure_warning (self ) -> None :
1628+ """Test line 765: Warning when gmm_loss with beta_balance fails.
1629+
1630+ Uses mocking to simulate the second gmm_loss optimization failing.
1631+ """
1632+ import logging
1633+ from unittest .mock import MagicMock , patch
1634+
1635+ sample_df , sample_weights , target_df , target_weights = (
1636+ self ._create_simple_test_data ()
1637+ )
1638+
1639+ call_count = [0 ]
1640+
1641+ def mock_minimize (fun , x0 , ** kwargs ):
1642+ call_count [0 ] += 1
1643+ result = MagicMock ()
1644+ if call_count [0 ] == 1 :
1645+ # First call is balance_optimize - succeed
1646+ result .__getitem__ = lambda self , key : {
1647+ "success" : np .bool_ (True ),
1648+ "message" : "Success" ,
1649+ "x" : x0 ,
1650+ "fun" : 1.0 ,
1651+ }[key ]
1652+ elif call_count [0 ] == 2 :
1653+ # Second call is gmm_optimize with glm_init - succeed
1654+ result .__getitem__ = lambda self , key : {
1655+ "success" : np .bool_ (True ),
1656+ "message" : "Success" ,
1657+ "x" : x0 ,
1658+ "fun" : 0.5 ,
1659+ }[key ]
1660+ else :
1661+ # Third call is gmm_optimize with bal_init - fail
1662+ result .__getitem__ = lambda self , key : {
1663+ "success" : np .bool_ (False ),
1664+ "message" : "Maximum iterations reached for beta_balance" ,
1665+ "x" : x0 ,
1666+ "fun" : 100.0 ,
1667+ }[key ]
1668+ return result
1669+
1670+ def mock_minimize_scalar (fun , ** kwargs ):
1671+ return {
1672+ "success" : np .bool_ (True ),
1673+ "message" : "Success" ,
1674+ "x" : np .array ([1.0 ]),
1675+ }
1676+
1677+ with patch (
1678+ "scipy.optimize.minimize_scalar" , side_effect = mock_minimize_scalar
1679+ ), patch ("scipy.optimize.minimize" , side_effect = mock_minimize ), self .assertLogs (
1680+ "balance.weighting_methods.cbps" , level = logging .WARNING
1681+ ) as log_context :
1682+ try :
1683+ balance_cbps .cbps (
1684+ sample_df ,
1685+ sample_weights ,
1686+ target_df ,
1687+ target_weights ,
1688+ transformations = None ,
1689+ cbps_method = "over" ,
1690+ )
1691+ except Exception :
1692+ pass
1693+
1694+ # Check that a warning about beta_balance was logged
1695+ bal_warnings = [
1696+ r
1697+ for r in log_context .records
1698+ if "beta_balance" in r .getMessage ().lower ()
1699+ or "convergence" in r .getMessage ().lower ()
1700+ ]
1701+ self .assertTrue (
1702+ len (bal_warnings ) > 0 or len (log_context .records ) > 0 ,
1703+ msg = "Expected warning about gmm_loss beta_balance convergence failure" ,
1704+ )
1705+
1706+ def test_over_method_both_gmm_constraint_violation_exception (self ) -> None :
1707+ """Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1708+
1709+ Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1710+ """
1711+ from unittest .mock import MagicMock , patch
1712+
1713+ sample_df , sample_weights , target_df , target_weights = (
1714+ self ._create_simple_test_data ()
1715+ )
1716+
1717+ call_count = [0 ]
1718+
1719+ def mock_minimize (fun , x0 , ** kwargs ):
1720+ call_count [0 ] += 1
1721+ result = MagicMock ()
1722+ if call_count [0 ] == 1 :
1723+ # First call is balance_optimize - succeed
1724+ result .__getitem__ = lambda self , key : {
1725+ "success" : np .bool_ (True ),
1726+ "message" : "Success" ,
1727+ "x" : x0 ,
1728+ "fun" : 1.0 ,
1729+ }[key ]
1730+ else :
1731+ # Both GMM optimizations fail with constraint violation
1732+ result .__getitem__ = lambda self , key : {
1733+ "success" : np .bool_ (False ),
1734+ "message" : "Did not converge to a solution satisfying the constraints" ,
1735+ "x" : x0 ,
1736+ "fun" : 100.0 ,
1737+ }[key ]
1738+ return result
1739+
1740+ def mock_minimize_scalar (fun , ** kwargs ):
1741+ return {
1742+ "success" : np .bool_ (True ),
1743+ "message" : "Success" ,
1744+ "x" : np .array ([1.0 ]),
1745+ }
1746+
1747+ with patch (
1748+ "scipy.optimize.minimize_scalar" , side_effect = mock_minimize_scalar
1749+ ), patch ("scipy.optimize.minimize" , side_effect = mock_minimize ):
1750+ with self .assertRaises (Exception ) as context :
1751+ balance_cbps .cbps (
1752+ sample_df ,
1753+ sample_weights ,
1754+ target_df ,
1755+ target_weights ,
1756+ transformations = None ,
1757+ cbps_method = "over" ,
1758+ )
1759+
1760+ self .assertIn (
1761+ "no solution satisfying the constraints" ,
1762+ str (context .exception ).lower (),
1763+ msg = "Expected exception about constraint violation in over method" ,
1764+ )
0 commit comments