Skip to content

HMM Training

Parameters and training modes for the V3 gap-aware HMM. For a conceptual guide see HMM Training Modes.

Parameters

HMMParams dataclass

HMMParams(
    mode: Literal[
        "unsupervised", "semi_supervised", "supervised"
    ] = "unsupervised",
    n_states: int = 3,
    p_stay_per_base: float = 0.92,
    init_prob: NDArray[float64] = (
        lambda: array([0.98, 0.01, 0.01], dtype=float64)
    )(),
    emission_transform: (
        EmissionCalibrator | EmissionKDE | None
    ) = None,
    unmod_emission_beta: tuple[float, float] = (2.0, 8.0),
    flank_emission_beta: tuple[float, float] = (3.0, 3.0),
    mod_emission_beta: tuple[float, float] = (8.0, 2.0),
    training_species: list[str] = list(),
    n_training_positions: int = 0,
    n_training_reads: int = 0,
)

Learned or default HMM parameters for V3.

All fields have defaults so the dataclass can be constructed incrementally or via :func:create_unsupervised_params.

The n_states field controls HMM topology: - n_states=2: Unmodified / Modified (original behaviour). - n_states=3: Unmodified / Flank / Modified. The Flank state absorbs the ±2-base signal halo around modification sites so that only true modification positions contribute to p_mod_hmm.

unmod_emission_beta class-attribute instance-attribute

unmod_emission_beta: tuple[float, float] = (2.0, 8.0)

Beta(2, 8) — mean ≈ 0.2, concentrates on low kNN scores.

flank_emission_beta class-attribute instance-attribute

flank_emission_beta: tuple[float, float] = (3.0, 3.0)

Beta(3, 3) — mean = 0.5, symmetric for moderate kNN scores.

mod_emission_beta class-attribute instance-attribute

mod_emission_beta: tuple[float, float] = (8.0, 2.0)

Beta(8, 2) — mean ≈ 0.8, concentrates on high kNN scores.

EmissionCalibrator dataclass

EmissionCalibrator(a: float, b: float)

Platt-scaling calibrator for V2 → V3 emission mapping (Mode B).

Transforms raw p_mod via sigmoid: σ(a·x + b).

transform

transform(p_mod_raw: NDArray[float64]) -> NDArray[float64]

Map raw P(mod) to calibrated P(mod).

Source code in baleen/eventalign/_hmm_training.py
def transform(self, p_mod_raw: NDArray[np.float64]) -> NDArray[np.float64]:
    """Map raw P(mod) to calibrated P(mod)."""
    z = self.a * np.asarray(p_mod_raw, dtype=np.float64) + self.b
    return 1.0 / (1.0 + np.exp(-z))

EmissionKDE dataclass

EmissionKDE(
    grid: NDArray[float64],
    density_unmod: NDArray[float64],
    density_mod: NDArray[float64],
)

KDE-based emission likelihood model (Mode C).

Stores two pre-evaluated density curves on a fixed grid: P(p_mod_raw | unmodified) and P(p_mod_raw | modified).

At inference time, :meth:emission_probs returns per-observation likelihoods via linear interpolation on the grid.

emission_probs

emission_probs(
    p_mod_raw: NDArray[float64],
) -> tuple[NDArray[float64], NDArray[float64]]

Return (P(obs|unmod), P(obs|mod)) via interpolation.

Values are clamped to [1e-10, ∞) to avoid log(0) issues.

Source code in baleen/eventalign/_hmm_training.py
def emission_probs(
    self, p_mod_raw: NDArray[np.float64]
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
    """Return ``(P(obs|unmod), P(obs|mod))`` via interpolation.

    Values are clamped to ``[1e-10, ∞)`` to avoid log(0) issues.
    """
    x = np.asarray(p_mod_raw, dtype=np.float64)
    p_unmod = np.interp(x, self.grid, self.density_unmod)
    p_mod = np.interp(x, self.grid, self.density_mod)
    # Floor to avoid zero-emission
    p_unmod = np.maximum(p_unmod, 1e-10)
    p_mod = np.maximum(p_mod, 1e-10)
    return p_unmod, p_mod

Training modes

create_unsupervised_params

create_unsupervised_params(n_states: int = 3) -> HMMParams

Build default (unsupervised) HMM parameters.

Parameters:

Name Type Description Default
n_states int

Number of HMM states. 2 returns the legacy 2-state (Unmodified / Modified) model; 3 (default) adds an explicit Flank state that absorbs the ±2-base signal halo.

3
Source code in baleen/eventalign/_hmm_training.py
def create_unsupervised_params(n_states: int = 3) -> HMMParams:
    """Build default (unsupervised) HMM parameters.

    Parameters
    ----------
    n_states
        Number of HMM states.  ``2`` returns the legacy 2-state
        (Unmodified / Modified) model; ``3`` (default) adds an
        explicit Flank state that absorbs the ±2-base signal halo.
    """
    if n_states == 2:
        return HMMParams(
            mode="unsupervised",
            n_states=2,
            init_prob=np.array([0.5, 0.5], dtype=np.float64),
        )
    return HMMParams(mode="unsupervised", n_states=3)

train_semi_supervised

train_semi_supervised(
    training_data: dict[str, ContigModificationResult],
    labels: dict[tuple[str, int], bool],
    *,
    species_name: str = "",
    species_names: list[str] | None = None,
    learn_transitions: bool = True,
    emission_source: str = "p_mod_raw",
    n_states: int = 2
) -> HMMParams

Train Mode B (semi-supervised) HMM parameters.

Parameters:

Name Type Description Default
training_data dict[str, ContigModificationResult]

{contig_name: ContigModificationResult} — must have been computed with V1→V2 (run_hmm=False is fine).

required
labels dict[tuple[str, int], bool]

{(contig, pipeline_position): is_modified} — True means the position is known to carry a modification.

required
species_name str

Optional single species tag stored in metadata.

''
species_names list[str] | None

Optional list of species names for multi-organism pooling. Takes precedence over species_name if provided.

None
learn_transitions bool

If True (default), learn p_stay_per_base from labeled trajectories instead of using the hardcoded 0.98 default.

True
n_states int

Number of HMM states. 2 (default) for U/M; 3 for U/Flank/M where the Flank state uses Beta(3,3) emissions.

2

Returns:

Type Description
HMMParams

With Platt-calibrated emission transform, learned init_prob, and optionally learned transition parameters.

Raises:

Type Description
ValueError

If fewer than 20 labeled positions are provided, or fewer than 10 positive / 10 negative labels.

Source code in baleen/eventalign/_hmm_training.py
def train_semi_supervised(
    training_data: dict[str, ContigModificationResult],
    labels: dict[tuple[str, int], bool],
    *,
    species_name: str = "",
    species_names: list[str] | None = None,
    learn_transitions: bool = True,
    emission_source: str = "p_mod_raw",
    n_states: int = 2,
) -> HMMParams:
    """Train Mode B (semi-supervised) HMM parameters.

    Parameters
    ----------
    training_data
        ``{contig_name: ContigModificationResult}`` — must have been
        computed with V1→V2 (``run_hmm=False`` is fine).
    labels
        ``{(contig, pipeline_position): is_modified}`` — True means the
        position is known to carry a modification.
    species_name
        Optional single species tag stored in metadata.
    species_names
        Optional list of species names for multi-organism pooling.
        Takes precedence over ``species_name`` if provided.
    learn_transitions
        If True (default), learn ``p_stay_per_base`` from labeled
        trajectories instead of using the hardcoded 0.98 default.
    n_states
        Number of HMM states.  ``2`` (default) for U/M; ``3`` for
        U/Flank/M where the Flank state uses Beta(3,3) emissions.

    Returns
    -------
    HMMParams
        With Platt-calibrated emission transform, learned ``init_prob``,
        and optionally learned transition parameters.

    Raises
    ------
    ValueError
        If fewer than 20 labeled positions are provided, or fewer than
        10 positive / 10 negative labels.
    """
    # ── Validate label counts ────────────────────────────────────────────
    n_pos = sum(1 for v in labels.values() if v)
    n_neg = sum(1 for v in labels.values() if not v)
    if n_pos + n_neg < 20:
        raise ValueError(
            f"Semi-supervised training requires >= 20 labeled positions, "
            f"got {n_pos + n_neg}"
        )
    if n_pos < 10:
        raise ValueError(
            f"Need >= 10 positive (modified) labels, got {n_pos}"
        )
    if n_neg < 10:
        raise ValueError(
            f"Need >= 10 negative (unmodified) labels, got {n_neg}"
        )

    # ── Collect (p_mod_raw, is_modified) pairs ───────────────────────────
    raw_vals: list[float] = []
    true_vals: list[float] = []
    n_reads_total = 0

    for (contig, pos), is_mod in labels.items():
        if contig not in training_data:
            continue
        cmr = training_data[contig]
        if pos not in cmr.position_stats:
            continue

        ps = cmr.position_stats[pos]
        n_reads_total += ps.n_native + ps.n_ivt

        source_scores = getattr(ps, emission_source)
        if is_mod:
            # Native reads at modified positions → label 1
            for i in range(ps.n_native):
                raw_vals.append(float(source_scores[i]))
                true_vals.append(1.0)
            # IVT reads at modified positions → label 0 (IVT never modified)
            for i in range(ps.n_native, ps.n_native + ps.n_ivt):
                raw_vals.append(float(source_scores[i]))
                true_vals.append(0.0)
        else:
            # All reads at unmodified positions → label 0
            for i in range(ps.n_native + ps.n_ivt):
                raw_vals.append(float(source_scores[i]))
                true_vals.append(0.0)

    if len(raw_vals) == 0:
        raise ValueError(
            "No reads found at labeled positions — check contig/position keys."
        )

    raw_arr = np.array(raw_vals, dtype=np.float64)
    true_arr = np.array(true_vals, dtype=np.float64)

    # ── Fit Platt scaling ────────────────────────────────────────────────
    a, b = _fit_platt_scaling(raw_arr, true_arr)
    calibrator = EmissionCalibrator(a=a, b=b)

    # ── Learn transition parameters from labeled trajectories ────────────
    if learn_transitions:
        p_stay = _learn_transition_from_labels(training_data, labels)
        logger.info("Semi-supervised learned p_stay=%.4f from labeled data", p_stay)
    else:
        p_stay = 0.98

    # ── Learned init_prob from base rate ─────────────────────────────────
    base_rate = n_pos / max(n_pos + n_neg, 1)
    if n_states == 3:
        flank_rate = min(2.0 * base_rate, 0.3)
        init_prob = np.array(
            [1.0 - base_rate - flank_rate, flank_rate, base_rate],
            dtype=np.float64,
        )
    else:
        init_prob = np.array([1.0 - base_rate, base_rate], dtype=np.float64)

    # ── Species metadata ─────────────────────────────────────────────────
    if species_names is not None:
        species_list = list(species_names)
    elif species_name:
        species_list = [species_name]
    else:
        species_list = []

    return HMMParams(
        mode="semi_supervised",
        n_states=n_states,
        p_stay_per_base=p_stay,
        init_prob=init_prob,
        emission_transform=calibrator,
        training_species=species_list,
        n_training_positions=n_pos + n_neg,
        n_training_reads=n_reads_total,
    )

train_supervised

train_supervised(
    training_data: dict[str, ContigModificationResult],
    labels: dict[tuple[str, int], bool],
    *,
    species_name: str = "",
    kde_n_bins: int = 200,
    kde_bandwidth: float | None = None,
    emission_source: str = "p_mod_raw",
    n_states: int = 2
) -> HMMParams

Train Mode C (fully supervised) HMM parameters.

Parameters:

Name Type Description Default
training_data dict[str, ContigModificationResult]

{contig_name: ContigModificationResult} — V1→V2 results.

required
labels dict[tuple[str, int], bool]

{(contig, pipeline_position): is_modified}

required
species_name str

Optional species tag.

''
kde_n_bins int

Number of evaluation points for KDE grid.

200
kde_bandwidth float | None

Explicit bandwidth for KDE; None = Scott's rule (default).

None
n_states int

Number of HMM states. 2 (default) for U/M; 3 for U/Flank/M where the Flank state uses Beta(3,3) emissions.

2

Returns:

Type Description
HMMParams

With MLE transition, KDE emission model, and learned init_prob.

Raises:

Type Description
ValueError

If fewer than 50 labeled positions or fewer than 3 contigs.

Source code in baleen/eventalign/_hmm_training.py
def train_supervised(
    training_data: dict[str, ContigModificationResult],
    labels: dict[tuple[str, int], bool],
    *,
    species_name: str = "",
    kde_n_bins: int = 200,
    kde_bandwidth: float | None = None,
    emission_source: str = "p_mod_raw",
    n_states: int = 2,
) -> HMMParams:
    """Train Mode C (fully supervised) HMM parameters.

    Parameters
    ----------
    training_data
        ``{contig_name: ContigModificationResult}`` — V1→V2 results.
    labels
        ``{(contig, pipeline_position): is_modified}``
    species_name
        Optional species tag.
    kde_n_bins
        Number of evaluation points for KDE grid.
    kde_bandwidth
        Explicit bandwidth for KDE; ``None`` = Scott's rule (default).
    n_states
        Number of HMM states.  ``2`` (default) for U/M; ``3`` for
        U/Flank/M where the Flank state uses Beta(3,3) emissions.

    Returns
    -------
    HMMParams
        With MLE transition, KDE emission model, and learned ``init_prob``.

    Raises
    ------
    ValueError
        If fewer than 50 labeled positions or fewer than 3 contigs.
    """
    n_pos = sum(1 for v in labels.values() if v)
    n_neg = sum(1 for v in labels.values() if not v)
    contig_set = {c for c, _ in labels.keys()}

    if n_pos + n_neg < 50:
        raise ValueError(
            f"Supervised training requires >= 50 labeled positions, "
            f"got {n_pos + n_neg}"
        )
    if len(contig_set) < 3:
        raise ValueError(
            f"Supervised training requires >= 3 contigs, got {len(contig_set)}"
        )

    # ── 1. Collect per-read p_mod_raw by label ───────────────────────────
    mod_vals: list[float] = []
    unmod_vals: list[float] = []
    n_reads_total = 0

    for (contig, pos), is_mod in labels.items():
        if contig not in training_data:
            continue
        cmr = training_data[contig]
        if pos not in cmr.position_stats:
            continue

        ps = cmr.position_stats[pos]
        n_reads_total += ps.n_native + ps.n_ivt

        source_scores = getattr(ps, emission_source)
        if is_mod:
            # Native reads → modified
            for i in range(ps.n_native):
                mod_vals.append(float(source_scores[i]))
            # IVT reads → unmodified
            for i in range(ps.n_native, ps.n_native + ps.n_ivt):
                unmod_vals.append(float(source_scores[i]))
        else:
            # All reads → unmodified
            for i in range(ps.n_native + ps.n_ivt):
                unmod_vals.append(float(source_scores[i]))

    if len(mod_vals) < 5 or len(unmod_vals) < 5:
        raise ValueError(
            f"Need >= 5 reads in each class for KDE fitting. "
            f"Got {len(mod_vals)} modified, {len(unmod_vals)} unmodified."
        )

    # ── 2. Fit KDE emission model ────────────────────────────────────────
    mod_arr = np.array(mod_vals, dtype=np.float64)
    unmod_arr = np.array(unmod_vals, dtype=np.float64)

    bw_kwargs = {"bw_method": kde_bandwidth} if kde_bandwidth is not None else {}
    kde_mod = _gaussian_kde(mod_arr, **bw_kwargs)
    kde_unmod = _gaussian_kde(unmod_arr, **bw_kwargs)

    grid = np.linspace(0.0, 1.0, kde_n_bins)
    emission_kde = EmissionKDE(
        grid=grid,
        density_unmod=kde_unmod(grid).astype(np.float64),
        density_mod=kde_mod(grid).astype(np.float64),
    )

    # ── 3. MLE transition from labeled trajectories ──────────────────────
    same_count = 0.0
    diff_count = 0.0

    for contig_name, cmr in training_data.items():
        all_trajs = list(cmr.native_trajectories) + list(cmr.ivt_trajectories)
        is_ivt_offset = len(cmr.native_trajectories)

        for traj_idx, traj in enumerate(all_trajs):
            is_ivt = traj_idx >= is_ivt_offset
            # Build state sequence at labeled positions
            labeled_pairs: list[tuple[int, int]] = []  # (position, state)
            for pos in traj.positions:
                key = (contig_name, pos)
                if key not in labels:
                    continue
                if is_ivt:
                    state = 0  # IVT reads are always unmodified
                else:
                    state = 1 if labels[key] else 0
                labeled_pairs.append((pos, state))

            # Consecutive pairs weighted by 1/gap
            for i in range(len(labeled_pairs) - 1):
                pos_i, state_i = labeled_pairs[i]
                pos_j, state_j = labeled_pairs[i + 1]
                gap = max(pos_j - pos_i, 1)
                if state_i == state_j:
                    same_count += 1.0 / gap
                else:
                    diff_count += 1.0 / gap

    total_transitions = same_count + diff_count
    if total_transitions > 0:
        p_stay = same_count / total_transitions
    else:
        p_stay = 0.98  # fallback to default

    # Clamp to [0.8, 0.999]
    p_stay = max(0.8, min(p_stay, 0.999))

    # ── 4. Learned init_prob ─────────────────────────────────────────────
    base_rate = n_pos / max(n_pos + n_neg, 1)
    if n_states == 3:
        flank_rate = min(2.0 * base_rate, 0.3)
        init_prob = np.array(
            [1.0 - base_rate - flank_rate, flank_rate, base_rate],
            dtype=np.float64,
        )
    else:
        init_prob = np.array([1.0 - base_rate, base_rate], dtype=np.float64)

    species_list = [species_name] if species_name else []

    return HMMParams(
        mode="supervised",
        n_states=n_states,
        p_stay_per_base=p_stay,
        init_prob=init_prob,
        emission_transform=emission_kde,
        training_species=species_list,
        n_training_positions=n_pos + n_neg,
        n_training_reads=n_reads_total,
    )

Labels & cross-validation

labels_from_known_modifications

labels_from_known_modifications(
    known_mods: dict[tuple[str, int], tuple[str, str]],
    contig_results: dict[str, ContigModificationResult],
    *,
    position_offset: int = 3,
    auto_negatives: bool = True,
    min_coverage: int = 5
) -> dict[tuple[str, int], bool]

Convert known biological modification sites to training labels.

Parameters:

Name Type Description Default
known_mods dict[tuple[str, int], tuple[str, str]]

{(contig, bio_position): (mod_short, mod_full), ...} where bio_position is the 1-based biological coordinate.

required
contig_results dict[str, ContigModificationResult]

{contig_name: ContigModificationResult} keyed by contig name.

required
position_offset int

bio_position - offset = pipeline_position. Default 3 for eventalign 5-mer centre.

3
auto_negatives bool

If True, positions with n_native + n_ivt >= min_coverage that are not in known_mods become negative (unmodified) labels.

True
min_coverage int

Minimum total read coverage for auto-negative positions.

5

Returns:

Type Description
labels

{(contig, pipeline_position): is_modified}

Source code in baleen/eventalign/_hmm_training.py
def labels_from_known_modifications(
    known_mods: dict[tuple[str, int], tuple[str, str]],
    contig_results: dict[str, ContigModificationResult],
    *,
    position_offset: int = 3,
    auto_negatives: bool = True,
    min_coverage: int = 5,
) -> dict[tuple[str, int], bool]:
    """Convert known biological modification sites to training labels.

    Parameters
    ----------
    known_mods
        ``{(contig, bio_position): (mod_short, mod_full), ...}`` where
        *bio_position* is the 1-based biological coordinate.
    contig_results
        ``{contig_name: ContigModificationResult}`` keyed by contig name.
    position_offset
        ``bio_position - offset = pipeline_position``.  Default 3 for
        eventalign 5-mer centre.
    auto_negatives
        If True, positions with ``n_native + n_ivt >= min_coverage`` that
        are **not** in *known_mods* become negative (unmodified) labels.
    min_coverage
        Minimum total read coverage for auto-negative positions.

    Returns
    -------
    labels
        ``{(contig, pipeline_position): is_modified}``
    """
    labels: dict[tuple[str, int], bool] = {}

    # Positive labels from known modifications
    for (contig, bio_pos), (_mod_short, _mod_full) in known_mods.items():
        pipeline_pos = bio_pos - position_offset
        if contig not in contig_results:
            continue
        cmr = contig_results[contig]
        if pipeline_pos not in cmr.position_stats:
            continue
        labels[(contig, pipeline_pos)] = True

    # Auto-negative labels
    if auto_negatives:
        for contig_name, cmr in contig_results.items():
            for pos, ps in cmr.position_stats.items():
                key = (contig_name, pos)
                if key in labels:
                    continue  # already labeled (positive)
                if ps.n_native + ps.n_ivt >= min_coverage:
                    labels[key] = False

    return labels

cross_validate_hmm

cross_validate_hmm(
    contig_results: dict[str, ContigResult],
    labels: dict[tuple[str, int], bool],
    mode: Literal["semi_supervised", "supervised"],
    *,
    cv_strategy: Literal[
        "leave_one_contig_out", "kfold"
    ] = "leave_one_contig_out",
    k: int = 5,
    emission_source: str = "p_mod_raw",
    **hierarchical_kwargs
) -> CVResult

Cross-validate HMM training to detect overfitting.

Parameters:

Name Type Description Default
contig_results dict[str, ContigResult]

Raw pipeline output per contig (ContigResult).

required
labels dict[tuple[str, int], bool]

{(contig, pipeline_position): is_modified}

required
mode Literal['semi_supervised', 'supervised']

Training mode to evaluate ("semi_supervised" or "supervised").

required
cv_strategy Literal['leave_one_contig_out', 'kfold']

"leave_one_contig_out" (default) or "kfold" (by contig).

'leave_one_contig_out'
k int

Number of folds for k-fold CV.

5
**hierarchical_kwargs

Forwarded to :func:~baleen.eventalign._hierarchical.compute_sequential_modification_probabilities.

{}

Returns:

Type Description
CVResult
Source code in baleen/eventalign/_hmm_training.py
def cross_validate_hmm(
    contig_results: dict[str, ContigResult],
    labels: dict[tuple[str, int], bool],
    mode: Literal["semi_supervised", "supervised"],
    *,
    cv_strategy: Literal["leave_one_contig_out", "kfold"] = "leave_one_contig_out",
    k: int = 5,
    emission_source: str = "p_mod_raw",
    **hierarchical_kwargs,
) -> CVResult:
    """Cross-validate HMM training to detect overfitting.

    Parameters
    ----------
    contig_results
        Raw pipeline output per contig (``ContigResult``).
    labels
        ``{(contig, pipeline_position): is_modified}``
    mode
        Training mode to evaluate (``"semi_supervised"`` or ``"supervised"``).
    cv_strategy
        ``"leave_one_contig_out"`` (default) or ``"kfold"`` (by contig).
    k
        Number of folds for k-fold CV.
    **hierarchical_kwargs
        Forwarded to
        :func:`~baleen.eventalign._hierarchical.compute_sequential_modification_probabilities`.

    Returns
    -------
    CVResult
    """
    # Lazy import to avoid circular dependency
    from baleen.eventalign._hierarchical import (
        compute_sequential_modification_probabilities,
    )

    # ── Build folds ──────────────────────────────────────────────────────
    contigs_with_labels = sorted({c for c, _ in labels.keys()})
    # Only keep contigs that exist in contig_results
    contigs_with_labels = [c for c in contigs_with_labels if c in contig_results]

    if len(contigs_with_labels) < 2:
        raise ValueError(
            "Need labels in >= 2 contigs for cross-validation, "
            f"got {len(contigs_with_labels)}"
        )

    if cv_strategy == "leave_one_contig_out":
        folds = [([c], [x for x in contigs_with_labels if x != c])
                 for c in contigs_with_labels]
    else:
        # k-fold by contig
        n = len(contigs_with_labels)
        fold_size = max(1, n // k)
        folds = []
        for i in range(0, n, fold_size):
            test_contigs = contigs_with_labels[i : i + fold_size]
            train_contigs = [c for c in contigs_with_labels if c not in test_contigs]
            if train_contigs:
                folds.append((test_contigs, train_contigs))

    # ── Run V1+V2 on all contigs once ────────────────────────────────────
    v2_results: dict[str, ContigModificationResult] = {}
    for contig_name in contigs_with_labels:
        cr = contig_results[contig_name]
        v2_results[contig_name] = compute_sequential_modification_probabilities(
            cr, run_hmm=False, **hierarchical_kwargs
        )

    # ── Per-fold train/test ──────────────────────────────────────────────
    per_fold_auroc: list[float] = []
    per_fold_auprc: list[float] = []
    fold_details: list[dict[str, Any]] = []

    for test_contigs, train_contigs in folds:
        # Build train labels & data
        train_labels = {
            (c, p): v for (c, p), v in labels.items() if c in train_contigs
        }
        train_data = {c: v2_results[c] for c in train_contigs if c in v2_results}

        # Check minimum requirements — skip fold if insufficient
        n_train_pos = sum(1 for v in train_labels.values() if v)
        n_train_neg = sum(1 for v in train_labels.values() if not v)

        try:
            if mode == "semi_supervised":
                hmm_params = train_semi_supervised(
                    train_data, train_labels, emission_source=emission_source,
                )
            else:
                hmm_params = train_supervised(
                    train_data, train_labels, emission_source=emission_source,
                )
        except ValueError as e:
            logger.warning(
                "Skipping fold (test=%s): insufficient training data — %s",
                test_contigs,
                e,
            )
            continue

        # Run V3 with trained params on test contigs
        y_true_list: list[float] = []
        y_score_list: list[float] = []

        for test_contig in test_contigs:
            if test_contig not in contig_results:
                continue
            cr = contig_results[test_contig]
            test_result = compute_sequential_modification_probabilities(
                cr, hmm_params=hmm_params, emission_source=emission_source,
                **hierarchical_kwargs,
            )

            # Collect scores at labeled test positions
            test_labels = {
                (c, p): v
                for (c, p), v in labels.items()
                if c == test_contig
            }
            for (_, pos), is_mod in test_labels.items():
                if pos not in test_result.position_stats:
                    continue
                ps = test_result.position_stats[pos]
                if is_mod:
                    # Native reads → y_true=1
                    for i in range(ps.n_native):
                        y_true_list.append(1.0)
                        y_score_list.append(float(ps.p_mod_hmm[i]))
                    # IVT reads → y_true=0
                    for i in range(ps.n_native, ps.n_native + ps.n_ivt):
                        y_true_list.append(0.0)
                        y_score_list.append(float(ps.p_mod_hmm[i]))
                else:
                    # All reads → y_true=0
                    for i in range(ps.n_native + ps.n_ivt):
                        y_true_list.append(0.0)
                        y_score_list.append(float(ps.p_mod_hmm[i]))

        if len(y_true_list) == 0:
            continue

        y_true = np.array(y_true_list, dtype=np.float64)
        y_score = np.array(y_score_list, dtype=np.float64)

        auroc = _manual_auroc(y_true, y_score)
        auprc = _manual_auprc(y_true, y_score)
        per_fold_auroc.append(auroc)
        per_fold_auprc.append(auprc)
        fold_details.append({
            "test_contigs": test_contigs,
            "train_contigs": train_contigs,
            "n_test_reads": len(y_true_list),
            "n_train_positions": n_train_pos + n_train_neg,
            "auroc": auroc,
            "auprc": auprc,
        })

    if len(per_fold_auroc) == 0:
        raise ValueError("No folds completed — insufficient data for CV.")

    return CVResult(
        per_fold_auroc=per_fold_auroc,
        per_fold_auprc=per_fold_auprc,
        mean_auroc=float(np.mean(per_fold_auroc)),
        mean_auprc=float(np.mean(per_fold_auprc)),
        std_auroc=float(np.std(per_fold_auroc)),
        std_auprc=float(np.std(per_fold_auprc)),
        fold_details=fold_details,
    )

CVResult dataclass

CVResult(
    per_fold_auroc: list[float],
    per_fold_auprc: list[float],
    mean_auroc: float,
    mean_auprc: float,
    std_auroc: float,
    std_auprc: float,
    fold_details: list[dict[str, Any]],
)

Cross-validation results.

Persistence

save_hmm_params

save_hmm_params(
    params: HMMParams, path: str | Path
) -> None

Serialize trained HMM parameters to JSON.

Source code in baleen/eventalign/_hmm_training.py
def save_hmm_params(params: HMMParams, path: str | Path) -> None:
    """Serialize trained HMM parameters to JSON."""
    data = {
        "mode": params.mode,
        "n_states": params.n_states,
        "p_stay_per_base": params.p_stay_per_base,
        "init_prob": params.init_prob.tolist(),
        "emission_transform": _serialize_emission_transform(
            params.emission_transform
        ),
        "unmod_emission_beta": list(params.unmod_emission_beta),
        "flank_emission_beta": list(params.flank_emission_beta),
        "mod_emission_beta": list(params.mod_emission_beta),
        "training_species": params.training_species,
        "n_training_positions": params.n_training_positions,
        "n_training_reads": params.n_training_reads,
    }
    Path(path).write_text(json.dumps(data, indent=2))

load_hmm_params

load_hmm_params(path: str | Path) -> HMMParams

Load previously trained HMM parameters from JSON.

Backward-compatible: files without n_states default to 2-state.

Source code in baleen/eventalign/_hmm_training.py
def load_hmm_params(path: str | Path) -> HMMParams:
    """Load previously trained HMM parameters from JSON.

    Backward-compatible: files without ``n_states`` default to 2-state.
    """
    data = json.loads(Path(path).read_text())
    return HMMParams(
        mode=data["mode"],
        n_states=data.get("n_states", 2),
        p_stay_per_base=data["p_stay_per_base"],
        init_prob=np.array(data["init_prob"], dtype=np.float64),
        emission_transform=_deserialize_emission_transform(
            data["emission_transform"]
        ),
        unmod_emission_beta=tuple(data.get("unmod_emission_beta", (2.0, 8.0))),
        flank_emission_beta=tuple(data.get("flank_emission_beta", (3.0, 3.0))),
        mod_emission_beta=tuple(data.get("mod_emission_beta", (8.0, 2.0))),
        training_species=data.get("training_species", []),
        n_training_positions=data.get("n_training_positions", 0),
        n_training_reads=data.get("n_training_reads", 0),
    )