Source code for convoys.multi

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, TypeVar

import numpy

from convoys import regression, single

if TYPE_CHECKING:
    from numpy import ndarray
    from numpy.typing import ArrayLike

T_Group = TypeVar("T_Group", "ArrayLike", Hashable)

__all__ = ["KaplanMeier", "Exponential", "Weibull", "Gamma", "GeneralizedGamma"]


class MultiModel(ABC, Generic[T_Group]):
    _base_model_cls: type[regression.RegressionModel | single.SingleModel]

    @abstractmethod
    def fit(self, G: "ArrayLike", B: "ArrayLike", T: "ArrayLike") -> None:
        raise NotImplementedError("Need to implement fit")

    @abstractmethod
    def predict(self, group: T_Group, t: "ArrayLike") -> "ndarray":
        raise NotImplementedError("Need to implement predict")

    @abstractmethod
    def predict_ci(self, group: T_Group, t: "ArrayLike", ci: float) -> "ndarray":
        raise NotImplementedError("Need to implement predict_ci")


class RegressionToMulti(MultiModel["ArrayLike"]):
    _base_model_cls: type[regression.RegressionModel]

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.base_model = self._base_model_cls(*args, **kwargs)

    def fit(self, G: "ArrayLike", B: "ArrayLike", T: "ArrayLike") -> None:
        """Fits the model

        :param G: numpy array of shape :math:`n`, containing integers representing group
            assignments.
        :param B: numpy array of shape :math:`n`, containing booleans representing
            whether or not the subject 'converted' at time delta :math:`t`.
        :param T: numpy array of shape :math:`n`, containing floats representing the
            time delta :math:`t` between creation and either conversion or censoring.
        """
        G = numpy.array(G, dtype=int)
        (n,) = G.shape
        self._n_groups = max(G) + 1
        X = numpy.zeros((n, self._n_groups), dtype=numpy.bool)
        for i, group in enumerate(G):
            X[i, group] = 1
        self.base_model.fit(X, B, T)

    def _get_x(self, group: "ArrayLike") -> "ndarray":
        x = numpy.zeros(self._n_groups)
        g = numpy.array(group)
        x[g] = 1
        return x

    def predict(self, group: "ArrayLike", t: "ArrayLike") -> "ndarray":
        return self.base_model.predict(self._get_x(group), t)

    def predict_ci(self, group: "ArrayLike", t: "ArrayLike", ci: float) -> "ndarray":
        return self.base_model.predict_ci(self._get_x(group), t, ci)

    def rvs(
        self, group: "ArrayLike", *args: Any, **kwargs: Any
    ) -> tuple["ndarray", "ndarray"]:
        return self.base_model.rvs(self._get_x(group), *args, **kwargs)


class SingleToMulti(MultiModel[Hashable]):
    _base_model_cls: type[single.SingleModel]

    def __init__(self, *args: Any, **kwargs: Any):
        self.base_model_init: Callable[[], single.SingleModel] = (
            lambda: self._base_model_cls(*args, **kwargs)
        )

    def fit(self, G: "ArrayLike", B: "ArrayLike", T: "ArrayLike") -> None:
        """Fits the model

        :param G: numpy array of shape :math:`n`, containing integers representing group
            assignments.
        :param B: numpy array of shape :math:`n`, containing booleans representing
            whether or not the subject 'converted' at time delta :math:`t`.
        :param T: numpy array of shape :math:`n`, containing floats representing the
            time delta :math:`t` between creation and either conversion or censoring.
        """
        G = numpy.array(G)
        B = numpy.array(B)
        T = numpy.array(T)
        group2bt: dict[Hashable, list[tuple]] = {}
        for g, b, t in zip(G, B, T, strict=False):
            group2bt.setdefault(g, []).append((b, t))
        self._group2model: dict[Hashable, single.SingleModel] = {}
        for g, BT in group2bt.items():
            self._group2model[g] = self.base_model_init()
            self._group2model[g].fit([b for b, t in BT], [t for b, t in BT])

    def predict(self, group: Hashable, t: "ArrayLike") -> "ndarray":
        return self._group2model[group].predict(t)

    def predict_ci(self, group: Hashable, t: "ArrayLike", ci: float) -> "ndarray":
        return self._group2model[group].predict_ci(t, ci)


[docs] class Exponential(RegressionToMulti): """Multi-group version of :class:`convoys.regression.Exponential`.""" _base_model_cls = regression.Exponential
[docs] class Weibull(RegressionToMulti): """Multi-group version of :class:`convoys.regression.Weibull`.""" _base_model_cls = regression.Weibull
[docs] class Gamma(RegressionToMulti): """Multi-group version of :class:`convoys.regression.Gamma`.""" _base_model_cls = regression.Gamma
[docs] class GeneralizedGamma(RegressionToMulti): """Multi-group version of :class:`convoys.regression.GeneralizedGamma`.""" _base_model_cls = regression.GeneralizedGamma
[docs] class KaplanMeier(SingleToMulti): """Multi-group version of :class:`convoys.single.KaplanMeier`.""" _base_model_cls = single.KaplanMeier