"""
This module provides the KFoldValidation class, an implementation of k-fold cross-validation
for model training and evaluation. It extends ValidationBase to perform iterative model
building and testing across defined data folds, accumulating performance reports.
"""
from typing import Optional, List, Dict
import copy
import polars as pl
from aiqclib.common.base.config_base import ConfigBase
from aiqclib.train.step2_validate_model.validate_base import ValidationBase
[docs]
class KFoldValidation(ValidationBase):
"""
A subclass of :class:`ValidationBase` that performs k-fold cross-validation
on training sets.
This class iterates over the specified number of folds, trains
(builds) the model on all folds except one, then tests it on the
held-out fold. Results are accumulated in :attr:`reports`.
"""
expected_class_name: str = "KFoldValidation"
def __init__(
self,
config: ConfigBase,
training_sets: Optional[Dict[str, pl.DataFrame]] = None,
) -> None:
"""
Initialize the k-fold validation process.
:param config: A training configuration object containing
model parameters, file paths, and other
validation settings.
:type config: ConfigBase
:param training_sets: A dictionary where keys are target names and values are
Polars DataFrames of labeled data. Each DataFrame must
contain a column named ``k_fold`` indicating the fold
assignment for each row. Defaults to None.
:type training_sets: Optional[Dict[str, pl.DataFrame]]
"""
super().__init__(config=config, training_sets=training_sets)
#: The default number of folds if none is specified in the config.
self.default_k_fold: int = 10
self.drop_cols = [
"k_fold",
"row_id",
"platform_code",
"profile_no",
"observation_no",
]
[docs]
def get_k_fold(self) -> int:
"""
Retrieve the number of folds to use for cross-validation from
the ``validate`` section of the YAML config, or fall back
to :attr:`default_k_fold`.
:return: The number of folds for k-fold cross-validation.
:rtype: int
"""
return (
self.config.get_step_params("validate").get("k_fold", self.default_k_fold)
or self.default_k_fold
)
[docs]
def validate(self, target_name: str) -> None:
"""
Conduct k-fold cross-validation for the given target name,
storing model objects, test results, and model-scores tables in
:attr:`models`, :attr:`reports`, and :attr:`model_scores`.
For each fold out of :meth:`get_k_fold`:
1. Reload or re-initialize the model using :meth:`load_base_model`.
2. Set ``base_model.k`` to the fold index.
3. Build the model using all training data except rows in the current fold.
4. Test the model on the held-out fold.
5. Accumulate test results and model-scores tables.
:param target_name: The identifier for which target dataset to validate,
referring to the corresponding DataFrame within
:attr:`training_sets`.
:type target_name: str
"""
self.models[target_name] = []
reports: List[pl.DataFrame] = []
model_scores: List[pl.DataFrame] = []
k_fold: int = self.get_k_fold()
for k in range(k_fold):
self.load_base_model()
current_fold_model = copy.deepcopy(self.base_model)
current_fold_model.k = k + 1
current_fold_model.training_set = (
self.training_sets[target_name]
.filter(pl.col("k_fold") != (k + 1))
.drop(self.drop_cols)
)
current_fold_model.build()
self.models[target_name].append(current_fold_model)
current_fold_model.test_set = (
self.training_sets[target_name]
.filter(pl.col("k_fold") == (k + 1))
.drop(self.drop_cols)
)
current_fold_model.test()
reports.append(current_fold_model.report)
if current_fold_model.model_score is not None:
model_scores.append(current_fold_model.model_score)
self.reports[target_name] = pl.concat(reports)
if model_scores:
self.model_scores[target_name] = pl.concat(model_scores)