Source code for aiqclib.train.models.logistic_regression

"""
This module provides a Logistic Regression model wrapper, inheriting from `aiqclib.common.base.scikit_learn_model_base.SklearnModelBase`.

It facilitates training, prediction, and evaluation of a Scikit-Learn Logistic Regression classifier
using Polars DataFrames.
"""

from typing import Dict, Any

from sklearn.linear_model import LogisticRegression as SklearnLR

from aiqclib.common.base.config_base import ConfigBase
from aiqclib.common.base.scikit_learn_model_base import SklearnModelBase


[docs] class LogisticRegression(SklearnModelBase): """ A Logistic Regression model wrapper class for training and testing. Inherits from :class:`SklearnModelBase` to reuse common Scikit-Learn API logic. Features include: - Automatic application of ``model_params`` from the YAML config, if defined; otherwise, uses default hyperparameters suitable for standard classification tasks. - Uses ``sklearn.linear_model.LogisticRegression``. .. note:: This class sets :attr:`expected_class_name` to ``"LogisticRegression"``. """ expected_class_name: str = "LogisticRegression" short_name: str = "Logit" def __init__(self, config: ConfigBase) -> None: """ Initialize the Logistic Regression model with default or user-specified parameters. :param config: A configuration object providing model parameters. :type config: ConfigBase """ super().__init__(config=config) # Default parameters for Logistic Regression self.model_params: Dict[str, Any] = { "l1_ratio": 0, "C": 1.0, "solver": "lbfgs", "class_weight": "balanced", "max_iter": 200, } # Update model parameters with config step parameters model_params = self.config.get_model_params( self.expected_class_name, self.short_name ) self.model_params.update(model_params) self.allow_na = False def _get_model_class(self) -> Any: """ Return the Scikit-Learn LogisticRegression class. :return: The LogisticRegression class. :rtype: Any """ return SklearnLR