Source code for convoys.single

import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import numpy
import scipy.stats  # type: ignore[import-untyped]

if TYPE_CHECKING:
    from numpy.typing import ArrayLike

__all__ = ["KaplanMeier"]


class SingleModel(ABC):
    @abstractmethod
    def fit(self, B: "ArrayLike", T: "ArrayLike") -> None:
        raise NotImplementedError("Need to implement fit")

    @abstractmethod
    def predict(self, t: "ArrayLike") -> numpy.ndarray:
        raise NotImplementedError("Need to implement predict")

    @abstractmethod
    def predict_ci(self, t: "ArrayLike", ci: float) -> numpy.ndarray:
        raise NotImplementedError("Need to implement predict_ci")


[docs] class KaplanMeier(SingleModel): """Implementation of the Kaplan-Meier nonparametric method."""
[docs] def fit(self, B: "ArrayLike", T: "ArrayLike") -> None: """Fits the model :param B: numpy array of shape :math:`n`, containing booleans representing whether or not the subject 'converted' at time delta :math:`t`. :param T: numpy array of shape :math:`n`, containing floats representing the time delta :math:`t` between creation and either conversion or censoring. """ if not isinstance(B, numpy.ndarray): B = numpy.array(B) if not isinstance(T, numpy.ndarray): T = numpy.array(T) # See https://www.math.wustl.edu/~sawyer/handouts/greenwood.pdf BT = [ (b, t) for b, t in zip(B, T, strict=False) if t >= 0 and 0 <= float(b) <= 1 ] if len(BT) < len(B): n_removed = len(B) - len(BT) warnings.warn( "Warning! Removed %d/%d entries from inputs where " "T < 0 or B not 0/1" % (n_removed, len(B)), stacklevel=2, ) B, T = ([z[i] for z in BT] for i in range(2)) n = len(T) self._ts = [0.0] self._ss = [1.0] self._vs = [0.0] sum_var_terms = 0.0 prod_s_terms = 1.0 for t, b in sorted(zip(T, B, strict=False)): d = float(b) self._ts.append(t) prod_s_terms *= 1 - d / n self._ss.append(prod_s_terms) if d == n == 1: sum_var_terms = float("inf") else: sum_var_terms += d / (n * (n - d)) if sum_var_terms > 0: self._vs.append(1 / numpy.log(prod_s_terms) ** 2 * sum_var_terms) else: self._vs.append(0) n -= 1 # Just prevent overflow warning when computing the confidence interval eps = 1e-9 self._ss_clipped = numpy.clip(self._ss, eps, 1.0 - eps)
[docs] def predict(self, t: "ArrayLike") -> numpy.ndarray: """Returns the predicted values.""" t = numpy.array(t) res = numpy.zeros(t.shape) for indexes, value in numpy.ndenumerate(t): j = numpy.searchsorted(self._ts, value, side="right") - 1 if j >= len(self._ts) - 1: # Make the plotting stop at the last value of t res[indexes] = float("nan") else: res[indexes] = 1 - self._ss[j] return res
[docs] def predict_ci(self, t: "ArrayLike", ci: float = 0.8) -> numpy.ndarray: """Returns the predicted values with a confidence interval.""" t = numpy.array(t) res = numpy.zeros(t.shape + (3,)) for indexes, value in numpy.ndenumerate(t): j = numpy.searchsorted(self._ts, value, side="right") - 1 if j >= len(self._ts) - 1: # Make the plotting stop at the last value of t res[indexes] = [float("nan")] * 3 else: z_lo, z_hi = scipy.stats.norm.ppf([(1 - ci) / 2, (1 + ci) / 2]) res[indexes] = ( 1 - self._ss[j], 1 - numpy.exp( -numpy.exp( numpy.log(-numpy.log(self._ss_clipped[j])) + z_hi * self._vs[j] ** 0.5 ) ), 1 - numpy.exp( -numpy.exp( numpy.log(-numpy.log(self._ss_clipped[j])) + z_lo * self._vs[j] ** 0.5 ) ), ) return res