Joint Ensemble Member Inference (Discharge × Total Nitrogen)

Joint Ensemble Member Inference (Discharge × Total Nitrogen)#

Propagate joint (discharge, total_nitrogen) ensemble uncertainty through the saved Random Forest bloom classifier. For each of the 4 UQ methods (bootstrap, GLUE, EnKF, LPU), sample 5 random same-method (discharge_member, tn_member) pairs and feed both variables simultaneously into the RF feature frame. Produces the same 4 output artifacts as discharge/notebooks/06_ensemble_member_inference.ipynb but reflects joint uncertainty in both variables rather than marginal uncertainty in one.

Run location: workspace root (Senior_Project_Final/) — see import-strategy cell below.

Import strategy — notebook-only, inline helpers#

This notebook does not import from red_tide_reanalysis (neither sub-project’s copy). All helpers are inlined in the cell below.

Why:

  • Both discharge/src/ and total_nitrogen/src/ ship a Python package named red_tide_reanalysis — only one can be installed at a time. Inlining sidesteps the collision entirely.

  • sys.path.insert(0, os.path.abspath("src")) is fragile: it resolves against the Jupyter kernel’s CWD, which breaks under papermill or when the notebook is opened from a parent directory. This notebook avoids that pattern.

  • The two helpers total ~20 lines — the cost of inlining is low, and keeps the deliverable to notebook-only (no workspace-level Python package).

Drift warning: if discharge/src/red_tide_reanalysis/ml/feature_builder.py or discharge/src/red_tide_reanalysis/ml/inference.py ever change upstream (new feature column, new scaling factor, new preprocessing step), the inlined copies below must be manually reconciled. The source-comment citations in the helpers cell serve as tripwires at code review.

Run location: workspace root. Paths in the USER CONFIGURATION cell above are relative to this CWD; the existence assertions there fire with actionable errors if the notebook is launched from the wrong directory.

import os as _os
from pathlib import Path as _Path
# nbconvert executes notebooks with the kernel CWD set to the notebook's directory.
# When this notebook lives in notebooks/, Path('discharge/...') would resolve from
# notebooks/ instead of the workspace root. Detect and correct automatically.
_cwd = _Path.cwd()
if _cwd.name == "notebooks" and (_cwd.parent / "discharge").is_dir():
    _os.chdir(_cwd.parent)  # move to workspace root (Senior_Project_Final/)
del _os, _Path, _cwd  # clean up after ourselves

# ── USER CONFIGURATION ────────────────────────────────────────────────────────────────────────────────
# Single source of truth for every configurable path and run parameter.
# All paths are resolved relative to the WORKSPACE ROOT (Senior_Project_Final/).
# If run from a different CWD, the existence assertions below fire with actionable
# error messages.

from pathlib import Path

# ── Run parameters ────────────────────────────────────────────────────────────────────────
METHODS              = ["bootstrap", "glue", "enkf", "lpu"]     # authoritative order
N_PAIRS_PER_METHOD   = 5
PAIRING_SEED         = 42
TEST_CUTOFF          = "2019-01-01"                             # matches discharge/06 template
SAVE_OUTPUTS         = True

# ── Input paths ──────────────────────────────────────────
DISCHARGE_ENSEMBLE_DIR = Path("discharge/notebooks/data/outputs/ensembles")
TN_ENSEMBLE_DIR        = Path("total_nitrogen/output/ensembles")
MODEL_PATH             = Path("discharge/models/rf_bloom_classifier.joblib")
SCALER_PATH            = Path("discharge/models/robust_scaler.joblib")
OBS_CSV                = Path("discharge/machineLearning/machineLearning/input/data_weekly_intepolated.csv")  # — note: "intepolated" spelling preserved from upstream

# ── Output paths ─────────────────────────────────────────────────────────────────────────
OUTPUT_DIR  = Path("output/joint")
FIGURES_DIR = OUTPUT_DIR / "figures"
CSV_DIR     = OUTPUT_DIR / "csv"

# ── Ensemble-file dicts ─────────────────────────────────────────────────────────────────────────
# Discharge filenames are NOT uniform — only bootstrap has "peace_river" in the name
# whereas glue/enkf/lpu do not. Explicit dict avoids brittle f-string reconstruction.
ENSEMBLE_FILES_DISCHARGE = {
    "bootstrap": "bootstrap_02296750_peace_river_discharge_members.csv",
    "glue":      "glue_02296750_discharge_members.csv",
    "enkf":      "enkf_02296750_discharge_members.csv",
    "lpu":       "lpu_02296750_discharge_members.csv",
}
# TN filenames ARE uniform: {method}_02296750_peace_river_total_nitrogen_members.csv
ENSEMBLE_FILES_TN = {
    "bootstrap": "bootstrap_02296750_peace_river_total_nitrogen_members.csv",
    "glue":      "glue_02296750_peace_river_total_nitrogen_members.csv",
    "enkf":      "enkf_02296750_peace_river_total_nitrogen_members.csv",
    "lpu":       "lpu_02296750_peace_river_total_nitrogen_members.csv",
}

# ── Path-existence assertions ─────────────────────────────────────────────────────────────────────────
assert MODEL_PATH.exists(), (
    f"Model not found: {MODEL_PATH.resolve()}\n"
    f"  → Run discharge/notebooks/00_serialize_model.ipynb to regenerate, or\n"
    f"    repoint MODEL_PATH at total_nitrogen/models/rf_bloom_classifier.joblib."
)
assert SCALER_PATH.exists(), (
    f"Scaler not found: {SCALER_PATH.resolve()}\n"
    f"  → Run discharge/notebooks/00_serialize_model.ipynb to regenerate, or\n"
    f"    repoint SCALER_PATH at total_nitrogen/models/robust_scaler.joblib."
)
assert DISCHARGE_ENSEMBLE_DIR.is_dir(), (
    f"Discharge ensemble directory not found: {DISCHARGE_ENSEMBLE_DIR.resolve()}\n"
    f"  → Run discharge/notebooks/{{01_bootstrap,02_glue,03_enkf,04_lpu}}_ci.ipynb to generate."
)
assert TN_ENSEMBLE_DIR.is_dir(), (
    f"Total-nitrogen ensemble directory not found: {TN_ENSEMBLE_DIR.resolve()}\n"
    f"  → Run total_nitrogen/{{01_bootstrap,02_glue,03_enkf,04_lpu}}_ci.ipynb to generate."
)
assert OBS_CSV.exists(), (
    f"Weekly-interpolated obs CSV not found: {OBS_CSV.resolve()}\n"
    f"  → Confirm the file lives under discharge/machineLearning/machineLearning/input/"
)

# Per-method ensemble CSV existence checks (8 files total — 4 discharge + 4 TN)
for method in METHODS:
    d_csv = DISCHARGE_ENSEMBLE_DIR / ENSEMBLE_FILES_DISCHARGE[method]
    assert d_csv.exists(), f"Missing discharge ensemble: {d_csv.resolve()}"
    t_csv = TN_ENSEMBLE_DIR / ENSEMBLE_FILES_TN[method]
    assert t_csv.exists(), f"Missing TN ensemble: {t_csv.resolve()}"

# ── Auto-create output subdirectories ─────────────────────────────────────────────────────────────────────────
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
CSV_DIR.mkdir(parents=True, exist_ok=True)

# ── Resolution summary ─────────────────────────────────────────────────────────────────────────────────
print("USER CONFIGURATION resolved:")
print(f"  METHODS                  = {METHODS}")
print(f"  N_PAIRS_PER_METHOD       = {N_PAIRS_PER_METHOD}")
print(f"  PAIRING_SEED             = {PAIRING_SEED}")
print(f"  TEST_CUTOFF              = {TEST_CUTOFF}")
print(f"  DISCHARGE_ENSEMBLE_DIR   = {DISCHARGE_ENSEMBLE_DIR.resolve()}")
print(f"  TN_ENSEMBLE_DIR          = {TN_ENSEMBLE_DIR.resolve()}")
print(f"  MODEL_PATH               = {MODEL_PATH.resolve()}")
print(f"  SCALER_PATH              = {SCALER_PATH.resolve()}")
print(f"  OBS_CSV                  = {OBS_CSV.resolve()}")
print(f"  OUTPUT_DIR               = {OUTPUT_DIR.resolve()}")
print(f"  FIGURES_DIR              = {FIGURES_DIR.resolve()}  (auto-created)")
print(f"  CSV_DIR                  = {CSV_DIR.resolve()}  (auto-created)")
USER CONFIGURATION resolved:
  METHODS                  = ['bootstrap', 'glue', 'enkf', 'lpu']
  N_PAIRS_PER_METHOD       = 5
  PAIRING_SEED             = 42
  TEST_CUTOFF              = 2019-01-01
  DISCHARGE_ENSEMBLE_DIR   = discharge\notebooks\data\outputs\ensembles
  TN_ENSEMBLE_DIR          = total_nitrogen\output\ensembles
  MODEL_PATH               = discharge\models\rf_bloom_classifier.joblib
  SCALER_PATH              = discharge\models\robust_scaler.joblib
  OBS_CSV                  = discharge\machineLearning\machineLearning\input\data_weekly_intepolated.csv
  OUTPUT_DIR               = output\joint
  FIGURES_DIR              = output\joint\figures  (auto-created)
  CSV_DIR                  = output\joint\csv  (auto-created)
# ── INLINED HELPERS ─────────────────────────────────────────────────────────────────────────────
# These are verbatim copies of the two helpers from the discharge sub-project.
# See the markdown cell above for why they are inlined rather than imported.
# If upstream changes, reconcile manually.

from pathlib import Path
import numpy as np
import pandas as pd
import joblib


# source: discharge/src/red_tide_reanalysis/ml/feature_builder.py:13
# Unit conversion: WAM model ensemble output is in m³/s; obs CSV peace_discharge is in CFS.
# The RobustScaler was fit on CFS values — all discharge features fed to the model must be CFS.
M3S_TO_CFS: float = 35.3147


# source: discharge/src/red_tide_reanalysis/ml/feature_builder.py:18-25
# 15 features — exact order from ml.ipynb cell 10.
# The RF classifier and RobustScaler are version-locked to THIS order.
FEATURE_COLS: list[str] = [
    "kb", "kb_prev1", "kb_prev2",
    "zos", "salinity", "water_temp",
    "wind_speed", "wind_direction",
    "peace_discharge", "peace_TN", "peace_TP",
    "peace_discharge_prev1", "peace_TN_prev1", "peace_TP_prev1",
    "discharge_4w_avg",
]


# source: discharge/src/red_tide_reanalysis/ml/feature_builder.py:28-44
def resample_ensemble(ensemble_csv: Path) -> pd.DataFrame:
    """Read daily ensemble CSV (wide: date x member_NNNN), resample to W-MON.

    Parameters
    ----------
    ensemble_csv : Path
        Path to daily ensemble CSV with DatetimeIndex and member columns.

    Returns
    -------
    pd.DataFrame
        Shape (T_weeks, N_members) with Monday-anchored DatetimeIndex.
    """
    df = pd.read_csv(ensemble_csv, index_col=0, parse_dates=True)
    # CRITICAL: closed='left', label='left' — Sunday anchor silently corrupts lag features
    weekly = df.resample("W-MON", closed="left", label="left").mean()
    return weekly


# source: discharge/src/red_tide_reanalysis/ml/inference.py:11-43
def run_inference(
    members_array: np.ndarray,
    model_path: Path,
    scaler_path: Path,
) -> np.ndarray:
    """Run bloom probability inference for each ensemble member.

    CRITICAL: Only .transform() is called on the scaler. NEVER .fit_transform() —
    fit_transform() silently shifts the feature space with no error.

    Parameters
    ----------
    members_array : np.ndarray
        Shape (N, T_weeks, 15) feature array from build_features.
    model_path : Path
        Path to rf_bloom_classifier.joblib.
    scaler_path : Path
        Path to robust_scaler.joblib.

    Returns
    -------
    np.ndarray
        Shape (N, T_weeks) float32 bloom probabilities (class 1).
    """
    model = joblib.load(model_path)
    scaler = joblib.load(scaler_path)

    N, T, F = members_array.shape
    proba = np.empty((N, T), dtype=np.float32)

    for i in range(N):
        X = members_array[i]              # (T, 15)
        Xs = scaler.transform(X)          # NEVER fit_transform
        proba[i] = model.predict_proba(Xs)[:, 1].astype(np.float32)

    return proba


print(f"Inlined helpers loaded: M3S_TO_CFS={M3S_TO_CFS}, FEATURE_COLS (15 entries), resample_ensemble, run_inference")


# source: adapted from discharge/src/red_tide_reanalysis/ml/feature_builder.py:47-94
def build_features_joint(
    d_weekly_df: pd.DataFrame,
    tn_weekly_df: pd.DataFrame,
    obs_csv: Path,
) -> np.ndarray:
    """Build joint (discharge + TN) feature array for N_PAIRS_PER_METHOD pairs.

    Parameters
    ----------
    d_weekly_df : pd.DataFrame
        Shape (T_d_weeks, N_pairs) - discharge members pre-selected and resampled.
        Columns: member_NNNN strings from resample_ensemble output.
        Values in m^3/s - scaled to CFS inside this helper.
    tn_weekly_df : pd.DataFrame
        Shape (T_tn_weeks, N_pairs) - TN members pre-selected and resampled.
        Columns: member_NNNN strings. Values in mg/L - no scaling applied.
        Column i of tn_weekly_df pairs positionally with column i of d_weekly_df.
    obs_csv : Path
        Path to data_weekly_intepolated.csv.

    Returns
    -------
    np.ndarray
        Shape (N_pairs, T_weeks_aligned, 15) with zero NaN.

    Raises
    ------
    ValueError
        If any NaN remains in the feature matrix for any pair after dropna.
    """
    obs = pd.read_csv(obs_csv, parse_dates=["time"], index_col="time")
    # drop BOTH replaced columns - single-variable build_features drops only one
    obs = obs.drop(columns=["peace_discharge", "peace_TN"])

    arrays: list[np.ndarray] = []
    d_cols = list(d_weekly_df.columns)
    tn_cols = list(tn_weekly_df.columns)
    for d_col, tn_col in zip(d_cols, tn_cols):  #  positional pairing
        d_member = (d_weekly_df[d_col] * M3S_TO_CFS).rename("peace_discharge")
        tn_member = tn_weekly_df[tn_col].rename("peace_TN")                      # (unscaled)

        # sequential inner join - obs cap d_member cap tn_member
        merged = obs.join(d_member, how="inner").join(tn_member, how="inner")

        # Lag features -  peace_TN_prev1 from injected ensemble member
        merged["kb_prev1"]              = merged["kb"].shift(1)
        merged["kb_prev2"]              = merged["kb"].shift(2)
        merged["peace_discharge_prev1"] = merged["peace_discharge"].shift(1)
        merged["peace_TN_prev1"]        = merged["peace_TN"].shift(1)   #  ensemble TN
        merged["peace_TP_prev1"]        = merged["peace_TP"].shift(1)   # observed TP (no UQ)
        merged["discharge_4w_avg"]      = merged["peace_discharge"].rolling(window=4).mean()

        # dropna removes TN gap weeks + lag initialization rows
        merged = merged.dropna()

        X = merged[FEATURE_COLS].values  # (T_weeks_aligned, 15) -  canonical order
        if np.isnan(X).any():
            raise ValueError(f"NaN in feature matrix for pair ({d_col}, {tn_col}) after dropna")
        arrays.append(X)

    return np.stack(arrays, axis=0)  # (N_pairs, T_weeks_aligned, 15)
Inlined helpers loaded: M3S_TO_CFS=35.3147, FEATURE_COLS (15 entries), resample_ensemble, run_inference
## Step 1 — Sample joint ensemble pairs

import numpy as np

# METHOD_COLORS for plotting
METHOD_COLORS = {
    "bootstrap": "steelblue",
    "glue": "darkorange",
    "enkf": "forestgreen",
    "lpu": "mediumpurple",
}

# single rng advanced sequentially across all 4 methods
rng = np.random.default_rng(PAIRING_SEED)

# Storage
weekly_d_by_method = {}    # {method: DataFrame(T_d_weeks, 5)}
weekly_tn_by_method = {}   # {method: DataFrame(T_tn_weeks, 5)}
pair_index_by_method = {}  # {method: list[tuple[int, int]]}

for method in METHODS:  #  ["bootstrap", "glue", "enkf", "lpu"]
    # Resample full 200-member ensembles (cell 3 helper)
    d_full = resample_ensemble(DISCHARGE_ENSEMBLE_DIR / ENSEMBLE_FILES_DISCHARGE[method])
    tn_full = resample_ensemble(TN_ENSEMBLE_DIR / ENSEMBLE_FILES_TN[method])
    # mirror total_nitrogen/06_ensemble_member_inference.ipynb cell_step1_code:
    # TN ensemble CSVs are sparse (obs at ~monthly cadence); interpolate so inner-join
    # in build_features_joint keeps the full weekly grid instead of collapsing to ~253 dates.
    tn_full = tn_full.interpolate(method="time", limit_direction="both")

    # two independent draws from [0, 200) without replacement
    d_indices = rng.choice(200, N_PAIRS_PER_METHOD, replace=False)
    tn_indices = rng.choice(200, N_PAIRS_PER_METHOD, replace=False)

    d_cols = [f"member_{i:04d}" for i in d_indices]
    tn_cols = [f"member_{i:04d}" for i in tn_indices]

    weekly_d_by_method[method] = d_full[d_cols].copy()
    weekly_tn_by_method[method] = tn_full[tn_cols].copy()
    pair_index_by_method[method] = list(zip(d_indices.tolist(), tn_indices.tolist()))

    print(f"  {method}: {d_full.shape[0]} d-weeks, {tn_full.shape[0]} tn-weeks  "
          f"({d_full.index[0].date()} \u2192 {d_full.index[-1].date()})")

# print pair table
print()
print(f"{'Method':<12} {'Pair':>4}  {'Discharge col':<16}  {'TN col':<16}")
print("-" * 54)
for method in METHODS:
    for i, (d_idx, tn_idx) in enumerate(pair_index_by_method[method]):
        print(f"{method:<12} {i:>4}  member_{d_idx:04d}        member_{tn_idx:04d}")
  bootstrap: 1305 d-weeks, 1249 tn-weeks  (1998-12-28 → 2023-12-25)
  glue: 1305 d-weeks, 1249 tn-weeks  (1998-12-28 → 2023-12-25)
  enkf: 1305 d-weeks, 1249 tn-weeks  (1998-12-28 → 2023-12-25)
  lpu: 1305 d-weeks, 1249 tn-weeks  (1998-12-28 → 2023-12-25)

Method       Pair  Discharge col     TN col          
------------------------------------------------------
bootstrap       0  member_0152        member_0193
bootstrap       1  member_0087        member_0018
bootstrap       2  member_0129        member_0103
bootstrap       3  member_0017        member_0152
bootstrap       4  member_0086        member_0146
glue            0  member_0099        member_0107
glue            1  member_0164        member_0045
glue            2  member_0088        member_0089
glue            3  member_0073        member_0087
glue            4  member_0036        member_0161
enkf            0  member_0033        member_0176
enkf            1  member_0168        member_0155
enkf            2  member_0163        member_0087
enkf            3  member_0054        member_0190
enkf            4  member_0125        member_0134
lpu             0  member_0008        member_0180
lpu             1  member_0097        member_0074
lpu             2  member_0108        member_0064
lpu             3  member_0148        member_0080
lpu             4  member_0030        member_0189
## Step 2 — Build joint features and run per-method inference

import pandas as pd
import joblib

# Joint features per method
features_by_method = {}
for method in METHODS:
    features_by_method[method] = build_features_joint(
        weekly_d_by_method[method],
        weekly_tn_by_method[method],
        OBS_CSV,
    )  # (5, T_weeks_aligned, 15)
    T = features_by_method[method].shape[1]
    print(f"  {method}: feature shape (5, {T}, 15)")

# Aligned DatetimeIndex — derived from method 0 / pair 0 to match build_features_joint internal join
obs_no_dq = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time").drop(
    columns=["peace_discharge", "peace_TN"]
)
d0 = (weekly_d_by_method[METHODS[0]].iloc[:, 0] * M3S_TO_CFS).rename("peace_discharge")
tn0 = weekly_tn_by_method[METHODS[0]].iloc[:, 0].rename("peace_TN")
merged0 = obs_no_dq.join(d0, how="inner").join(tn0, how="inner")
merged0["kb_prev1"] = merged0["kb"].shift(1)
merged0["kb_prev2"] = merged0["kb"].shift(2)
merged0["peace_discharge_prev1"] = merged0["peace_discharge"].shift(1)
merged0["peace_TN_prev1"] = merged0["peace_TN"].shift(1)
merged0["peace_TP_prev1"] = merged0["peace_TP"].shift(1)
merged0["discharge_4w_avg"] = merged0["peace_discharge"].rolling(window=4).mean()
aligned_index = merged0.dropna().index  # length 253

# True labels
obs_full = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
obs_full["bloom"] = (obs_full["kb"] >= 100_000).astype(int)
obs_full["target_next_week"] = obs_full["bloom"].shift(-1)
y_true_series = obs_full.loc[aligned_index, "target_next_week"]

# shift(-1) guard — trim final NaN if joint end date is obs[-2] or earlier
valid_mask = y_true_series.notna().values
aligned_index = aligned_index[valid_mask]
for method in METHODS:
    features_by_method[method] = features_by_method[method][:, valid_mask, :]
y_true = y_true_series.dropna().astype(int)

# Baseline: observed feature values, no ensemble injection
obs_flat = pd.read_csv(OBS_CSV, parse_dates=["time"]).sort_values("time").reset_index(drop=True)
obs_flat["bloom"] = (obs_flat["kb"] >= 100_000).astype(int)
obs_flat["target_next_week"] = obs_flat["bloom"].shift(-1)
obs_flat["kb_prev1"] = obs_flat["kb"].shift(1)
obs_flat["kb_prev2"] = obs_flat["kb"].shift(2)
obs_flat["peace_discharge_prev1"] = obs_flat["peace_discharge"].shift(1)
obs_flat["peace_TN_prev1"] = obs_flat["peace_TN"].shift(1)
obs_flat["peace_TP_prev1"] = obs_flat["peace_TP"].shift(1)
obs_flat["discharge_4w_avg"] = obs_flat["peace_discharge"].rolling(window=4).mean()
obs_flat = obs_flat.dropna(subset=FEATURE_COLS + ["target_next_week"]).reset_index(drop=True)

baseline_dates = pd.DatetimeIndex(obs_flat["time"])
X_baseline = obs_flat[FEATURE_COLS].values
y_baseline = obs_flat["target_next_week"].astype(int).values

model = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
baseline_proba = model.predict_proba(scaler.transform(X_baseline))[:, 1].astype(np.float32)
baseline_pred = (baseline_proba >= 0.5).astype(int)

# Test-period masks
test_mask_members = aligned_index >= TEST_CUTOFF
test_mask_baseline = baseline_dates >= TEST_CUTOFF
y_test_members = y_true[test_mask_members].values
y_test_baseline = y_baseline[test_mask_baseline]

# Per-method inference
proba_by_method = {}
pred_by_method = {}
for method in METHODS:
    proba_by_method[method] = run_inference(features_by_method[method], MODEL_PATH, SCALER_PATH)
    pred_by_method[method] = (proba_by_method[method] >= 0.5).astype(int)

print(f"Features aligned to shape (5, {len(aligned_index)}, 15) per method")
print(f"Aligned dates: {aligned_index[0].date()}{aligned_index[-1].date()}")
print(f"Test period: members {test_mask_members.sum()} weeks, baseline {test_mask_baseline.sum()} weeks")
print(f"Bloom weeks — test (members): {int(y_test_members.sum())}  |  test (baseline): {int(y_test_baseline.sum())}")
  bootstrap: feature shape (5, 1246, 15)
  glue: feature shape (5, 1246, 15)
  enkf: feature shape (5, 1246, 15)
  lpu: feature shape (5, 1246, 15)
Features aligned to shape (5, 1246, 15) per method
Aligned dates: 2000-01-24 → 2023-12-04
Test period: members 257 weeks, baseline 259 weeks
Bloom weeks — test (members): 80  |  test (baseline): 80

Step 3 — Output artifacts#

Per-method joint-pair confusion matrices.

# === Cell B: Per-method joint-pair confusion-matrix grids ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb::plot_member_cms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, balanced_accuracy_score


def plot_joint_member_cms(method_name: str) -> None:
    """Plot 5 confusion matrices side-by-side for the 5 joint pairs of `method_name`.
    Uses test-period predictions only. title form, filename, caption,
    SAVE_OUTPUTS gating.
    """
    pred = pred_by_method[method_name]
    pair_idx = pair_index_by_method[method_name]                             # list[(d_idx, tn_idx)]
    pair_labels = [f"d{d:04d} \u00D7 tn{t:04d}" for d, t in pair_idx]        # (Unicode x)

    fig, axes = plt.subplots(1, 5, figsize=(20, 4.2))
    for j, (label, ax) in enumerate(zip(pair_labels, axes)):
        y_pred_j = pred[j][test_mask_members]
        cm = confusion_matrix(y_test_members, y_pred_j, labels=[0, 1])
        bal_acc = balanced_accuracy_score(y_test_members, y_pred_j)
        sns.heatmap(
            cm, annot=True, fmt="d", cmap="Blues", ax=ax, cbar=False,
            xticklabels=["No Bloom", "Bloom"],
            yticklabels=["No Bloom", "Bloom"],
        )
        ax.set_xlabel("Predicted")
        ax.set_ylabel("Actual" if j == 0 else "")
        ax.set_title(f"{label}\nBal. Acc = {bal_acc:.3f}")

    fig.suptitle(
        f"{method_name.upper()} \u2014 Per-Pair Joint Confusion Matrices "
        f"(Test Period from {TEST_CUTOFF})",
        fontsize=13,
    )

    # caption below the grid mapping pair index -> full member_XXXX names
    caption_lines = [
        f"pair_{i}: member_{d:04d} \u00D7 member_{t:04d}"
        for i, (d, t) in enumerate(pair_idx)
    ]
    fig.text(0.5, -0.02, "   |   ".join(caption_lines),
             ha="center", fontsize=9, color="0.25")

    plt.tight_layout()
    if SAVE_OUTPUTS:
        fname = FIGURES_DIR / f"{method_name}_joint_member_cms.png"
        plt.savefig(fname, dpi=150, bbox_inches="tight")
        print(f"Saved: {fname}")
    plt.show()


for method in METHODS:
    plot_joint_member_cms(method)
Saved: output\joint\figures\bootstrap_joint_member_cms.png
../_images/8f28b14a43e7450e81496a36336799b8fd3875a5b1fe4dce8df5023b96a8b6b2.png
Saved: output\joint\figures\glue_joint_member_cms.png
../_images/03eb75ef8fcc99a098bc137e82e0a7a341b5282cc7c58fbba713bfd9ad37ca08.png
Saved: output\joint\figures\enkf_joint_member_cms.png
../_images/be5add4876000f71e9716753d251dd3a5a696f9b59d33a65db4692f7800ba21e.png
Saved: output\joint\figures\lpu_joint_member_cms.png
../_images/be1fac0e097d50e7fa34791ba02ec748b4d92fc4aa533ff837aaae8f5da355da.png
# === Cell C: Comparison CM - baseline + 3 ranked methods ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb cell_final_cm_code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

# --- Rank methods by mean TP across their 5 joint pairs ---
method_mean_tp = {}
for method in METHODS:
    cms = np.stack([
        confusion_matrix(
            y_test_members,
            pred_by_method[method][j][test_mask_members],
            labels=[0, 1],
        )
        for j in range(N_PAIRS_PER_METHOD)
    ])
    method_mean_tp[method] = cms.mean(axis=0)[1, 1]

ranked_methods = sorted(METHODS, key=lambda m: method_mean_tp[m], reverse=True)

# --- Figure scaffolding: baseline top-left, top 3 ranked methods filling remaining slots ---
fig = plt.figure(figsize=(12, 9))
gs = gridspec.GridSpec(2, 2, figure=fig)
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
ax_bl = fig.add_subplot(gs[positions[0]])

# --- Baseline cell: recomputed CM from baseline_pred[test_mask_baseline] ---
cm_baseline = confusion_matrix(
    y_test_baseline, baseline_pred[test_mask_baseline], labels=[0, 1]
)
bal_acc_bl = balanced_accuracy_score(
    y_test_baseline, baseline_pred[test_mask_baseline]
)
tp_bl = int(cm_baseline[1, 1])
annot_bl = np.array(
    [[str(int(cm_baseline[i, k])) for k in range(2)] for i in range(2)]
)
sns.heatmap(
    cm_baseline, annot=annot_bl, fmt="", cmap="Blues",
    ax=ax_bl, cbar=False,
    xticklabels=["No Bloom", "Bloom"],
    yticklabels=["No Bloom", "Bloom"],
    annot_kws={"size": 11},
)
ax_bl.set_title(
    f"BASELINE  (observed TN)\nBal. Acc = {bal_acc_bl:.3f}   TP = {tp_bl}",
    fontsize=11,
)

# --- Method cells: mean [p05, p95] annotations, title "{METHOD} (N=5 pairs)" ---
for slot, method in enumerate(ranked_methods[:3], start=1):
    ax_m = fig.add_subplot(gs[positions[slot]])
    cms = np.stack([
        confusion_matrix(
            y_test_members,
            pred_by_method[method][j][test_mask_members],
            labels=[0, 1],
        )
        for j in range(N_PAIRS_PER_METHOD)
    ])
    cm_mean = cms.mean(axis=0)
    cm_p05 = np.percentile(cms, 5, axis=0)
    cm_p95 = np.percentile(cms, 95, axis=0)

    annot = np.empty_like(cm_mean, dtype=object)
    for i in range(2):
        for k in range(2):
            annot[i, k] = (
                f"{cm_mean[i, k]:.1f}\n"
                f"[{cm_p05[i, k]:.0f}, {cm_p95[i, k]:.0f}]"
            )

    bal_accs = [
        balanced_accuracy_score(
            y_test_members,
            pred_by_method[method][j][test_mask_members],
        )
        for j in range(N_PAIRS_PER_METHOD)
    ]
    bal_acc_mean = float(np.mean(bal_accs))

    sns.heatmap(
        cm_mean, annot=annot, fmt="", cmap="Blues",
        ax=ax_m, cbar=False,
        xticklabels=["No Bloom", "Bloom"],
        yticklabels=["No Bloom", "Bloom"],
        annot_kws={"size": 10},
    )
    ax_m.set_title(
        f"{method.upper()} (N=5 pairs)\nBal. Acc = {bal_acc_mean:.3f}",
        fontsize=11,
    )

fig.suptitle(
    f"Joint Ensemble Member Inference - Comparison Confusion Matrices "
    f"(Test Period from {TEST_CUTOFF})",
    fontsize=13,
)
plt.tight_layout(rect=(0, 0, 1, 0.96))

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / "joint_comparison_cm.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")
plt.show()
Saved: output\joint\figures\joint_comparison_cm.png
../_images/5164cb61b0a0de6dc4304dbd5c9746b525043c44cd7aefbbd4d2a0bb56701bf1.png
# === Cell D: Precision-recall curves ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb::cell_pr_code
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

fig, ax = plt.subplots(figsize=(9, 7))

for method in METHODS:
    color = METHOD_COLORS[method]
    aps = []
    for j in range(N_PAIRS_PER_METHOD):
        proba_j = proba_by_method[method][j][test_mask_members]
        prec, rec, _ = precision_recall_curve(y_test_members, proba_j)
        ap = average_precision_score(y_test_members, proba_j)
        aps.append(ap)
        label = (
            f"{method} (mean AP={np.mean(aps):.3f})"
            if j == N_PAIRS_PER_METHOD - 1
            else None
        )
        ax.plot(rec, prec, color=color, linewidth=1.0, alpha=0.7, label=label)

# Baseline on top
prec_bl, rec_bl, _ = precision_recall_curve(
    y_test_baseline, baseline_proba[test_mask_baseline]
)
ap_bl = average_precision_score(
    y_test_baseline, baseline_proba[test_mask_baseline]
)
ax.plot(
    rec_bl, prec_bl, color="black", linewidth=1.8,
    label=f"baseline (AP={ap_bl:.3f})", zorder=10,
)

ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(
    f"Joint Ensemble Member Inference — Precision-Recall Curves "
    f"(Test Period from {TEST_CUTOFF})",
    fontsize=12,
)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1.02)
ax.grid(True, alpha=0.3)
ax.legend(loc="lower left", fontsize=10)

if SAVE_OUTPUTS:
    fname = FIGURES_DIR / "joint_members_pr.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")
plt.show()
Saved: output\joint\figures\joint_members_pr.png
../_images/8a3e84a40464aef7f7d3ba447ef126799fdf0e916f150f8e481d3e70e77b02c9.png
# === Cell E: Full-range bloom-probability traces ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb::cell_traces_code
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# observed bloom dates from OBS_CSV, kb >= 100_000
obs_bloom = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
bloom_dates = obs_bloom.index[obs_bloom["kb"] >= 100_000]
bloom_train = bloom_dates[bloom_dates <  pd.Timestamp(TEST_CUTOFF)]
bloom_test  = bloom_dates[bloom_dates >= pd.Timestamp(TEST_CUTOFF)]
half_week   = pd.Timedelta(days=3)

fig, ax = plt.subplots(figsize=(16, 6))

# Gold (train) + tomato (test) axvspans
for bd in bloom_train:
    ax.axvspan(bd - half_week, bd + half_week,
               color="gold", alpha=0.15, linewidth=0)
for bd in bloom_test:
    ax.axvspan(bd - half_week, bd + half_week,
               color="tomato", alpha=0.20, linewidth=0)

# Method mean probability lines
for method in METHODS:
    mean_proba = proba_by_method[method].mean(axis=0)
    ax.plot(aligned_index, mean_proba,
            color=METHOD_COLORS[method], linewidth=1.2, alpha=0.85,
            label=f"{method} (mean of {N_PAIRS_PER_METHOD} pairs)")

# Baseline on top
ax.plot(baseline_dates, baseline_proba,
        color="black", linewidth=1.4, label="baseline (observed TN)", zorder=5)

# Cutoff + 0.5 reference
ax.axvline(pd.Timestamp(TEST_CUTOFF),
           color="black", linestyle="--", linewidth=1.2)
ax.axhline(0.5, color="red", linestyle=":", linewidth=1.0)

ax.set_xlabel("Date")
ax.set_ylabel("Bloom probability")
ax.set_ylim(0, 1)
ax.set_title(
    f"Joint Ensemble — Bloom-probability traces (full range; "
    f"Test Period from {TEST_CUTOFF})",
    fontsize=12,
)

# Extend legend with bloom-span patches
handles, labels = ax.get_legend_handles_labels()
handles.extend([
    mpatches.Patch(color="gold",   alpha=0.4, label="train-period bloom"),
    mpatches.Patch(color="tomato", alpha=0.4, label="test-period bloom"),
])
ax.legend(handles=handles, loc="upper left", fontsize=9, ncol=2)

ax.text(0.01, 0.02, "Joint (d × tn) pairs",
        transform=ax.transAxes, fontsize=8, color="0.4")

plt.tight_layout()
if SAVE_OUTPUTS:
    fname = FIGURES_DIR / "joint_traces_full.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")
plt.show()
Saved: output\joint\figures\joint_traces_full.png
../_images/1571800226ecb75c94fd22b713b0ef5330e0bdef91a4b1308be5b2e0945ed9de.png
# === Cell F: Test-period zoom bloom-probability traces ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb::cell_zoom_code
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

obs_bloom = pd.read_csv(OBS_CSV, parse_dates=["time"], index_col="time")
bloom_dates = obs_bloom.index[obs_bloom["kb"] >= 100_000]
bloom_test  = bloom_dates[bloom_dates >= pd.Timestamp(TEST_CUTOFF)]
half_week   = pd.Timedelta(days=3)

fig, ax = plt.subplots(figsize=(14, 5))

# Tomato axvspans ONLY ( no gold, no cutoff line)
for bd in bloom_test:
    ax.axvspan(bd - half_week, bd + half_week,
               color="tomato", alpha=0.25, linewidth=0)

# Method mean lines, sliced by test_mask_members
for method in METHODS:
    mean_proba = proba_by_method[method].mean(axis=0)
    ax.plot(
        aligned_index[test_mask_members], mean_proba[test_mask_members],
        color=METHOD_COLORS[method], linewidth=1.2, alpha=0.85,
        label=f"{method} (mean of {N_PAIRS_PER_METHOD} pairs)",
    )

# Baseline, sliced by test_mask_baseline
ax.plot(
    baseline_dates[test_mask_baseline], baseline_proba[test_mask_baseline],
    color="black", linewidth=1.4, label="baseline (observed TN)", zorder=5,
)

ax.axhline(0.5, color="red", linestyle=":", linewidth=1.0)

ax.set_xlabel("Date")
ax.set_ylabel("Bloom probability")
ax.set_ylim(0, 1)
ax.set_title(
    f"Joint Ensemble — Bloom-probability traces (Test Period from {TEST_CUTOFF})",
    fontsize=12,
)

handles, labels = ax.get_legend_handles_labels()
handles.append(
    mpatches.Patch(color="tomato", alpha=0.5, label="test-period bloom")
)
ax.legend(handles=handles, loc="upper left", fontsize=9, ncol=2)

ax.text(0.01, 0.02, "Joint (d × tn) pairs",
        transform=ax.transAxes, fontsize=8, color="0.4")

plt.tight_layout()
if SAVE_OUTPUTS:
    fname = FIGURES_DIR / "joint_traces_test.png"
    plt.savefig(fname, dpi=150, bbox_inches="tight")
    print(f"Saved: {fname}")
plt.show()
Saved: output\joint\figures\joint_traces_test.png
../_images/6f84227563ba09202ec64cd93c083b7cd235626c241f84a4faf7262df9df09e2.png
# === Cell G: Joint probability CSV export ===
# source: adapted from discharge/notebooks/06_ensemble_member_inference.ipynb::cell_export_code
import pandas as pd

if SAVE_OUTPUTS:
    series_dict = {
        "baseline": pd.Series(baseline_proba, index=baseline_dates, name="baseline"),
    }
    # METHODS iteration order drives column order (bootstrap→glue→enkf→lpu)
    for method in METHODS:
        for j, (d_idx, tn_idx) in enumerate(pair_index_by_method[method]):
            col = f"{method}_d{d_idx:04d}_tn{tn_idx:04d}"                 # naming
            series_dict[col] = pd.Series(
                proba_by_method[method][j],
                index=aligned_index,
                name=col,
            )

    # outer-join by timestamp — handles aligned_index vs baseline_dates mismatch
    proba_df = pd.DataFrame(series_dict)
    proba_df.index.name = "date"

    # A2: defensive mkdir — CSV_DIR may not have been created in cell 2
    CSV_DIR.mkdir(parents=True, exist_ok=True)

    # seed IS in the CSV filename (unlike figures —
    proba_csv = CSV_DIR / f"joint_proba_seed{PAIRING_SEED}.csv"
    proba_df.to_csv(proba_csv)
    print(f"Saved: {proba_csv}  shape={proba_df.shape}")
else:
    print("SAVE_OUTPUTS is False — skipping CSV export.")
Saved: output\joint\csv\joint_proba_seed42.csv  shape=(1613, 21)