Source code for ballet.eng.base
from typing import Callable, Optional
import funcy as fy
import numpy as np
import pandas as pd
import sklearn.base
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.validation import check_is_fitted
from sklearn_pandas import DataFrameMapper
from sklearn_pandas import __version__ as sklearn_pandas_version
import ballet.transformer  # avoid circular import
from ballet.exc import BalletError
from ballet.util import get_arr_desc
from ballet.util.typing import OneOrMore, TransformerLike
__all__ = (
    'BaseTransformer',
    'ConditionalTransformer',
    'GroupedFunctionTransformer',
    'GroupwiseTransformer',
    'NoFitMixin',
    'SimpleFunctionTransformer',
    'SubsetTransformer',
)
[docs]class BaseTransformer(NoFitMixin, TransformerMixin, BaseEstimator):
    """Base transformer class for developing new transformers"""
    pass
[docs]class SimpleFunctionTransformer(FunctionTransformer):
    """Transformer that applies a callable to its input
    The callable will be called on the input X in the transform stage,
    optionally with additional arguments and keyword arguments.
    A simple wrapper around :py:class:`FunctionTransformer`.
    Args:
        func: callable to apply
        func_kwargs: keyword arguments to pass
    """
    def __init__(self,
                 func: Callable,
                 func_kwargs: Optional[dict] = None):
        self.func = func
        self.func_kwargs = func_kwargs or {}
        super().__init__(
            func=self.func,
            kw_args=self.func_kwargs)
[docs]class GroupedFunctionTransformer(FunctionTransformer):
    """Transformer that applies a callable to each group of a groupby
    Args:
        func: callable to apply
        func_kwargs: keyword arguments to pass
        groupby_kwargs: keyword arguments to ``pd.DataFrame.groupby``. If
            omitted, no grouping is performed and the function is called on
            the entire DataFrame.
    """
    def __init__(self,
                 func: Callable,
                 func_kwargs: Optional[dict] = None,
                 groupby_kwargs: Optional[dict] = None):
        self.func = func
        self.func_kwargs = func_kwargs or {}
        self.groupby_kwargs = groupby_kwargs or {}
        super().__init__(
            func=func,
            kw_args=self.func_kwargs)
[docs]    def transform(self, X, **transform_kwargs):
        if self.groupby_kwargs:
            call = X.groupby(**self.groupby_kwargs).apply
        else:
            call = X.pipe
        return call(super().transform)
[docs]class GroupwiseTransformer(BaseTransformer):
    """Transformer that does something different for every group
    For each group identified in the training set by the groupby operation,
    a separate transformer is cloned and fit. This is useful to learn
    group-wise transformers that do not leak data between the training and
    test sets. Consider the case of imputing missing values with the mean of
    some group. A normal, pure-pandas implementation, such as
    ``X_te.groupby(by='foo').apply('mean')`` would leak information about
    the test set means, which might differ from the training set means.
    Args:
        transformer: the transformer to apply
            to each group. If transformer is a transformer-like instance (i.e.
            has fit, transform methods etc.), then it is cloned for each group.
            If transformer is a transformer-like class (i.e. instances of
            the class are transformer-like), then it is initialized with no
            arguments for each group. If it is a callable, then it is called
            with no arguments for each group.
        groupby_kwargs: keyword arguments to pd.DataFrame.groupby
        column_selection: column, or list of columns,
            to select after the groupby. Equivalent to
            ``df.groupby(...)[column_selection]``. Defaults to None, i.e. no
            column selection is performed.
        handle_unknown: 'error' or 'ignore', default='error'. Whether to
            raise an error or ignore if an unknown group is encountered during
            transform. When this parameter is set to 'ignore' and an unknown
            group is encountered during transform, the group's values will be
            passed through unchanged.
        handle_error: 'error' or 'ignore', default='error'. Whether to
            raise an error or ignore if an error is raised during transforming
            an individual group. When this parameter is set to 'ignore' and
            an error is raised when calling the transformer's transform
            method on an individual group, the group's values will be passed
            through unchanged.
    Example usage:
        In this example, we create a groupwise transformer that fits a
        separate imputer for each group encountered. For new data points,
        values will be imputed according to the mean of its group on the
        training set, avoiding any data leakage.
        .. code-block:: python
           >>> from sklearn.impute import SimpleImputer
           >>> transformer = GroupwiseTransformer(
           ...     SimpleImputer(strategy='mean'),
           ...     groupby_kwargs = {'level': 'name'}
           ... )
    Raises:
        ballet.exc.BalletError: if handle_unknown=='error' and an unknown group
            is encountered at transform-time.
    """
    def __init__(self,
                 transformer: TransformerLike,
                 groupby_kwargs: dict = None,
                 column_selection: OneOrMore[str] = None,
                 handle_unknown: str = 'error',
                 handle_error: str = 'error'):
        self.transformer = transformer
        self.groupby_kwargs = groupby_kwargs
        self.column_selection = column_selection
        self.handle_unknown = handle_unknown
        self.handle_error = handle_error
    def _make_transformer(self):
        if (
            isinstance(self.transformer, type)
            or callable(self.transformer)
        ):
            return self.transformer()
        else:
            return sklearn.base.clone(self.transformer)
[docs]    def fit(self, X, y=None, **fit_kwargs):
        # validation on inputs
        self.groupby_kwargs_ = self.groupby_kwargs or {}
        if self.handle_unknown not in ['error', 'ignore']:
            raise ValueError(
                f'Invalid value for handle_unknown: {self.handle_unknown}')
        if self.handle_error not in ['error', 'ignore']:
            raise ValueError(
                f'Invalid value for handle_error: {self.handle_error}')
        # Get the groups
        grouper = X.groupby(**self.groupby_kwargs_)
        self.groups_ = set(grouper.groups.keys())
        # Create and fit a transformer for each group
        self.transformers_ = {}
        for group_name, x_group in grouper:
            transformer = self._make_transformer()
            if self.column_selection is not None:
                x_group = x_group[self.column_selection]
            if y is not None:
                # Extract y by integer indexing
                y_group = y[grouper.indices[group_name]]
                transformer.fit(x_group, y_group)
            else:
                transformer.fit(x_group)
            self.transformers_[group_name] = transformer
        return self
[docs]    def transform(self, X, **transform_kwargs):
        check_is_fitted(self, ['groups_', 'transformers_'])
        def _transform(x_group, *args, **kwargs):
            # If the group is not a DataFrame, there are two problems
            # 1. We can't rely on group.name to lookup the right transformer
            # 2. We can't "reassemble" the transformed
            # However, the contract of ``pandas.core.groupby.GroupBy.apply`` is
            # that the input is a DataFrame, so this should never occur.
            if not isinstance(x_group, pd.DataFrame):
                raise NotImplementedError
            group_name = x_group.name
            if self.column_selection is not None:
                x_group = x_group[self.column_selection]
            if group_name in self.transformers_:
                transformer = self.transformers_[group_name]
                try:
                    data = transformer.transform(x_group, *args, **kwargs)
                    # This post-processing step is required because sklearn
                    # transform converts a DataFrame to an array. This is my
                    # best attempt so far to approximate the following:
                    # >>> result = x_group.copy()
                    # >>> result.values = data
                    # which is an error as `values` cannot be set.
                    index = x_group.index
                    columns = x_group.columns
                    return pd.DataFrame(
                        data=data, index=index, columns=columns)
                except Exception:
                    if self.handle_error == 'ignore':
                        return x_group
                    else:
                        raise
            else:
                if self.handle_unknown == 'error':
                    raise BalletError(f'Unknown group: {group_name}')
                elif self.handle_unknown == 'ignore':
                    return x_group
                else:
                    # Unreachable code
                    raise RuntimeError
        return (
            X
            .groupby(**self.groupby_kwargs_)
            .apply(_transform, **transform_kwargs)
        )
[docs]class ConditionalTransformer(BaseTransformer):
    """Transform columns that satisfy a condition during training
    In the fit stage, determines which variables (columns) satisfy the
    condition. In the transform stage, applies the given transformation to
    the satisfied columns. If a second transformation is given, applies the
    second transformation to the complement of the satisfied columns (i.e.
    the columns that fail to satisfy the condition). Otherwise, these
    unsatisfied columns are passed through unchanged.
    Args:
        condition: condition function
        satisfy_transform: transform function for satisfied columns
        unsatisfy_transform: transform function for unsatisfied columns
            (defaults to identity)
    """
    def __init__(
        self,
        condition: Callable,
        satisfy_transform: Callable,
        unsatisfy_transform: Optional[Callable] = None
    ):
        super().__init__()
        self.condition = condition
        self.satisfy_transform = satisfy_transform
        self.unsatisfy_transform = unsatisfy_transform or fy.identity
[docs]    def fit(self, X, y=None, **fit_args):
        # satisfied_columns_ is a bool or array[bool]
        self.satisfied_columns_ = self.condition(X)
        self.unsatisfied_columns_ = np.logical_not(self.satisfied_columns_)
        return self
[docs]    def transform(self, X, **transform_args):
        check_is_fitted(self, ['satisfied_columns_', 'unsatisfied_columns_'])
        if isinstance(X, pd.DataFrame):
            X = X.copy()
            X.loc[:, self.satisfied_columns_] = self.satisfy_transform(
                X.loc[:, self.satisfied_columns_])
            X.loc[:, self.unsatisfied_columns_] = self.unsatisfy_transform(
                X.loc[:, self.unsatisfied_columns_]
            )
            return X
        elif np.ndim(X) == 1:
            return (
                self.satisfy_transform(X)
                if self.satisfied_columns_
                else self.unsatisfy_transform(X)
            )
        elif isinstance(X, np.ndarray):
            X = X.copy().astype('float')
            if self.satisfied_columns_.any():
                mask = np.tile(self.satisfied_columns_, (X.shape[0], 1))
                np.putmask(X, mask, self.satisfy_transform(
                    X[:, self.satisfied_columns_]))
            if self.unsatisfied_columns_.any():
                mask = np.tile(self.unsatisfied_columns_, (X.shape[0], 1))
                np.putmask(X, mask, self.unsatisfy_transform(
                    X[:, self.unsatisfied_columns_]))
            return X
        elif not self.satisfied_columns_:
            # if we wouldn't otherwise have known what to do, we can pass
            # through X if transformation was not necessary anyways
            return self.unsatisfy_transform(X)
        else:
            raise TypeError(
                f'Couldn\'t apply transformer on features in '
                f'{get_arr_desc(X)}.')
[docs]class SubsetTransformer(DataFrameMapper):
    """Transform a subset of columns with another transformer
    Args:
        input:
        transformer:
        alias:
    """
    def __init__(self,
                 input: OneOrMore[str],
                 transformer: TransformerLike,
                 alias: Optional[str] = None):
        self.input = input
        self.transformer = transformer
        self.alias = alias
        super().__init__(
            [(input,
              ballet.transformer.desugar_transformer(transformer),
              {'alias': alias})],
            default=None,
            input_df=True,
            df_out=True,
        )
    if sklearn_pandas_version.startswith('1'):
        def __setstate__(self, state):
            # FIXME bug with sklearn-pandas 1.x that is fixed on 2.x (can
            # delete after upgrade)
            # horrible hack - DataFrameMapper does not call super.__setstate__
            # in 1.x, but we happen to know that its parent is BaseEstimator
            BaseEstimator.__setstate__(self, state)
            DataFrameMapper.__setstate__(self, state)