Source code for aiqclib.train.models.support_vector_machine

"""
This module provides a Support Vector Machine (SVM) model wrapper, inheriting
from `aiqclib.common.base.scikit_learn_model_base.SklearnModelBase`.

It facilitates training, prediction, and evaluation of an SVM classifier using
Polars DataFrames, converting them to Pandas for compatibility with the
`sklearn` library.
"""

from typing import Dict, Any

from sklearn.svm import SVC as SklearnSVC

from aiqclib.common.base.config_base import ConfigBase
from aiqclib.common.base.scikit_learn_model_base import SklearnModelBase


[docs] class SupportVectorMachine(SklearnModelBase): """ A Support Vector Machine (SVM) model wrapper class for training and testing. Inherits from :class:`aiqclib.common.base.scikit_learn_model_base.SklearnModelBase` to reuse common Scikit-Learn API logic. Features include: - Automatic application of ``model_params`` from the YAML config, if defined; otherwise, uses default hyperparameters. - Uses :class:`sklearn.svm.SVC`. - Enforces ``probability=True`` by default to support the generation of model-scores tables and ROC/PR curves used in the parent class. - Uses a linear kernel by default. .. note:: Standard SVM implementations (especially with a linear kernel) typically do not support the ``n_jobs`` parameter directly, as parallelization is often handled by underlying BLAS libraries. """ expected_class_name: str = "SupportVectorMachine" short_name: str = "SVM" def __init__(self, config: ConfigBase) -> None: """ Initialize the SVM model with default or user-specified parameters. :param config: A configuration object providing model parameters. :type config: ConfigBase """ super().__init__(config=config) self.model_params: Dict[str, Any] = { "C": 1.0, "kernel": "linear", "probability": True, # Required for predict_proba used in base class "tol": 1e-3, "max_iter": 200, "random_state": None, "class_weight": "balanced", } # Update model parameters with config step parameters model_params = self.config.get_model_params( self.expected_class_name, self.short_name ) self.model_params.update(model_params) self.allow_na = False def _get_model_class(self) -> Any: """ Return the Scikit-Learn SVC class. This method is an abstract method implementation required by :class:`aiqclib.common.base.scikit_learn_model_base.SklearnModelBase`. :returns: The Scikit-Learn SVC class. :rtype: Any """ return SklearnSVC