Source code for convoys.export

from __future__ import annotations

from math import floor
from typing import TYPE_CHECKING, Callable, Hashable, Literal

import numpy
import pandas

import convoys.multi

if TYPE_CHECKING:
    pass

__all__ = ["export_cohorts"]


_models: dict[str, Callable[[bool], convoys.multi.MultiModel]] = {
    "kaplan-meier": lambda _: convoys.multi.KaplanMeier(),
    "exponential": lambda ci: convoys.multi.Exponential(mcmc=ci),
    "weibull": lambda ci: convoys.multi.Weibull(mcmc=ci),
    "gamma": lambda ci: convoys.multi.Gamma(mcmc=ci),
    "generalized-gamma": lambda ci: convoys.multi.GeneralizedGamma(mcmc=ci),
}


[docs] def export_cohorts( G: numpy.ndarray, B: numpy.ndarray, T: numpy.ndarray, t_max: int | float | None = None, model: Literal[ "kaplan-meier", "exponential", "weibull", "gamma", "generalized-gamma" ] | convoys.multi.MultiModel = "kaplan-meier", ci: float | None = None, groups: list[Hashable] | None = None, specific_groups: list[Hashable] | None = None, ) -> pandas.DataFrame: """Helper function to fit data using a model and then export the model predictions as a DataFrame. The Dataframe will have the columns "group", "t", and "prediction_value"; if ci is not None, it will also have the columns "ci_low" and "ci_high". "t" will be integers in the range from 0 to t_max (or max(T) if t_max is None). This format enables easy storage in a database and plotting with BI tools. :param G: numpy array of shape :math:`n`, containing integers representing group assignments. :param B: numpy array of shape :math:`n`, containing booleans representing whether or not the subject 'converted' at time delta :math:`t`. :param T: numpy array of shape :math:`n`, containing floats representing the time delta :math:`t` between creation and either conversion or censoring. :param t_max: (optional) max value for x axis :param model: (optional, default is kaplan-meier) model to fit. Can be an instance of :class:`multi.MultiModel` or a string identifying the model. One of 'kaplan-meier', 'exponential', 'weibull', 'gamma', or 'generalized-gamma'. :param ci: confidence interval, value from 0-1, or None (default) if no confidence interval is to be plotted :param groups: list of group labels :param specific_groups: subset of groups to plot See :meth:`convoys.utils.get_arrays` which is handy for converting a Pandas dataframe into arrays `G`, `B`, `T`. """ if model not in _models.keys(): if not isinstance(model, convoys.multi.MultiModel): raise ValueError("model incorrectly specified") if groups is None: groups = list(set(G)) # Set x scale if t_max is None: t_max = max(T) if not isinstance(model, convoys.multi.MultiModel): # Fit model m = _models[model](bool(ci)) m.fit(G, B, T) else: m = model if specific_groups is None: specific_groups = groups if len(set(specific_groups).intersection(groups)) != len(specific_groups): raise ValueError("specific_groups not a subset of groups!") RESULT_LENGTH = floor(t_max + 1) t = numpy.arange(RESULT_LENGTH) data: list[pandas.DataFrame] = [] for group in specific_groups: j = groups.index(group) # matching index of group if ci is not None: result = m.predict_ci(j, t, ci=ci) result_df = pandas.DataFrame( data=result, columns=["prediction_value", "ci_low", "ci_high"] ) else: result = m.predict(j, t) result_df = pandas.DataFrame(data=result, columns=["prediction_value"]) result_df["t"] = t result_df["group"] = group data.append(result_df) unioned_data = pandas.concat(data) return unioned_data