Skip to content

Commit befdae7

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase test coverage from 98% to near 100% (facebookresearch#281)
Summary: Adds comprehensive tests covering previously untested edge cases identified in the coverage report. The 7 files with coverage gaps now have tests for: - **CLI**: Exception handling paths for weighting failures - **Sample class**: Design effect diagnostics and IPW model parameters - **CBPS**: Optimization convergence warnings and constraint violation exceptions - **Plotting**: Functions with missing values, default parameters, and various dist_types - **Distance metrics**: Empty numeric columns edge case This improves the overall test coverage reliability and catches potential regressions in error handling paths. Differential Revision: D90946146
1 parent 6417fb0 commit befdae7

File tree

12 files changed

+1967
-181
lines changed

12 files changed

+1967
-181
lines changed

.github/workflows/build-and-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
flake8 .
5454
- name: ufmt (formatting check)
5555
run: |
56+
echo "Checking for formatting issues..."
57+
ufmt diff .
5658
ufmt check .
5759
5860
pyre:

balance/stats_and_plots/weighted_comparisons_plots.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,6 @@ def seaborn_plot_dist(
653653
# With limiting the y axis range to (0,1)
654654
seaborn_plot_dist(dfs1, names=["self", "unadjusted", "target"], dist_type = "kde", ylim = (0,1))
655655
"""
656-
# Provide default names if not specified
657-
if names is None:
658-
names = [f"df_{i}" for i in range(len(dfs))]
659-
660656
# Set default dist_type
661657
dist_type_resolved: Literal["qq", "hist", "kde", "ecdf"]
662658
if dist_type is None:
@@ -671,10 +667,6 @@ def seaborn_plot_dist(
671667
if names is None:
672668
names = [f"df_{i}" for i in range(len(dfs))]
673669

674-
# Type narrowing for names parameter
675-
if names is None:
676-
names = []
677-
678670
# Choose set of variables to plot
679671
variables = choose_variables(*(d["df"] for d in dfs), variables=variables)
680672
logger.debug(f"plotting variables {variables}")
@@ -1348,12 +1340,13 @@ def naming_legend(object_name: str, names_of_dfs: List[str]) -> str:
13481340
naming_legend('self', ['self', 'target']) #'sample'
13491341
naming_legend('other_name', ['self', 'target']) #'other_name'
13501342
"""
1351-
if object_name in names_of_dfs:
1352-
return {
1353-
"unadjusted": "sample",
1354-
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1355-
"target": "population",
1356-
}[object_name]
1343+
name_mapping = {
1344+
"unadjusted": "sample",
1345+
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1346+
"target": "population",
1347+
}
1348+
if object_name in name_mapping:
1349+
return name_mapping[object_name]
13571350
else:
13581351
return object_name
13591352

balance/stats_and_plots/weighted_comparisons_stats.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from balance.stats_and_plots.weights_stats import _check_weights_are_valid
2424
from balance.util import _safe_groupby_apply, _safe_replace_and_infer
2525
from balance.utils.input_validation import (
26+
_coerce_to_numeric_and_validate,
2627
_extract_series_and_weights,
2728
_is_discrete_series,
2829
)
@@ -826,16 +827,16 @@ def emd(
826827
)
827828
)
828829
else:
829-
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
830-
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
831-
if sample_vals.empty or target_vals.empty:
832-
raise ValueError("Numeric columns must contain at least one value.")
833-
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
834-
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]
830+
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
831+
sample_series, sample_w, "Sample numeric column"
832+
)
833+
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
834+
target_series, target_w, "Target numeric column"
835+
)
835836
out_dict[col] = float(
836837
wasserstein_distance(
837-
sample_vals.to_numpy(),
838-
target_vals.to_numpy(),
838+
sample_vals,
839+
target_vals,
839840
u_weights=sample_w_numeric,
840841
v_weights=target_w_numeric,
841842
)
@@ -962,21 +963,17 @@ def cvmd(
962963
np.sum((sample_cdf - target_cdf) ** 2 * combined_pmf.to_numpy())
963964
)
964965
else:
965-
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
966-
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
967-
if sample_vals.empty or target_vals.empty:
968-
raise ValueError("Numeric columns must contain at least one value.")
969-
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
970-
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]
971-
972-
sample_sorted, sample_cdf = _weighted_ecdf(
973-
sample_vals.to_numpy(), sample_w_numeric
966+
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
967+
sample_series, sample_w, "Sample numeric column"
974968
)
975-
target_sorted, target_cdf = _weighted_ecdf(
976-
target_vals.to_numpy(), target_w_numeric
969+
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
970+
target_series, target_w, "Target numeric column"
977971
)
972+
973+
sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric)
974+
target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric)
978975
combined_values, combined_weights = _combined_weights(
979-
np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy())),
976+
np.concatenate((sample_vals, target_vals)),
980977
np.concatenate((sample_w_numeric, target_w_numeric)),
981978
)
982979
sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values)
@@ -1101,22 +1098,16 @@ def ks(
11011098
)
11021099
out_dict[col] = float(np.max(np.abs(sample_cdf - target_cdf)))
11031100
else:
1104-
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
1105-
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
1106-
if sample_vals.empty or target_vals.empty:
1107-
raise ValueError("Numeric columns must contain at least one value.")
1108-
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
1109-
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]
1110-
1111-
sample_sorted, sample_cdf = _weighted_ecdf(
1112-
sample_vals.to_numpy(), sample_w_numeric
1101+
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
1102+
sample_series, sample_w, "Sample numeric column"
11131103
)
1114-
target_sorted, target_cdf = _weighted_ecdf(
1115-
target_vals.to_numpy(), target_w_numeric
1116-
)
1117-
combined_values = np.unique(
1118-
np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy()))
1104+
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
1105+
target_series, target_w, "Target numeric column"
11191106
)
1107+
1108+
sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric)
1109+
target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric)
1110+
combined_values = np.unique(np.concatenate((sample_vals, target_vals)))
11201111
sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values)
11211112
target_eval = _evaluate_ecdf(target_sorted, target_cdf, combined_values)
11221113
out_dict[col] = float(np.max(np.abs(sample_eval - target_eval)))

balance/stats_and_plots/weighted_stats.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,8 @@ def ci_of_weighted_mean(
279279
var_weighed_mean_of_v = var_of_weighted_mean(v, w, inf_rm)
280280
z_value = norm.ppf((1 + conf_level) / 2)
281281

282-
if isinstance(v, pd.Series):
283-
ci_index = v.index
284-
elif isinstance(v, pd.DataFrame):
285-
ci_index = v.columns
286-
else:
287-
ci_index = None
282+
# After _prepare_weighted_stat_args, v is always a DataFrame
283+
ci_index = v.columns
288284

289285
ci = pd.Series(
290286
[

balance/utils/input_validation.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,56 @@ def _extract_series_and_weights(
114114
return filtered_series, filtered_weights
115115

116116

117+
def _coerce_to_numeric_and_validate(
118+
series: pd.Series,
119+
weights: np.ndarray,
120+
label: str,
121+
) -> tuple[np.ndarray, np.ndarray]:
122+
"""
123+
Convert series to numeric, drop NaN values, and validate non-empty.
124+
125+
This function handles series that may contain values that cannot be
126+
converted to numeric (e.g., non-numeric strings in an object dtype series).
127+
It coerces such values to NaN and drops them, then validates that at least
128+
one valid numeric value remains.
129+
130+
Args:
131+
series (pd.Series): Input series to convert to numeric.
132+
weights (np.ndarray): Weights aligned to the series.
133+
label (str): Label for error messages.
134+
135+
Returns:
136+
Tuple[np.ndarray, np.ndarray]: Numeric values and corresponding weights.
137+
138+
Raises:
139+
ValueError: If no valid numeric values remain after conversion.
140+
141+
Examples:
142+
.. code-block:: python
143+
144+
import numpy as np
145+
import pandas as pd
146+
from balance.utils.input_validation import _coerce_to_numeric_and_validate
147+
148+
vals, w = _coerce_to_numeric_and_validate(
149+
pd.Series([1.0, 2.0, 3.0]),
150+
np.array([1.0, 1.0, 2.0]),
151+
"example",
152+
)
153+
vals.tolist()
154+
# [1.0, 2.0, 3.0]
155+
w.tolist()
156+
# [1.0, 1.0, 2.0]
157+
"""
158+
numeric_series = pd.to_numeric(series, errors="coerce").dropna()
159+
if numeric_series.empty:
160+
raise ValueError(
161+
f"{label} must contain at least one valid numeric value after conversion."
162+
)
163+
numeric_weights = weights[series.index.isin(numeric_series.index)]
164+
return numeric_series.to_numpy(), numeric_weights
165+
166+
117167
def _is_discrete_series(series: pd.Series) -> bool:
118168
"""
119169
Determine whether a series should be treated as discrete for comparisons.

0 commit comments

Comments
 (0)