06 Ensemble Member Inference (All Methods)#

Run 5 selected members from each of the 4 UQ ensembles through the saved RF bloom classifier. For each ensemble, display 5 per-member confusion matrices directly after the inference call. At the end, compare methods via a single mean ± std confusion matrix.

# ── USER CONFIGURATION ────────────────────────────────────────────────────────
# Member indices applied to EVERY ensemble (bootstrap, glue, enkf, lpu).
MEMBER_INDICES = [0, 40, 80, 120, 160]

# All 4 UQ methods
METHODS = ["bootstrap", "glue", "enkf", "lpu"]

SAVE_OUTPUTS = True
TEST_CUTOFF  = "2019-01-01"
# ─────────────────────────────────────────────────────────────────────────────

assert len(MEMBER_INDICES) == 5,            "This notebook expects exactly 5 members."
assert all(0 <= i < 200 for i in MEMBER_INDICES), "Member indices must be in 0-199."
assert all(m in ("bootstrap", "glue", "enkf", "lpu") for m in METHODS)

print(f"Config OK: methods={METHODS}, members={MEMBER_INDICES}")
Config OK: methods=['bootstrap', 'glue', 'enkf', 'lpu'], members=[0, 40, 80, 120, 160]
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import joblib
from sklearn.metrics import (
    balanced_accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, precision_recall_curve, average_precision_score,
)
from red_tide_reanalysis.ml import resample_ensemble, build_features, run_inference
from red_tide_reanalysis.ml.feature_builder import FEATURE_COLS, M3S_TO_CFS

# ── Paths (relative to notebooks/) ───────────────────────────────────────────
OBS_CSV      = Path("../machineLearning/machineLearning/input/data_weekly_intepolated.csv")
ENSEMBLE_DIR = Path("data/outputs/ensembles")
MODEL_PATH   = Path("../models/rf_bloom_classifier.joblib")
SCALER_PATH  = Path("../models/robust_scaler.joblib")
OUTPUT_DIR   = Path("data/outputs/ml")
FIGURES_DIR  = OUTPUT_DIR / "figures"

ENSEMBLE_FILES = {
    "bootstrap": "bootstrap_02296750_peace_river_discharge_members.csv",
    "glue":      "glue_02296750_discharge_members.csv",
    "enkf":      "enkf_02296750_discharge_members.csv",
    "lpu":       "lpu_02296750_discharge_members.csv",
}

assert MODEL_PATH.exists(),  f"Model not found: {MODEL_PATH} — run 00_serialize_model.ipynb first"
assert SCALER_PATH.exists(), f"Scaler not found: {SCALER_PATH}"
if SAVE_OUTPUTS:
    FIGURES_DIR.mkdir(parents=True, exist_ok=True)

METHOD_COLORS = {
    "bootstrap": "steelblue",
    "glue":      "darkorange",
    "enkf":      "forestgreen",
    "lpu":       "mediumpurple",
}

selected_cols = [f"member_{i:04d}" for i in MEMBER_INDICES]

# Dicts populated as each method is run
proba_by_method = {}
pred_by_method  = {}


def plot_member_cms(method_name):
    """Plot 5 confusion matrices side-by-side for the 5 selected members
    of `method_name`. Uses test-period predictions only."""
    pred  = pred_by_method[method_name]
    fig, axes = plt.subplots(1, 5, figsize=(20, 4.2))
    for j, (col, ax) in enumerate(zip(selected_cols, axes)):
        y_pred_j = pred[j][test_mask_members]
        cm       = confusion_matrix(y_test_members, y_pred_j, labels=[0, 1])
        bal_acc  = balanced_accuracy_score(y_test_members, y_pred_j)
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax, cbar=False,
                    xticklabels=["No Bloom", "Bloom"],
                    yticklabels=["No Bloom", "Bloom"])
        ax.set_xlabel("Predicted")
        ax.set_ylabel("Actual" if j == 0 else "")
        ax.set_title(f"{col}\nBal. Acc = {bal_acc:.3f}")

    fig.suptitle(
        f"{method_name.upper()} — Per-Member Confusion Matrices (Test Period 2019+)",
        fontsize=13,
    )
    plt.tight_layout()

    if SAVE_OUTPUTS:
        fname = FIGURES_DIR / f"{method_name}_member_cms.png"
        plt.savefig(fname, dpi=150, bbox_inches="tight")
        print(f"Saved: {fname}")
    plt.show()


print("Setup OK")
Setup OK

Step 1 — Resample all ensembles and select members#

weekly_selected_by_method = {}

for method in METHODS:
    ens_csv = ENSEMBLE_DIR / ENSEMBLE_FILES[method]
    weekly_full = resample_ensemble(ens_csv)               # (T_weeks, 200), m³/s
    weekly_selected_by_method[method] = weekly_full[selected_cols].copy()
    print(f"  {method}: {weekly_full.shape[0]} weekly rows  "
          f"({weekly_full.index[0].date()}{weekly_full.index[-1].date()})")

print(f"\nSelected members (applied to each method): {selected_cols}")
  bootstrap: 1305 weekly rows  (1998-12-28 → 2023-12-25)
  glue: 1305 weekly rows  (1998-12-28 → 2023-12-25)
  glue: 1305 weekly rows  (1998-12-28 → 2023-12-25)
  enkf: 1305 weekly rows  (1998-12-28 → 2023-12-25)
  enkf: 1305 weekly rows  (1998-12-28 → 2023-12-25)
  lpu: 1305 weekly rows  (1998-12-28 → 2023-12-25)

Selected members (applied to each method): ['member_0000', 'member_0040', 'member_0080', 'member_0120', 'member_0160']
  lpu: 1305 weekly rows  (1998-12-28 → 2023-12-25)

Selected members (applied to each method): ['member_0000', 'member_0040', 'member_0080', 'member_0120', 'member_0160']

Step 2 — Build features, baseline, and test-period masks#

All 4 methods share the same weekly grid, so the aligned date index and true bloom labels are computed once here.

# ── Ensemble features per method ─────────────────────────────────────────────
features_by_method = {}
for method in METHODS:
    features_by_method[method] = build_features(weekly_selected_by_method[method], OBS_CSV)

# ── Aligned DatetimeIndex (shared across all methods) ────────────────────────
obs_no_q    = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time").drop(columns=["peace_discharge"])
member0_cfs = (weekly_selected_by_method[METHODS[0]].iloc[:, 0] * M3S_TO_CFS).rename("peace_discharge")
merged0     = obs_no_q.join(member0_cfs, how="inner")
merged0["kb_prev1"]         = merged0["kb"].shift(1)
merged0["discharge_4w_avg"] = merged0["peace_discharge"].rolling(window=4).mean()
aligned_index = merged0.dropna().index

# ── True labels ──────────────────────────────────────────────────────────────
obs_full = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
obs_full["bloom"]            = (obs_full["kb"] >= 100_000).astype(int)
obs_full["target_next_week"] = obs_full["bloom"].shift(-1)
y_true_series = obs_full.loc[aligned_index, "target_next_week"]

# shift(-1) leaves NaN on last row — trim index and features
valid_mask    = y_true_series.notna().values
aligned_index = aligned_index[valid_mask]
for method in METHODS:
    features_by_method[method] = features_by_method[method][:, valid_mask, :]
y_true = y_true_series.dropna().astype(int)

# ── Baseline features ────────────────────────────────────────────────────────
obs_flat = pd.read_csv(OBS_CSV, parse_dates=["time"]).sort_values("time").reset_index(drop=True)
obs_flat["bloom"]                 = (obs_flat["kb"] >= 100_000).astype(int)
obs_flat["target_next_week"]      = obs_flat["bloom"].shift(-1)
obs_flat["kb_prev1"]              = obs_flat["kb"].shift(1)
obs_flat["kb_prev2"]               = obs_flat["kb"].shift(2)
obs_flat["peace_discharge_prev1"] = obs_flat["peace_discharge"].shift(1)
obs_flat["peace_TN_prev1"]        = obs_flat["peace_TN"].shift(1)
obs_flat["peace_TP_prev1"]        = obs_flat["peace_TP"].shift(1)
obs_flat["discharge_4w_avg"]      = obs_flat["peace_discharge"].rolling(window=4).mean()
obs_flat = obs_flat.dropna(subset=FEATURE_COLS + ["target_next_week"]).reset_index(drop=True)

baseline_dates = pd.DatetimeIndex(obs_flat["time"])
X_baseline     = obs_flat[FEATURE_COLS].values
y_baseline     = obs_flat["target_next_week"].astype(int).values

# ── Baseline inference ───────────────────────────────────────────────────────
model  = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
baseline_proba = model.predict_proba(scaler.transform(X_baseline))[:, 1].astype(np.float32)
baseline_pred  = (baseline_proba >= 0.5).astype(int)

# ── Test-period masks and aligned true labels ────────────────────────────────
test_mask_members  = aligned_index  >= TEST_CUTOFF
test_mask_baseline = baseline_dates >= TEST_CUTOFF
y_test_members     = y_true[test_mask_members].values
y_test_baseline    = y_baseline[test_mask_baseline]

print(f"Features aligned to shape (5, {len(aligned_index)}, 15) per method")
print(f"Aligned dates: {aligned_index[0].date()}{aligned_index[-1].date()}")
print(f"Test period: members {test_mask_members.sum()} weeks, baseline {test_mask_baseline.sum()} weeks")
print(f"Bloom weeks — test (members): {int(y_test_members.sum())}  |  "
      f"test (baseline): {int(y_test_baseline.sum())}")
Features aligned to shape (5, 1301, 15) per method
Aligned dates: 1999-01-18 → 2023-12-18
Test period: members 259 weeks, baseline 259 weeks
Bloom weeks — test (members): 80  |  test (baseline): 80

Step 3 — Per-method inference with confusion matrices#

Bootstrap#

# ── BOOTSTRAP — inference for 5 selected members ────────────────────────
proba_by_method["bootstrap"] = run_inference(features_by_method["bootstrap"], MODEL_PATH, SCALER_PATH)
pred_by_method["bootstrap"]  = (proba_by_method["bootstrap"] >= 0.5).astype(int)

plot_member_cms("bootstrap")
Saved: data\outputs\ml\figures\bootstrap_member_cms.png
../_images/b1474d30594e7940fa520c69a4454724b1d30cb3241d92d9e1e5c90ce57cd292.png

GLUE#

# ── GLUE — inference for 5 selected members ────────────────────────
proba_by_method["glue"] = run_inference(features_by_method["glue"], MODEL_PATH, SCALER_PATH)
pred_by_method["glue"]  = (proba_by_method["glue"] >= 0.5).astype(int)

plot_member_cms("glue")
Saved: data\outputs\ml\figures\glue_member_cms.png
../_images/d7c22b89f0aa0712e705e84bc19475ac85c77cee11f63857cc7d9b8fa5ebf8ef.png

EnKF#

# ── ENKF — inference for 5 selected members ────────────────────────
proba_by_method["enkf"] = run_inference(features_by_method["enkf"], MODEL_PATH, SCALER_PATH)
pred_by_method["enkf"]  = (proba_by_method["enkf"] >= 0.5).astype(int)

plot_member_cms("enkf")
Saved: data\outputs\ml\figures\enkf_member_cms.png
../_images/2828f3cccac20f6a2dd1ab9de990d733f73fec59bc27d212040221616c3c20a1.png

LPU#

# ── LPU — inference for 5 selected members ────────────────────────
proba_by_method["lpu"] = run_inference(features_by_method["lpu"], MODEL_PATH, SCALER_PATH)
pred_by_method["lpu"]  = (proba_by_method["lpu"] >= 0.5).astype(int)

plot_member_cms("lpu")
Saved: data\outputs\ml\figures\lpu_member_cms.png
../_images/aea7347cadd5409e3bd2badb063f37efd70861bdf440fcef595f95a42cab68bc.png

Final Comparison — Mean [5th, 95th Percentile] Confusion Matrix per Method#

For each method, the 5 per-member confusion matrices are averaged element-wise, and each cell is annotated as mean on top with [p05, p95] below. The four ensemble methods occupy the 2×2 block on the left; the baseline CM (single deterministic run, integer counts) spans the full right column.

import matplotlib.gridspec as gridspec

# ── Rank ensemble methods by mean True Positive count (descending) ───────────
method_mean_tp = {}
for method in METHODS:
    cms = np.stack([
        confusion_matrix(
            y_test_members,
            pred_by_method[method][j][test_mask_members],
            labels=[0, 1],
        )
        for j in range(len(MEMBER_INDICES))
    ])
    method_mean_tp[method] = cms.mean(axis=0)[1, 1]  # mean TP

ranked_methods = sorted(METHODS, key=lambda m: method_mean_tp[m], reverse=True)

# Layout: 2×2 grid — baseline top-left, then methods by descending mean TP
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

fig = plt.figure(figsize=(12, 9))
gs = gridspec.GridSpec(2, 2, figure=fig)

# ── Baseline (top-left) ─────────────────────────────────────────────────────
ax_bl = fig.add_subplot(gs[0, 0])
cm_baseline = np.array([[172, 7],
                         [15, 65]])
bal_acc_bl = balanced_accuracy_score(
    [0]*172 + [0]*7 + [1]*15 + [1]*65,
    [0]*172 + [1]*7 + [0]*15 + [1]*65,
)
annot_bl = np.array([[str(172), str(7)],
                      [str(15),  str(65)]])
sns.heatmap(cm_baseline, annot=annot_bl, fmt="", cmap="Blues", ax=ax_bl, cbar=False,
            xticklabels=["No Bloom", "Bloom"],
            yticklabels=["No Bloom", "Bloom"],
            annot_kws={"size": 11})
ax_bl.set_xlabel("Predicted")
ax_bl.set_ylabel("Actual")
ax_bl.set_title(
    f"BASELINE  (observed TN)\n"
    f"Bal. Acc = {bal_acc_bl:.3f}   TP = {65}",
    fontsize=11,
)

# ── Ensemble methods (remaining 3 slots, ranked by mean TP) ─────────────────
for method, (r, c) in zip(ranked_methods, positions[1:]):
    ax = fig.add_subplot(gs[r, c])

    cms = np.stack([
        confusion_matrix(
            y_test_members,
            pred_by_method[method][j][test_mask_members],
            labels=[0, 1],
        )
        for j in range(len(MEMBER_INDICES))
    ])
    cm_mean = cms.mean(axis=0)
    cm_p05  = np.percentile(cms, 5,  axis=0)
    cm_p95  = np.percentile(cms, 95, axis=0)

    annot = np.empty_like(cm_mean, dtype=object)
    for i in range(2):
        for k in range(2):
            annot[i, k] = f"{cm_mean[i, k]:.1f}\n[{cm_p05[i, k]:.0f}, {cm_p95[i, k]:.0f}]"

    sns.heatmap(cm_mean, annot=annot, fmt="", cmap="Blues", ax=ax, cbar=False,
                xticklabels=["No Bloom", "Bloom"],
                yticklabels=["No Bloom", "Bloom"],
                annot_kws={"size": 11})

    bal_accs = [
        balanced_accuracy_score(y_test_members,
                                pred_by_method[method][j][test_mask_members])
        for j in range(len(MEMBER_INDICES))
    ]
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual" if c == 0 else "")
    ax.set_title(
        f"{method.upper()}  (N={len(MEMBER_INDICES)} members)\n"
        f"Bal. Acc = {np.mean(bal_accs):.3f} ± {np.std(bal_accs):.3f}   "
        f"mean TP = {cm_mean[1, 1]:.1f}",
        fontsize=11,
    )

fig.suptitle(
    f"Confusion Matrices — mean [5th, 95th percentile] across {len(MEMBER_INDICES)} members "
    f"(Test Period from {TEST_CUTOFF})",
    fontsize=13,
)
plt.tight_layout()

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / f"all_methods_mean_cm_{'_'.join(str(i) for i in MEMBER_INDICES)}.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")

print(f"Method ranking by mean TP: {' > '.join(f'{m} ({method_mean_tp[m]:.1f})' for m in ranked_methods)}")
plt.show()
Saved: data\outputs\ml\figures\all_methods_mean_cm_0_40_80_120_160.png
Method ranking by mean TP: lpu (65.6) > enkf (65.4) > bootstrap (55.6) > glue (41.6)
../_images/186af21d2f951089cea495998f815f25ef4c2511a9ae0968393876a948383706.png

Precision-Recall Curves — All Methods#

One line per (method, member) combination, colored by method. Baseline in black.

fig, ax = plt.subplots(figsize=(9, 7))

# Baseline
prec_bl, rec_bl, _ = precision_recall_curve(y_test_baseline, baseline_proba[test_mask_baseline])
ap_bl = average_precision_score(y_test_baseline, baseline_proba[test_mask_baseline])
ax.plot(rec_bl, prec_bl, color="black", linewidth=2.2,
        label=f"Baseline (AP={ap_bl:.3f})", zorder=10)

for method in METHODS:
    color = METHOD_COLORS[method]
    proba_test = proba_by_method[method][:, test_mask_members]
    aps = []
    for j, col in enumerate(selected_cols):
        prec, rec, _ = precision_recall_curve(y_test_members, proba_test[j])
        ap = average_precision_score(y_test_members, proba_test[j])
        aps.append(ap)
        ax.plot(rec, prec, color=color, linewidth=1.0, alpha=0.7,
                label=f"{method} (mean AP={np.mean(aps):.3f})" if j == len(selected_cols) - 1 else None)

ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curves — All Methods × Members (Test Period 2019+)")
ax.grid(True, alpha=0.3)
ax.legend(fontsize=9, loc="lower left")
plt.tight_layout()

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_pr.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")

plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_pr.png
../_images/6be32fb16f575e5cff51d62031936fd41fc201a9ef8f20f2b836b1ed1cf63829.png

Bloom Probability Traces — Method Means vs Baseline#

# Bloom event dates (from obs)
obs_bloom   = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
bloom_dates = obs_bloom.index[obs_bloom["kb"] >= 100_000]
bloom_train = bloom_dates[bloom_dates <  TEST_CUTOFF]
bloom_test  = bloom_dates[bloom_dates >= TEST_CUTOFF]
half_week   = pd.Timedelta(days=3)

fig, ax = plt.subplots(figsize=(16, 6))

for bd in bloom_train:
    ax.axvspan(bd - half_week, bd + half_week, color="gold",   alpha=0.15, linewidth=0)
for bd in bloom_test:
    ax.axvspan(bd - half_week, bd + half_week, color="tomato", alpha=0.20, linewidth=0)

ax.axvline(pd.Timestamp(TEST_CUTOFF), color="black", linestyle="--", linewidth=1.2)
ax.text(pd.Timestamp(TEST_CUTOFF), 0.97, " Test →",
        transform=ax.get_xaxis_transform(), fontsize=9, color="black", va="top")

ax.plot(baseline_dates, baseline_proba, color="black", linewidth=1.6,
        label="Baseline (observed)", zorder=5)

for method in METHODS:
    mean_trace = proba_by_method[method].mean(axis=0)
    ax.plot(aligned_index, mean_trace,
            color=METHOD_COLORS[method], linewidth=1.2, alpha=0.9,
            label=f"{method} (mean of {len(MEMBER_INDICES)} members)", zorder=4)

ax.axhline(0.5, color="red", linestyle=":", linewidth=0.9, label="0.5 threshold")

train_patch = mpatches.Patch(color="gold",   alpha=0.5, label="Bloom event (train)")
test_patch  = mpatches.Patch(color="tomato", alpha=0.5, label="Bloom event (test)")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles + [train_patch, test_patch], loc="upper left", fontsize=8, ncol=2)

ax.set_ylim(0, 1)
ax.set_ylabel("Bloom Probability  P(bloom)")
ax.set_xlabel("Date")
ax.set_title("Method Means vs Baseline — Full Range", fontsize=12)
plt.tight_layout()

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_traces_full.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")

plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_traces_full.png
../_images/40d21546a0cf53f01b19ecfe2ce24cc0e39f8c7f96463240a0f6d8ec31728b84.png

Test Period Zoom (2019+) — Method Means vs Baseline#

fig, ax = plt.subplots(figsize=(14, 5))

for bd in bloom_test:
    ax.axvspan(bd - half_week, bd + half_week, color="tomato", alpha=0.25, linewidth=0)

ax.plot(baseline_dates[test_mask_baseline], baseline_proba[test_mask_baseline],
        color="black", linewidth=2.0, label="Baseline (observed)", zorder=5)

for method in METHODS:
    mean_trace = proba_by_method[method].mean(axis=0)
    ax.plot(aligned_index[test_mask_members], mean_trace[test_mask_members],
            color=METHOD_COLORS[method], linewidth=1.5, alpha=0.9,
            label=f"{method} mean", zorder=4)

ax.axhline(0.5, color="red", linestyle=":", linewidth=0.9, label="0.5 threshold")

test_patch = mpatches.Patch(color="tomato", alpha=0.5, label="Bloom event")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles + [test_patch], loc="upper right", fontsize=9)

ax.set_ylim(0, 1)
ax.set_ylabel("Bloom Probability  P(bloom)")
ax.set_xlabel("Date")
ax.set_title("Method Means vs Baseline — Test Period (2019+)", fontsize=12)
plt.tight_layout()

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / f"all_methods_members_{'_'.join(str(i) for i in MEMBER_INDICES)}_traces_test.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")

plt.show()
Saved: data\outputs\ml\figures\all_methods_members_0_40_80_120_160_traces_test.png
../_images/58ad4714d82e44181fa7911161e8c97205a15a2c1ad18a3036652e92f8c529d9.png

Optional — Export Probability CSV#

if SAVE_OUTPUTS:
    series_dict = {
        "baseline": pd.Series(baseline_proba, index=baseline_dates, name="baseline")
    }
    for method in METHODS:
        for j, col in enumerate(selected_cols):
            series_dict[f"{method}_{col}"] = pd.Series(
                proba_by_method[method][j], index=aligned_index,
                name=f"{method}_{col}",
            )
    proba_df = pd.DataFrame(series_dict)
    proba_df.index.name = "date"

    proba_csv = OUTPUT_DIR / f"all_methods_proba_{'_'.join(str(i) for i in MEMBER_INDICES)}.csv"
    proba_df.to_csv(proba_csv)
    print(f"Saved: {proba_csv}  shape={proba_df.shape}")
Saved: data\outputs\ml\all_methods_proba_0_40_80_120_160.csv  shape=(1613, 21)