06 Ensemble Member Inference (All Methods)#
Run 5 selected members from each of the 4 UQ ensembles through the saved RF bloom classifier. For each ensemble, display 5 per-member confusion matrices directly after the inference call. At the end, compare methods via a single mean ± std confusion matrix.
# ── USER CONFIGURATION ────────────────────────────────────────────────────────
# Member indices applied to EVERY ensemble (bootstrap, glue, enkf, lpu).
MEMBER_INDICES = [0, 40, 80, 120, 160]
# All 4 UQ methods
METHODS = ["bootstrap", "glue", "enkf", "lpu"]
SAVE_OUTPUTS = True
TEST_CUTOFF = "2019-01-01"
# ─────────────────────────────────────────────────────────────────────────────
assert len(MEMBER_INDICES) == 5, "This notebook expects exactly 5 members."
assert all(0 <= i < 200 for i in MEMBER_INDICES), "Member indices must be in 0-199."
assert all(m in ("bootstrap", "glue", "enkf", "lpu") for m in METHODS)
print(f"Config OK: methods={METHODS}, members={MEMBER_INDICES}")
Config OK: methods=['bootstrap', 'glue', 'enkf', 'lpu'], members=[0, 40, 80, 120, 160]
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import joblib
from sklearn.metrics import (
balanced_accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, precision_recall_curve, average_precision_score,
)
from red_tide_reanalysis.ml import resample_ensemble, build_features, run_inference
from red_tide_reanalysis.ml.feature_builder import FEATURE_COLS, M3S_TO_CFS
# ── Paths (relative to notebooks/) ───────────────────────────────────────────
OBS_CSV = Path("../machineLearning/machineLearning/input/data_weekly_intepolated.csv")
ENSEMBLE_DIR = Path("data/outputs/ensembles")
MODEL_PATH = Path("../models/rf_bloom_classifier.joblib")
SCALER_PATH = Path("../models/robust_scaler.joblib")
OUTPUT_DIR = Path("data/outputs/ml")
FIGURES_DIR = OUTPUT_DIR / "figures"
ENSEMBLE_FILES = {
"bootstrap": "bootstrap_02296750_peace_river_discharge_members.csv",
"glue": "glue_02296750_discharge_members.csv",
"enkf": "enkf_02296750_discharge_members.csv",
"lpu": "lpu_02296750_discharge_members.csv",
}
assert MODEL_PATH.exists(), f"Model not found: {MODEL_PATH} — run 00_serialize_model.ipynb first"
assert SCALER_PATH.exists(), f"Scaler not found: {SCALER_PATH}"
if SAVE_OUTPUTS:
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
METHOD_COLORS = {
"bootstrap": "steelblue",
"glue": "darkorange",
"enkf": "forestgreen",
"lpu": "mediumpurple",
}
selected_cols = [f"member_{i:04d}" for i in MEMBER_INDICES]
# Dicts populated as each method is run
proba_by_method = {}
pred_by_method = {}
def plot_member_cms(method_name):
"""Plot 5 confusion matrices side-by-side for the 5 selected members
of `method_name`. Uses test-period predictions only."""
pred = pred_by_method[method_name]
fig, axes = plt.subplots(1, 5, figsize=(20, 4.2))
for j, (col, ax) in enumerate(zip(selected_cols, axes)):
y_pred_j = pred[j][test_mask_members]
cm = confusion_matrix(y_test_members, y_pred_j, labels=[0, 1])
bal_acc = balanced_accuracy_score(y_test_members, y_pred_j)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax, cbar=False,
xticklabels=["No Bloom", "Bloom"],
yticklabels=["No Bloom", "Bloom"])
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual" if j == 0 else "")
ax.set_title(f"{col}\nBal. Acc = {bal_acc:.3f}")
fig.suptitle(
f"{method_name.upper()} — Per-Member Confusion Matrices (Test Period 2019+)",
fontsize=13,
)
plt.tight_layout()
if SAVE_OUTPUTS:
fname = FIGURES_DIR / f"{method_name}_member_cms.png"
plt.savefig(fname, dpi=150, bbox_inches="tight")
print(f"Saved: {fname}")
plt.show()
print("Setup OK")
Setup OK
Step 1 — Resample all ensembles and select members#
weekly_selected_by_method = {}
for method in METHODS:
ens_csv = ENSEMBLE_DIR / ENSEMBLE_FILES[method]
weekly_full = resample_ensemble(ens_csv) # (T_weeks, 200), m³/s
weekly_selected_by_method[method] = weekly_full[selected_cols].copy()
print(f" {method}: {weekly_full.shape[0]} weekly rows "
f"({weekly_full.index[0].date()} → {weekly_full.index[-1].date()})")
print(f"\nSelected members (applied to each method): {selected_cols}")
bootstrap: 1305 weekly rows (1998-12-28 → 2023-12-25)
glue: 1305 weekly rows (1998-12-28 → 2023-12-25)
glue: 1305 weekly rows (1998-12-28 → 2023-12-25)
enkf: 1305 weekly rows (1998-12-28 → 2023-12-25)
enkf: 1305 weekly rows (1998-12-28 → 2023-12-25)
lpu: 1305 weekly rows (1998-12-28 → 2023-12-25)
Selected members (applied to each method): ['member_0000', 'member_0040', 'member_0080', 'member_0120', 'member_0160']
lpu: 1305 weekly rows (1998-12-28 → 2023-12-25)
Selected members (applied to each method): ['member_0000', 'member_0040', 'member_0080', 'member_0120', 'member_0160']
Step 2 — Build features, baseline, and test-period masks#
All 4 methods share the same weekly grid, so the aligned date index and true bloom labels are computed once here.
# ── Ensemble features per method ─────────────────────────────────────────────
features_by_method = {}
for method in METHODS:
features_by_method[method] = build_features(weekly_selected_by_method[method], OBS_CSV)
# ── Aligned DatetimeIndex (shared across all methods) ────────────────────────
obs_no_q = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time").drop(columns=["peace_discharge"])
member0_cfs = (weekly_selected_by_method[METHODS[0]].iloc[:, 0] * M3S_TO_CFS).rename("peace_discharge")
merged0 = obs_no_q.join(member0_cfs, how="inner")
merged0["kb_prev1"] = merged0["kb"].shift(1)
merged0["discharge_4w_avg"] = merged0["peace_discharge"].rolling(window=4).mean()
aligned_index = merged0.dropna().index
# ── True labels ──────────────────────────────────────────────────────────────
obs_full = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
obs_full["bloom"] = (obs_full["kb"] >= 100_000).astype(int)
obs_full["target_next_week"] = obs_full["bloom"].shift(-1)
y_true_series = obs_full.loc[aligned_index, "target_next_week"]
# shift(-1) leaves NaN on last row — trim index and features
valid_mask = y_true_series.notna().values
aligned_index = aligned_index[valid_mask]
for method in METHODS:
features_by_method[method] = features_by_method[method][:, valid_mask, :]
y_true = y_true_series.dropna().astype(int)
# ── Baseline features ────────────────────────────────────────────────────────
obs_flat = pd.read_csv(OBS_CSV, parse_dates=["time"]).sort_values("time").reset_index(drop=True)
obs_flat["bloom"] = (obs_flat["kb"] >= 100_000).astype(int)
obs_flat["target_next_week"] = obs_flat["bloom"].shift(-1)
obs_flat["kb_prev1"] = obs_flat["kb"].shift(1)
obs_flat["kb_prev2"] = obs_flat["kb"].shift(2)
obs_flat["peace_discharge_prev1"] = obs_flat["peace_discharge"].shift(1)
obs_flat["peace_TN_prev1"] = obs_flat["peace_TN"].shift(1)
obs_flat["peace_TP_prev1"] = obs_flat["peace_TP"].shift(1)
obs_flat["discharge_4w_avg"] = obs_flat["peace_discharge"].rolling(window=4).mean()
obs_flat = obs_flat.dropna(subset=FEATURE_COLS + ["target_next_week"]).reset_index(drop=True)
baseline_dates = pd.DatetimeIndex(obs_flat["time"])
X_baseline = obs_flat[FEATURE_COLS].values
y_baseline = obs_flat["target_next_week"].astype(int).values
# ── Baseline inference ───────────────────────────────────────────────────────
model = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
baseline_proba = model.predict_proba(scaler.transform(X_baseline))[:, 1].astype(np.float32)
baseline_pred = (baseline_proba >= 0.5).astype(int)
# ── Test-period masks and aligned true labels ────────────────────────────────
test_mask_members = aligned_index >= TEST_CUTOFF
test_mask_baseline = baseline_dates >= TEST_CUTOFF
y_test_members = y_true[test_mask_members].values
y_test_baseline = y_baseline[test_mask_baseline]
print(f"Features aligned to shape (5, {len(aligned_index)}, 15) per method")
print(f"Aligned dates: {aligned_index[0].date()} → {aligned_index[-1].date()}")
print(f"Test period: members {test_mask_members.sum()} weeks, baseline {test_mask_baseline.sum()} weeks")
print(f"Bloom weeks — test (members): {int(y_test_members.sum())} | "
f"test (baseline): {int(y_test_baseline.sum())}")
Features aligned to shape (5, 1301, 15) per method
Aligned dates: 1999-01-18 → 2023-12-18
Test period: members 259 weeks, baseline 259 weeks
Bloom weeks — test (members): 80 | test (baseline): 80
Step 3 — Per-method inference with confusion matrices#
Bootstrap#
# ── BOOTSTRAP — inference for 5 selected members ────────────────────────
proba_by_method["bootstrap"] = run_inference(features_by_method["bootstrap"], MODEL_PATH, SCALER_PATH)
pred_by_method["bootstrap"] = (proba_by_method["bootstrap"] >= 0.5).astype(int)
plot_member_cms("bootstrap")
Saved: data\outputs\ml\figures\bootstrap_member_cms.png
GLUE#
# ── GLUE — inference for 5 selected members ────────────────────────
proba_by_method["glue"] = run_inference(features_by_method["glue"], MODEL_PATH, SCALER_PATH)
pred_by_method["glue"] = (proba_by_method["glue"] >= 0.5).astype(int)
plot_member_cms("glue")
Saved: data\outputs\ml\figures\glue_member_cms.png
EnKF#
# ── ENKF — inference for 5 selected members ────────────────────────
proba_by_method["enkf"] = run_inference(features_by_method["enkf"], MODEL_PATH, SCALER_PATH)
pred_by_method["enkf"] = (proba_by_method["enkf"] >= 0.5).astype(int)
plot_member_cms("enkf")
Saved: data\outputs\ml\figures\enkf_member_cms.png
LPU#
# ── LPU — inference for 5 selected members ────────────────────────
proba_by_method["lpu"] = run_inference(features_by_method["lpu"], MODEL_PATH, SCALER_PATH)
pred_by_method["lpu"] = (proba_by_method["lpu"] >= 0.5).astype(int)
plot_member_cms("lpu")
Saved: data\outputs\ml\figures\lpu_member_cms.png
Final Comparison — Mean [5th, 95th Percentile] Confusion Matrix per Method#
For each method, the 5 per-member confusion matrices are averaged element-wise,
and each cell is annotated as mean on top with [p05, p95] below. The four
ensemble methods occupy the 2×2 block on the left; the baseline CM (single
deterministic run, integer counts) spans the full right column.
import matplotlib.gridspec as gridspec
# ── Rank ensemble methods by mean True Positive count (descending) ───────────
method_mean_tp = {}
for method in METHODS:
cms = np.stack([
confusion_matrix(
y_test_members,
pred_by_method[method][j][test_mask_members],
labels=[0, 1],
)
for j in range(len(MEMBER_INDICES))
])
method_mean_tp[method] = cms.mean(axis=0)[1, 1] # mean TP
ranked_methods = sorted(METHODS, key=lambda m: method_mean_tp[m], reverse=True)
# Layout: 2×2 grid — baseline top-left, then methods by descending mean TP
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
fig = plt.figure(figsize=(12, 9))
gs = gridspec.GridSpec(2, 2, figure=fig)
# ── Baseline (top-left) ─────────────────────────────────────────────────────
ax_bl = fig.add_subplot(gs[0, 0])
cm_baseline = np.array([[172, 7],
[15, 65]])
bal_acc_bl = balanced_accuracy_score(
[0]*172 + [0]*7 + [1]*15 + [1]*65,
[0]*172 + [1]*7 + [0]*15 + [1]*65,
)
annot_bl = np.array([[str(172), str(7)],
[str(15), str(65)]])
sns.heatmap(cm_baseline, annot=annot_bl, fmt="", cmap="Blues", ax=ax_bl, cbar=False,
xticklabels=["No Bloom", "Bloom"],
yticklabels=["No Bloom", "Bloom"],
annot_kws={"size": 11})
ax_bl.set_xlabel("Predicted")
ax_bl.set_ylabel("Actual")
ax_bl.set_title(
f"BASELINE (observed TN)\n"
f"Bal. Acc = {bal_acc_bl:.3f} TP = {65}",
fontsize=11,
)
# ── Ensemble methods (remaining 3 slots, ranked by mean TP) ─────────────────
for method, (r, c) in zip(ranked_methods, positions[1:]):
ax = fig.add_subplot(gs[r, c])
cms = np.stack([
confusion_matrix(
y_test_members,
pred_by_method[method][j][test_mask_members],
labels=[0, 1],
)
for j in range(len(MEMBER_INDICES))
])
cm_mean = cms.mean(axis=0)
cm_p05 = np.percentile(cms, 5, axis=0)
cm_p95 = np.percentile(cms, 95, axis=0)
annot = np.empty_like(cm_mean, dtype=object)
for i in range(2):
for k in range(2):
annot[i, k] = f"{cm_mean[i, k]:.1f}\n[{cm_p05[i, k]:.0f}, {cm_p95[i, k]:.0f}]"
sns.heatmap(cm_mean, annot=annot, fmt="", cmap="Blues", ax=ax, cbar=False,
xticklabels=["No Bloom", "Bloom"],
yticklabels=["No Bloom", "Bloom"],
annot_kws={"size": 11})
bal_accs = [
balanced_accuracy_score(y_test_members,
pred_by_method[method][j][test_mask_members])
for j in range(len(MEMBER_INDICES))
]
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual" if c == 0 else "")
ax.set_title(
f"{method.upper()} (N={len(MEMBER_INDICES)} members)\n"
f"Bal. Acc = {np.mean(bal_accs):.3f} ± {np.std(bal_accs):.3f} "
f"mean TP = {cm_mean[1, 1]:.1f}",
fontsize=11,
)
fig.suptitle(
f"Confusion Matrices — mean [5th, 95th percentile] across {len(MEMBER_INDICES)} members "
f"(Test Period from {TEST_CUTOFF})",
fontsize=13,
)
plt.tight_layout()
if SAVE_OUTPUTS:
fname = FIGURES_DIR / f"all_methods_mean_cm_{'_'.join(str(i) for i in MEMBER_INDICES)}.png"
plt.savefig(fname, dpi=150, bbox_inches="tight")
print(f"Saved: {fname}")
print(f"Method ranking by mean TP: {' > '.join(f'{m} ({method_mean_tp[m]:.1f})' for m in ranked_methods)}")
plt.show()
Saved: data\outputs\ml\figures\all_methods_mean_cm_0_40_80_120_160.png
Method ranking by mean TP: lpu (65.6) > enkf (65.4) > bootstrap (55.6) > glue (41.6)
Precision-Recall Curves — All Methods#
One line per (method, member) combination, colored by method. Baseline in black.
fig, ax = plt.subplots(figsize=(9, 7))
# Baseline
prec_bl, rec_bl, _ = precision_recall_curve(y_test_baseline, baseline_proba[test_mask_baseline])
ap_bl = average_precision_score(y_test_baseline, baseline_proba[test_mask_baseline])
ax.plot(rec_bl, prec_bl, color="black", linewidth=2.2,
label=f"Baseline (AP={ap_bl:.3f})", zorder=10)
for method in METHODS:
color = METHOD_COLORS[method]
proba_test = proba_by_method[method][:, test_mask_members]
aps = []
for j, col in enumerate(selected_cols):
prec, rec, _ = precision_recall_curve(y_test_members, proba_test[j])
ap = average_precision_score(y_test_members, proba_test[j])
aps.append(ap)
ax.plot(rec, prec, color=color, linewidth=1.0, alpha=0.7,
label=f"{method} (mean AP={np.mean(aps):.3f})" if j == len(selected_cols) - 1 else None)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curves — All Methods × Members (Test Period 2019+)")
ax.grid(True, alpha=0.3)
ax.legend(fontsize=9, loc="lower left")
plt.tight_layout()
if SAVE_OUTPUTS:
fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_pr.png"
plt.savefig(fname, dpi=150, bbox_inches="tight")
print(f"Saved: {fname}")
plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_pr.png
Bloom Probability Traces — Method Means vs Baseline#
# Bloom event dates (from obs)
obs_bloom = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
bloom_dates = obs_bloom.index[obs_bloom["kb"] >= 100_000]
bloom_train = bloom_dates[bloom_dates < TEST_CUTOFF]
bloom_test = bloom_dates[bloom_dates >= TEST_CUTOFF]
half_week = pd.Timedelta(days=3)
fig, ax = plt.subplots(figsize=(16, 6))
for bd in bloom_train:
ax.axvspan(bd - half_week, bd + half_week, color="gold", alpha=0.15, linewidth=0)
for bd in bloom_test:
ax.axvspan(bd - half_week, bd + half_week, color="tomato", alpha=0.20, linewidth=0)
ax.axvline(pd.Timestamp(TEST_CUTOFF), color="black", linestyle="--", linewidth=1.2)
ax.text(pd.Timestamp(TEST_CUTOFF), 0.97, " Test →",
transform=ax.get_xaxis_transform(), fontsize=9, color="black", va="top")
ax.plot(baseline_dates, baseline_proba, color="black", linewidth=1.6,
label="Baseline (observed)", zorder=5)
for method in METHODS:
mean_trace = proba_by_method[method].mean(axis=0)
ax.plot(aligned_index, mean_trace,
color=METHOD_COLORS[method], linewidth=1.2, alpha=0.9,
label=f"{method} (mean of {len(MEMBER_INDICES)} members)", zorder=4)
ax.axhline(0.5, color="red", linestyle=":", linewidth=0.9, label="0.5 threshold")
train_patch = mpatches.Patch(color="gold", alpha=0.5, label="Bloom event (train)")
test_patch = mpatches.Patch(color="tomato", alpha=0.5, label="Bloom event (test)")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles + [train_patch, test_patch], loc="upper left", fontsize=8, ncol=2)
ax.set_ylim(0, 1)
ax.set_ylabel("Bloom Probability P(bloom)")
ax.set_xlabel("Date")
ax.set_title("Method Means vs Baseline — Full Range", fontsize=12)
plt.tight_layout()
if SAVE_OUTPUTS:
fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_traces_full.png"
plt.savefig(fname, dpi=150, bbox_inches="tight")
print(f"Saved: {fname}")
plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_traces_full.png
Test Period Zoom (2019+) — Method Means vs Baseline#
fig, ax = plt.subplots(figsize=(14, 5))
for bd in bloom_test:
ax.axvspan(bd - half_week, bd + half_week, color="tomato", alpha=0.25, linewidth=0)
ax.plot(baseline_dates[test_mask_baseline], baseline_proba[test_mask_baseline],
color="black", linewidth=2.0, label="Baseline (observed)", zorder=5)
for method in METHODS:
mean_trace = proba_by_method[method].mean(axis=0)
ax.plot(aligned_index[test_mask_members], mean_trace[test_mask_members],
color=METHOD_COLORS[method], linewidth=1.5, alpha=0.9,
label=f"{method} mean", zorder=4)
ax.axhline(0.5, color="red", linestyle=":", linewidth=0.9, label="0.5 threshold")
test_patch = mpatches.Patch(color="tomato", alpha=0.5, label="Bloom event")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles + [test_patch], loc="upper right", fontsize=9)
ax.set_ylim(0, 1)
ax.set_ylabel("Bloom Probability P(bloom)")
ax.set_xlabel("Date")
ax.set_title("Method Means vs Baseline — Test Period (2019+)", fontsize=12)
plt.tight_layout()
if SAVE_OUTPUTS:
fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_traces_test.png"
plt.savefig(fname, dpi=150, bbox_inches="tight")
print(f"Saved: {fname}")
plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_traces_test.png
Optional — Export Probability CSV#
if SAVE_OUTPUTS:
series_dict = {
"baseline": pd.Series(baseline_proba, index=baseline_dates, name="baseline")
}
for method in METHODS:
for j, col in enumerate(selected_cols):
series_dict[f"{method}_{col}"] = pd.Series(
proba_by_method[method][j], index=aligned_index,
name=f"{method}_{col}",
)
proba_df = pd.DataFrame(series_dict)
proba_df.index.name = "date"
proba_csv = OUTPUT_DIR / f"all_methods_proba_{'_'.join(str(i) for i in MEMBER_INDICES)}.csv"
proba_df.to_csv(proba_csv)
print(f"Saved: {proba_csv} shape={proba_df.shape}")
Saved: data\outputs\ml\all_methods_proba_0_40_80_120_160.csv shape=(1613, 21)