Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions src/edvise/modeling/bias_detection.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import logging
import time
import random

import typing as t

import matplotlib.figure
import matplotlib.pyplot as plt
import mlflow
import numpy as np
from collections import Counter
import pandas as pd
import scipy.stats as st
import seaborn as sns
import sklearn.metrics

import mlflow
from mlflow.entities import Metric
from mlflow.exceptions import RestException

from . import evaluation

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -627,18 +633,33 @@ def log_subgroup_metrics_to_mlflow(
group_col: str,
) -> None:
"""
Logs individual subgroup-level metrics to MLflow.
Logs subgroup-level bias and performance metrics to MLflow in a single batch.

This function aggregates subgroup metrics into a single payload and logs them
to MLflow using the `_log_metrics_batch_with_retry()` helper, which batches
metrics into one request and automatically retries transient failures. This
reduces API call volume and mitigates 'TEMPORARILY_UNAVAILABLE' errors caused
by rapid per-metric logging.

Args:
subgroup_metrics: Dictionary of subgroup bias metrics.
split_name: Name of the data split (e.g., "train", "test", "validation").
group_col: Column name representing the group for bias evaluation.
subgroup_metrics (dict): Dictionary of subgroup bias or performance metrics.
split_name (str): Name of the data split (e.g., "train", "test", "validation").
group_col (str): Column name representing the demographic or grouping variable
used for bias evaluation.
"""
payload = {}
for metric, value in subgroup_metrics.items():
if metric not in {"Subgroup", "Number of Samples"}:
mlflow.log_metric(
f"{split_name}_{group_col}_metrics/{metric}_subgroup", value
)
key = f"{split_name}_{group_col}_metrics/{metric}_subgroup"
payload[key] = value

active_run = mlflow.active_run()
run_id = active_run.info.run_id if active_run else None
if run_id is None:
with mlflow.start_run(nested=True) as r:
_log_metrics_batch_with_retry(r.info.run_id, payload)
else:
_log_metrics_batch_with_retry(run_id, payload)


def plot_fnr_group(fnr_data: list) -> matplotlib.figure.Figure:
Expand Down Expand Up @@ -692,3 +713,61 @@ def plot_fnr_group(fnr_data: list) -> matplotlib.figure.Figure:

plt.tight_layout()
return fig


def _log_metrics_batch_with_retry(
run_id, metrics_dict, step=0, max_tries=6, base_delay=0.5
):
"""
Log multiple MLflow metrics in a single batch with retry and backoff.

This function batches metrics into one REST call using MlflowClient.log_batch()
to reduce API call volume and avoid transient 'TEMPORARILY_UNAVAILABLE' errors
from the MLflow tracking server. It automatically retries failed requests using
exponential backoff with jitter and, on repeated failure, falls back to logging
the metrics as a CSV artifact to preserve data.

Parameters:
run_id (str): Active MLflow run ID.
metrics_dict (dict): Mapping of metric names to numeric values.
step (int, optional): Metric step value. Defaults to 0.
max_tries (int, optional): Maximum number of retry attempts. Defaults to 6.
base_delay (float, optional): Base delay (in seconds) for exponential backoff. Defaults to 0.5.

Raises:
mlflow.exceptions.RestException: If all retry attempts fail.
"""
client = mlflow.tracking.MlflowClient()
ts = int(time.time() * 1000)
batch = [
Metric(key=k, value=float(v), timestamp=ts, step=step)
for k, v in metrics_dict.items()
]
last_err = None
for attempt in range(1, max_tries + 1):
try:
client.log_batch(run_id, metrics=batch)
return
except RestException as e:
last_err = e
# Retry only on transient cases
if (
"TEMPORARILY_UNAVAILABLE" in str(e)
or "rate limit" in str(e).lower()
or "temporarily unavailable" in str(e).lower()
):
sleep_s = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 0.2)
time.sleep(min(sleep_s, 8.0))
continue
raise
# If we get here, retries failed — fall back to artifact so the run isn’t lost
try:
import io

buf = io.StringIO()
for k, v in metrics_dict.items():
buf.write(f"{ts},{step},{k},{v}\n")
mlflow.log_text(buf.getvalue(), "fallback_metrics.csv")
except Exception:
pass
raise last_err