Source code for aiqclib.classify.step7_concat_datasets.concat_base

"""
Module for concatenating input datasets with machine learning model predictions.

This module provides the :class:`ConcatDatasetsBase` class, which handles the logic
of merging original input features with generated predictions from one or more
targets and saving the result to a Parquet file.
"""

import os
from typing import Dict, Optional

import polars as pl

from aiqclib.common.base.config_base import ConfigBase
from aiqclib.common.base.dataset_base import DataSetBase


[docs] class ConcatDatasetsBase(DataSetBase): """ Abstract base class for concatenating predictions and the original dataset. Inherits from :class:`~aiqclib.common.base.dataset_base.DataSetBase` to ensure configuration consistency. The concatenated dataset, once generated, can be written to Parquet files. """ def __init__( self, config: ConfigBase, input_data: Optional[pl.DataFrame] = None, predictions: Optional[Dict[str, pl.DataFrame]] = None, ) -> None: """ Initialize the dataset concatenation base class. :param config: The configuration object, containing paths and target definitions. :type config: aiqclib.common.base.config_base.ConfigBase :param input_data: A Polars DataFrame providing the full dataset to which predictions will be concatenated, defaults to None. :type input_data: Optional[polars.DataFrame] :param predictions: A dictionary mapping each target to its respective subset of predictions, defaults to None. :type predictions: Optional[Dict[str, polars.DataFrame]] :raises NotImplementedError: If the subclass does not define ``expected_class_name``. :raises ValueError: If the provided YAML config does not match this class's ``expected_class_name``. :return: None :rtype: NoneType """ super().__init__(step_name="concat", config=config) #: The default pattern to use when writing feature files for each target. self.default_file_name: str = "predictions.parquet" #: Output file name to store the concatenated dataset self.output_file_name: str = self.config.get_full_file_name( step_name="concat", default_file_name=self.default_file_name ) self.input_data: Optional[pl.DataFrame] = input_data #: A dict of Polars DataFrames, one per target, containing classification results. self.predictions: Optional[Dict[str, pl.DataFrame]] = predictions self.merged_predictions: Optional[pl.DataFrame] = None
[docs] def merge_predictions(self) -> None: """ Merges the input data with the predictions for each target into a single Polars DataFrame. The method concatenates individual prediction DataFrames (one per target) and then joins them with the original input data based on common identifier columns ('platform_code', 'profile_no', 'observation_no'). The 'label', 'predicted_label', and 'score' columns from each target's predictions are renamed to include the target key (e.g., 'targetA_label', 'targetA_predicted') to avoid name collisions. The result is stored in the :attr:`merged_predictions` attribute. :raises ValueError: If :attr:`predictions` or :attr:`input_data` is None when this method is called. :return: None :rtype: NoneType """ if self.input_data is None: raise ValueError("Member variable 'input_data' must not be empty.") if self.predictions is None: raise ValueError("Member variable 'predictions' must not be empty.") self.merged_predictions = self.input_data.join( pl.concat( [ df.rename( { "label": f"{key}_label", "predicted_label": f"{key}_predicted", "score": f"{key}_score", } ).select( [ "platform_code", "profile_no", "observation_no", f"{key}_label", f"{key}_predicted", f"{key}_score", ] ) for key, df in self.predictions.items() ], how="align", ), on=["platform_code", "profile_no", "observation_no"], )
[docs] def write_merged_predictions(self) -> None: """ Writes the merged predictions DataFrame to a Parquet file. The output directory is created if it does not exist. The file path is determined by :attr:`output_file_name`. :raises ValueError: If :attr:`merged_predictions` is None when this method is called. :return: None :rtype: NoneType """ if self.merged_predictions is None: raise ValueError("Member variable 'merged_predictions' must not be empty.") os.makedirs(os.path.dirname(self.output_file_name), exist_ok=True) self.merged_predictions.write_parquet(self.output_file_name)