Skip to content

Commit 6fc2c22

Browse files
authored
Merge pull request #29 from quantifyearth/sw-model-validation
Add model validation diagnostics
2 parents 6dd0c73 + be419db commit 6fc2c22

File tree

3 files changed

+138
-13
lines changed

3 files changed

+138
-13
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,9 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
# VS Code
163+
.vscode/
164+
165+
# Claude Code
166+
.claude/

aoh/validation/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This directory contains code to implement the model base validation proposed by
44

55
This directory contains the following scripts:
66

7-
* `collate_data.py` - Then you generate a series of AOH GeoTIFFs, besides each one is a JSON file that contains information required for validation. This script takes a folder containing the AOH output of a run and collates all those JSON files into a single CSV file that can be used for a validation run.
7+
* `collate_data.py` - When you generate a series of AOH GeoTIFFs, besides each one is a JSON file that contains information required for validation. This script takes a folder containing the AOH output of a run and collates all those JSON files into a single CSV file that can be used for a validation run.
88
* `validate_map_prevalence.py` - This uses the data in the collated CSV to do a model validation as per the Dahal et al paper.
99
* `fetch_gbif_data.py` - This script takes the collated CSV file and attempts to find occurence data on GBIF that can be used for point validation as per the Dahal et al paper.
1010
* `validate_occurences.py` - This uses the data fetched from GBIF to check the occurrences against a coprus of AOHs.

aoh/validation/validate_map_prevalence.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,90 @@
1111
except (ImportError, ValueError):
1212
pymer4 = None
1313

14-
def model_validation(aoh_df: pd.DataFrame) -> pd.DataFrame:
15-
if pymer4 is None:
16-
raise ImportError("pymer4 is required for model validation but not installed. "
17-
"This requires R to be installed on the system.")
14+
def generate_results_summary(aoh_df: pd.DataFrame, outliers: pd.DataFrame) -> str:
15+
summary_content = (
16+
"# Model Validation Results Summary\n\n"
17+
+ "## Summary Statistics\n\n"
18+
+ f"- **Total species analyzed**: {len(aoh_df[aoh_df.aoh_total > 0])}\n"
19+
+ f"- **Species with no AOH**: {len(aoh_df[aoh_df.aoh_total == 0])}\n"
20+
+ f"- **Total outliers detected**: {len(outliers)}\n\n"
21+
+ "## Species by Taxonomic Class\n"
22+
)
23+
24+
# Count species by class
25+
class_counts = aoh_df.groupby('class_name').size().to_dict()
26+
outlier_counts = outliers.groupby('class_name').size().to_dict()
27+
28+
for class_name in sorted(class_counts.keys()):
29+
total = class_counts.get(class_name, 0)
30+
outlier_count = outlier_counts.get(class_name, 0)
31+
outlier_pct = (outlier_count / total * 100) if total > 0 else 0
32+
summary_content += f"- **{class_name}**: {total} species, {outlier_count} outliers ({outlier_pct:.1f}%)\n"
33+
34+
return summary_content
35+
36+
def add_diagnostic_columns(
37+
klass_df: pd.DataFrame,
38+
upper_fence: float,
39+
lower_fence: float
40+
) -> pd.DataFrame:
41+
# Calculate class means for comparison
42+
klass_means = klass_df[['elevation_rangekm', 'elevation_midkm', 'n_habitats', 'prevalence']].mean()
43+
44+
# Outlier flags and type
45+
klass_df['outlier_type'] = 'normal'
46+
klass_df.loc[klass_df.fit_diff < lower_fence, 'outlier_type'] = 'over-predicted'
47+
klass_df.loc[klass_df.fit_diff > upper_fence, 'outlier_type'] = 'under-predicted'
48+
49+
# Human-readable explanation
50+
klass_df['explanation'] = 'Within normal range'
51+
klass_df.loc[klass_df.outlier_type == 'under-predicted', 'explanation'] = (
52+
'Observed prevalence (' + klass_df['prevalence'].round(3).astype(str) +
53+
') much higher than predicted (' + klass_df['fit'].round(3).astype(str) + ')'
54+
)
55+
klass_df.loc[klass_df.outlier_type == 'over-predicted', 'explanation'] = (
56+
'Observed prevalence (' + klass_df['prevalence'].round(3).astype(str) +
57+
') much lower than predicted (' + klass_df['fit'].round(3).astype(str) + ')'
58+
)
59+
60+
# Context comparison - percentage difference from class mean
61+
klass_df['elevation_range_vs_class_mean'] = (
62+
((klass_df['elevation_rangekm'] - klass_means['elevation_rangekm']) /
63+
klass_means['elevation_rangekm'] * 100).round(1).astype(str) + '%'
64+
)
65+
klass_df['elevation_mid_vs_class_mean'] = (
66+
((klass_df['elevation_midkm'] - klass_means['elevation_midkm']) /
67+
klass_means['elevation_midkm'] * 100).round(1).astype(str) + '%'
68+
)
69+
klass_df['n_habitats_vs_class_mean'] = (
70+
((klass_df['n_habitats'] - klass_means['n_habitats']) /
71+
klass_means['n_habitats'] * 100).round(1).astype(str) + '%'
72+
)
73+
klass_df['prevalence_vs_class_mean'] = (
74+
((klass_df['prevalence'] - klass_means['prevalence']) /
75+
klass_means['prevalence'] * 100).round(1).astype(str) + '%'
76+
)
77+
78+
return klass_df
1879

19-
# Ger rid of any where we had no AoH
20-
aoh_df = aoh_df[aoh_df.prevalence > 0]
80+
def extract_model_coefficients(model: "pymer4.models.Lmer", class_name: str) -> pd.DataFrame:
81+
coef_df = model.coefs.copy()
82+
# Normalize to have explicit variable column for easier downstream pivoting
83+
coef_df = coef_df.reset_index().rename(columns={'index': 'variable'})
84+
coef_df['class_name'] = class_name
85+
return coef_df
86+
87+
def extract_random_effects(model: "pymer4.models.Lmer", class_name: str) -> pd.DataFrame:
88+
ranef_df = model.ranef.copy()
89+
ranef_df['class_name'] = class_name
90+
ranef_df = ranef_df.reset_index()
91+
# pymer4 uses 'X.Intercept.' as the column name for random intercepts
92+
intercept_col = [col for col in ranef_df.columns if 'Intercept' in col][0]
93+
ranef_df = ranef_df.rename(columns={'index': 'family_name', intercept_col: 'random_effect'})
94+
return ranef_df
95+
96+
def add_predictors_to_aoh_df(aoh_df: pd.DataFrame) -> pd.DataFrame:
97+
"""Calculate and standardize predictor variables."""
2198

2299
aoh_df['elevation_range'] = aoh_df['elevation_upper'] - aoh_df['elevation_lower']
23100
aoh_df['elevation_mid'] = (aoh_df['elevation_upper'] + aoh_df['elevation_lower']) / 2
@@ -35,12 +112,29 @@ def model_validation(aoh_df: pd.DataFrame) -> pd.DataFrame:
35112
aoh_df['std_n_habitats'] = (aoh_df.n_habitats - means.n_habitats) \
36113
/ standard_devs.n_habitats
37114

38-
per_class_df = []
115+
return aoh_df
116+
117+
def model_validation(aoh_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
118+
if pymer4 is None:
119+
raise ImportError("pymer4 is required for model validation but not installed. "
120+
"This requires R to be installed on the system.")
121+
122+
# Get rid of any where we had no AoH
123+
aoh_df = aoh_df[aoh_df.prevalence > 0].copy()
124+
125+
# Prepare predictor variables
126+
aoh_df = add_predictors_to_aoh_df(aoh_df)
39127

128+
# Get unique taxonomic classes
40129
klasses = aoh_df.class_name.unique()
41130
if len(klasses) == 0:
42131
raise ValueError("No species classes were found")
43132

133+
# Fit models for each class
134+
per_class_outliers_df = []
135+
per_class_model_coefficients = []
136+
per_class_random_effects = []
137+
44138
for klass in klasses:
45139
klass_df = aoh_df[aoh_df.class_name == klass].copy()
46140
print(f"{klass}:\n\taohs: {len(klass_df)}")
@@ -56,21 +150,46 @@ def model_validation(aoh_df: pd.DataFrame) -> pd.DataFrame:
56150
q1 = klass_df.fit_diff.quantile(q=0.25)
57151
q3 = klass_df.fit_diff.quantile(q=0.75)
58152
iqr = q3 - q1
153+
lower_fence = q1 - (1.5 * iqr)
154+
upper_fence = q3 + (1.5 * iqr)
59155

60-
klass_df['outlier'] = (klass_df.fit_diff > q3 + (1.5 * iqr)) | (klass_df.fit_diff < (q1 - (1.5 * iqr)))
156+
klass_df['outlier'] = (klass_df.fit_diff > upper_fence ) | (klass_df.fit_diff < lower_fence )
157+
klass_df = add_diagnostic_columns(klass_df, upper_fence, lower_fence)
61158
klass_outliers = klass_df[klass_df.outlier == True] # pylint: disable = C0121
62159
print(f"\toutliers: {len(klass_outliers)}")
63-
per_class_df.append(klass_outliers)
160+
per_class_outliers_df.append(klass_outliers)
161+
162+
coef_df = extract_model_coefficients(model, klass)
163+
per_class_model_coefficients.append(coef_df)
164+
165+
ranef_df = extract_random_effects(model, klass)
166+
per_class_random_effects.append(ranef_df)
167+
168+
# Concatenate results
169+
outliers_df = pd.concat(per_class_outliers_df) # type: ignore[arg-type]
170+
model_coefficients_df = pd.concat(per_class_model_coefficients) # type: ignore[arg-type]
171+
random_effects_df = pd.concat(per_class_random_effects) # type: ignore[arg-type]
64172

65-
return pd.concat(per_class_df) # type: ignore[no-any-return]
173+
return outliers_df, model_coefficients_df, random_effects_df
66174

67175
def validate_map_prevalence(
68176
collated_data_path: Path,
69177
output_path: Path,
70178
) -> None:
71179
aoh_df = pd.read_csv(collated_data_path)
72-
outliers = model_validation(aoh_df)
73-
outliers.to_csv(output_path)
180+
outliers, model_coefficients, random_effects = model_validation(aoh_df)
181+
outliers.to_csv(output_path, index=False)
182+
183+
# Save useful model diagnostic files
184+
output_dir = output_path.parent
185+
aoh_df[aoh_df.aoh_total == 0].to_csv(output_dir / "species_with_no_aoh.csv", index=False)
186+
model_coefficients.pivot(
187+
index='class_name', columns='variable', values='Estimate'
188+
).to_csv(output_dir / "model_coefficients.csv", index=True)
189+
random_effects.to_csv(output_dir / "random_effects.csv", index=False)
190+
with open(output_dir / "summary.md", 'w', encoding='utf-8') as f:
191+
summary_content = generate_results_summary(aoh_df, outliers)
192+
f.write(summary_content)
74193

75194
def main() -> None:
76195
parser = argparse.ArgumentParser(description="Validate map prevalence.")

0 commit comments

Comments
 (0)