"""
This module provides functions for generating and saving performance metric plots,
specifically Receiver Operating Characteristic (ROC) curves and Precision-Recall (PR) curves.
It supports plotting for individual models across multiple cross-validation folds (with
mean and standard deviation) or comparing multiple models/methods on a single plot.
Plots are saved as SVG files.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from sklearn.metrics import auc, precision_recall_curve, roc_curve
[docs]
def create_metric_plots(model) -> None:
"""
Create and save ROC and Precision-Recall plots as an SVG file for a single model.
Generates a figure with two subplots (ROC on left, PR on right) based on
the data in ``model.model_scores``. If the model-scores table contains
multiple unique 'k' values (folds), it plots individual fold curves and then
the mean curve with a shaded confidence band (standard deviation).
The output file path is determined by ``model.output_file_names['metric_plot']``.
:param model: An object containing evaluation results and output configuration.
It is expected to have the following attributes:
- ``model_scores`` (dict[str, polars.DataFrame]): A dictionary
where keys are target names and values are Polars DataFrames. Each
DataFrame must contain at least 'k' (fold identifier), 'label'
(true binary labels), and 'score' (prediction probabilities/scores)
columns.
- ``output_file_names`` (dict[str, dict[str, str]]): A dictionary
containing output file paths. Specifically,
``output_file_names['metric_plot'][target_name]`` should provide
the full path where the plot for a given target will be saved.
:type model: object
:raises ValueError: If ``model.model_scores`` is empty.
:return: None
:rtype: None
"""
if not model.model_scores:
raise ValueError("Member variable 'model_scores' must not be empty.")
for target_name, df in model.model_scores.items():
output_path = model.output_file_names["metric_plot"][target_name]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
unique_k = df["k"].unique().sort()
has_folds = len(unique_k) > 1
plt.rcParams.update({"font.size": 14})
fig, (ax_roc, ax_pr) = plt.subplots(1, 2, figsize=(12, 6))
# --- ROC Curve Setup ---
mean_fpr = np.linspace(0, 1, 100)
tprs = []
aucs = []
# --- Precision-Recall Curve Setup ---
# We interpolate over recall to average precision
mean_recall = np.linspace(0, 1, 100)
precisions = []
# Loop through folds (or single run)
for k in unique_k:
fold_data = df.filter(pl.col("k") == k)
y_true = fold_data["label"].to_numpy()
y_score = fold_data["score"].to_numpy()
if len(np.unique(y_true)) < 2:
continue # Skip folds with only one class
# ROC
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)
interp_tpr = np.interp(mean_fpr, fpr, tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(roc_auc)
# PR
prec, rec, _ = precision_recall_curve(y_true, y_score)
# Reverse to ensure increasing recall for interpolation
prec = prec[::-1]
rec = rec[::-1]
interp_prec = np.interp(mean_recall, rec, prec)
precisions.append(interp_prec)
if has_folds:
ax_roc.plot(
fpr,
tpr,
lw=1,
alpha=0.3,
label=f"ROC fold {k} (AUC = {roc_auc:.2f})",
)
ax_pr.plot(rec, prec, lw=1, alpha=0.3)
# --- Plot Mean ROC ---
if tprs:
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
label_roc = (
f"Mean ROC (AUC = {mean_auc:.2f} $\\pm$ {std_auc:.2f})"
if has_folds
else f"ROC (AUC = {mean_auc:.2f})"
)
ax_roc.plot(mean_fpr, mean_tpr, color="b", label=label_roc, lw=2, alpha=0.8)
if has_folds:
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax_roc.fill_between(
mean_fpr,
tprs_lower,
tprs_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
# ROC Formatting
ax_roc.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", alpha=0.8)
ax_roc.set_xlim([-0.05, 1.05])
ax_roc.set_ylim([-0.05, 1.05])
ax_roc.set_xlabel("False Positive Rate")
ax_roc.set_ylabel("True Positive Rate")
ax_roc.set_title(f"ROC Curve - {target_name}")
ax_roc.legend(loc="lower right", fontsize="small")
ax_roc.grid(True, alpha=0.3)
# --- Plot Mean PR ---
if precisions:
mean_precision = np.mean(precisions, axis=0)
mean_ap = auc(mean_recall, mean_precision)
label_pr = (
f"Mean PR (AP = {mean_ap:.2f})"
if has_folds
else f"PR (AP = {mean_ap:.2f})"
)
ax_pr.plot(
mean_recall,
mean_precision,
color="b",
label=label_pr,
lw=2,
alpha=0.8,
)
if has_folds:
std_prec = np.std(precisions, axis=0)
prec_upper = np.minimum(mean_precision + std_prec, 1)
prec_lower = np.maximum(mean_precision - std_prec, 0)
ax_pr.fill_between(
mean_recall,
prec_lower,
prec_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
# PR Formatting
ax_pr.set_xlim([-0.05, 1.05])
ax_pr.set_ylim([-0.05, 1.05])
ax_pr.set_xlabel("Recall")
ax_pr.set_ylabel("Precision")
ax_pr.set_title(f"Precision-Recall Curve - {target_name}")
ax_pr.legend(loc="lower left", fontsize="small")
ax_pr.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, format="svg")
plt.close(fig)
[docs]
def create_multi_method_metric_plots(model) -> None:
"""
Create and save ROC and Precision-Recall plots for multiple methods overlaid
on the same figure. Assumes the model-scores tables have a 'method' column.
The output file path is determined by ``model.output_file_names['metric_plot']``.
:param model: An object containing evaluation results and output configuration.
It is expected to have the following attributes:
- ``model_scores`` (dict[str, polars.DataFrame]): A dictionary
where keys are target names and values are Polars DataFrames. Each
DataFrame must contain at least 'method' (method identifier), 'label'
(true binary labels), and 'score' (prediction probabilities/scores)
columns. It aggregates results across all folds/runs for each method.
- ``output_file_names`` (dict[str, dict[str, str]]): A dictionary
containing output file paths. Specifically,
``output_file_names['metric_plot'][target_name]`` should provide
the full path where the plot for a given target will be saved.
:type model: object
:raises ValueError: If ``model.model_scores`` is empty.
:return: None
:rtype: None
.. admonition:: Code Issue
:class: warning
The calculation of Average Precision (AP) for the Precision-Recall curve using
``pr_auc = auc(rec[::-1], prec[::-1])`` is incorrect.
The ``sklearn.metrics.precision_recall_curve`` function returns `recall` values
that are already in increasing order. Therefore, `auc(rec, prec)` should be used
directly to calculate the Area Under the Curve for the Precision-Recall plot.
Reversing `rec` and `prec` before passing them to `auc` when `rec` is already
increasing will lead to an incorrect AP value.
"""
if not model.model_scores:
raise ValueError("Member variable 'model_scores' must not be empty.")
for target_name, df in model.model_scores.items():
output_path = model.output_file_names["metric_plot"][target_name]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.rcParams.update({"font.size": 14})
fig, (ax_roc, ax_pr) = plt.subplots(1, 2, figsize=(14, 7))
# Retrieve unique methods from the aggregated dataframe
methods = df["method"].unique().to_list()
for method in methods:
method_data = df.filter(pl.col("method") == method)
y_true = method_data["label"].to_numpy()
y_score = method_data["score"].to_numpy()
if len(np.unique(y_true)) < 2:
continue # Skip methods where evaluation lacks distinct classes
# ROC
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)
ax_roc.plot(fpr, tpr, lw=2, label=f"{method} (AUC = {roc_auc:.2f})")
# PR
prec, rec, _ = precision_recall_curve(y_true, y_score)
pr_auc = auc(rec[::-1], prec[::-1])
ax_pr.plot(rec, prec, lw=2, label=f"{method} (AP = {pr_auc:.2f})")
# ROC Formatting
ax_roc.plot([0, 1], [0, 1], linestyle="--", lw=2, color="k", alpha=0.5)
ax_roc.set_xlim([-0.05, 1.05])
ax_roc.set_ylim([-0.05, 1.05])
ax_roc.set_xlabel("False Positive Rate")
ax_roc.set_ylabel("True Positive Rate")
ax_roc.set_title(f"ROC Curve - {target_name}")
ax_roc.legend(loc="lower right", fontsize="small")
ax_roc.grid(True, alpha=0.3)
# PR Formatting
ax_pr.set_xlim([-0.05, 1.05])
ax_pr.set_ylim([-0.05, 1.05])
ax_pr.set_xlabel("Recall")
ax_pr.set_ylabel("Precision")
ax_pr.set_title(f"Precision-Recall Curve - {target_name}")
ax_pr.legend(loc="lower left", fontsize="small")
ax_pr.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, format="svg")
plt.close(fig)