Source code for aiqclib.classify.step4_select_rows.dataset_all

"""
Module for selecting all data rows from combined Copernicus CTD data.

This module provides the :class:`LocateDataSetAll` class, which extends
:class:`LocatePositionBase` to identify and label data points for machine
learning tasks based on Quality Control (QC) flags.

"""

from typing import Dict, Optional

import polars as pl

from aiqclib.common.base.config_base import ConfigBase
from aiqclib.prepare.step4_select_rows.locate_base import LocatePositionBase


[docs] class LocateDataSetAll(LocatePositionBase): """ A subclass of :class:`LocatePositionBase` that locates all rows from Copernicus CTD data for training or evaluation purposes. This class assigns a default file naming scheme for target rows and uses configuration details (e.g., QC flags) to identify relevant data rows for each target. :cvar expected_class_name: The expected name of the class for validation. :vartype expected_class_name: str """ expected_class_name: str = "LocateDataSetAll" def __init__( self, config: ConfigBase, input_data: Optional[pl.DataFrame] = None, selected_profiles: Optional[pl.DataFrame] = None, ) -> None: """ Initialize the dataset with configuration, an optional input DataFrame, and an optional DataFrame of selected profiles. :param config: A configuration object specifying paths, parameters, and target definitions for locating test data rows. :type config: aiqclib.common.base.config_base.ConfigBase :param input_data: An optional Polars DataFrame containing the full data from which positive and negative rows will be derived. If not provided, it should be set later using :meth:`set_input_data`. :type input_data: polars.DataFrame or None :param selected_profiles: An optional Polars DataFrame containing profiles labeled as positive or negative. If not provided, it should be set later using :meth:`set_selected_profiles`. :type selected_profiles: polars.DataFrame or None """ super().__init__( config=config, input_data=input_data, selected_profiles=selected_profiles ) #: Default file name template for writing target rows (one file per target). self.default_file_name: str = "selected_rows_classify_{target_name}.parquet" #: Dictionary mapping each target name to the corresponding output Parquet file path. self.output_file_names: Dict[str, str] = self.config.get_target_file_names( step_name="locate", default_file_name=self.default_file_name )
[docs] def select_all_rows(self, target_name: str, target_value: Dict) -> None: """ Collect all rows for a specified target by applying flag-based labeling to each record. This method assumes that :attr:`input_data` has been set prior to its call. :param target_name: The name (key) of the target in the configuration's target dictionary. :type target_name: str :param target_value: A dictionary of target metadata, including the relevant QC flag variable name, positive flag values, and negative flag values. :type target_value: Dict :raises ValueError: If :attr:`input_data` is None when this method is called. :raises KeyError: If 'flag' is not present in target_value. """ if self.input_data is None: raise ValueError("Member variable 'input_data' must not be empty.") pos_flag_values = target_value.get("pos_flag_values", [4]) neg_flag_values = target_value.get("neg_flag_values", [1]) flag_var_name = target_value["flag"] self.selected_rows[target_name] = ( self.input_data.with_row_index("row_id", offset=1) .filter(pl.col(flag_var_name).is_in(pos_flag_values + neg_flag_values)) .with_columns( pl.lit(0, dtype=pl.UInt32).alias("profile_id"), pl.lit("").alias("pair_id"), pl.when(pl.col(flag_var_name).is_in(pos_flag_values)) .then(1) .when(pl.col(flag_var_name).is_in(neg_flag_values)) .then(0) .otherwise(None) .alias("label"), ) .select( pl.col("row_id"), pl.col("profile_id"), pl.col("platform_code"), pl.col("profile_no"), pl.col("observation_no"), pl.col("pres"), pl.col(flag_var_name).alias("flag"), pl.col("label"), pl.col("pair_id"), ) )
[docs] def locate_target_rows(self, target_name: str, target_value: Dict) -> None: """ Locate target rows for training or evaluation by calling :meth:`select_all_rows`. This method acts as a wrapper, ensuring all rows are considered for the target based on the provided QC flag. :param target_name: Name of the target variable. :type target_name: str :param target_value: A dictionary of target metadata, including the QC flag variable name used for labeling (e.g., ``{"flag": "TEMP_QC_FLAG"}``). :type target_value: Dict """ self.select_all_rows(target_name, target_value)