Source code for convoys.plotting

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal

import numpy
from matplotlib import pyplot

import convoys.multi

if TYPE_CHECKING:
    from matplotlib.axes import Axes

__all__ = ["plot_cohorts"]


_models: dict[str, Callable[[bool], convoys.multi.MultiModel]] = {
    "kaplan-meier": lambda _: convoys.multi.KaplanMeier(),
    "exponential": lambda ci: convoys.multi.Exponential(mcmc=ci),
    "weibull": lambda ci: convoys.multi.Weibull(mcmc=ci),
    "gamma": lambda ci: convoys.multi.Gamma(mcmc=ci),
    "generalized-gamma": lambda ci: convoys.multi.GeneralizedGamma(mcmc=ci),
}


[docs] def plot_cohorts( G: numpy.ndarray, B: numpy.ndarray, T: numpy.ndarray, t_max: int | float | None = None, model: Literal[ "kaplan-meier", "exponential", "weibull", "gamma", "generalized-gamma" ] | convoys.multi.MultiModel = "kaplan-meier", ci: float | None = None, ax: "Axes" | None = None, plot_kwargs: dict[str, Any] | None = None, plot_ci_kwargs: dict[str, Any] | None = None, groups: list[Hashable] | None = None, specific_groups: list[Hashable] | None = None, label_fmt: str = "%(group)s (n=%(n).0f, k=%(k).0f)", ) -> convoys.multi.MultiModel: """Helper function to fit data using a model and then plot the cohorts. :param G: numpy array of shape :math:`n`, containing integers representing group assignments. :param B: numpy array of shape :math:`n`, containing booleans representing whether or not the subject 'converted' at time t. :param T: numpy array of shape :math:`n`, containing floats representing the time delta between creation and either conversion or censoring. :param t_max: (optional) max value for x axis :param model: (optional, default is kaplan-meier) model to fit. Can be an instance of :class:`multi.MultiModel` or a string identifying the model. One of 'kaplan-meier', 'exponential', 'weibull', 'gamma', or 'generalized-gamma'. :param ci: confidence interval, value from 0-1, or None (default) if no confidence interval is to be plotted :param ax: custom pyplot axis to plot on :param plot_kwargs: extra arguments to pyplot for the lines :param plot_ci_kwargs: extra arguments to pyplot for the confidence intervals :param groups: list of group labels :param specific_groups: subset of groups to plot :param label_fmt: custom format for the labels to use in the legend See :meth:`convoys.utils.get_arrays` which is handy for converting a Pandas dataframe into arrays `G`, `B`, `T`. """ if model not in _models.keys(): if not isinstance(model, convoys.multi.MultiModel): raise ValueError("model incorrectly specified") if groups is None: groups = list(set(G)) if ax is None: ax = pyplot.gca() # Set x scale if t_max is None: _, t_max = ax.get_xlim() t_max = max(t_max, max(T)) if not isinstance(model, convoys.multi.MultiModel): # Fit model m = _models[model](bool(ci)) m.fit(G, B, T) else: m = model if specific_groups is None: specific_groups = groups if len(set(specific_groups).intersection(groups)) != len(specific_groups): raise ValueError("specific_groups not a subset of groups!") # Plot t = numpy.linspace(0, t_max, 1000) # type: ignore[arg-type] _, y_max = ax.get_ylim() # Reset to first color ax.set_prop_cycle(None) # type:ignore[call-overload] for group in specific_groups: j = groups.index(group) # matching index of group n = numpy.sum(G == j) k = numpy.sum(B[G == j]) label = label_fmt % dict(group=group, n=n, k=k) if ci is not None: p_y, p_y_lo, p_y_hi = m.predict_ci(j, t, ci=ci).T merged_plot_ci_kwargs = {"alpha": 0.2} if plot_ci_kwargs is not None: merged_plot_ci_kwargs.update(plot_ci_kwargs) p = ax.fill_between( t, 100.0 * p_y_lo, 100.0 * p_y_hi, **merged_plot_ci_kwargs, # type: ignore[arg-type] ) color = p.get_facecolor()[0] # reuse color for the line else: p_y = m.predict(j, t).T color = None merged_plot_kwargs = {"color": color, "linewidth": 1.5, "alpha": 0.7} if plot_kwargs is not None: merged_plot_kwargs.update(plot_kwargs) ax.plot(t, 100.0 * p_y, label=label, **merged_plot_kwargs) # type: ignore[arg-type] y_max = max(y_max, 110.0 * max(p_y)) ax.set_xlim(0, t_max) ax.set_ylim(0, y_max) ax.set_ylabel("Conversion rate %") ax.grid(True) return m