Survival Analysis

How to plot survival curves with Plotly and Altair

All three of the major Python survival analysis packages—convoys, lifelines, and scikit-survival—show how to plot survival curves with Matplotlib. In some cases, they bake Matplotlib-based plots directly into trained survival model objects, to enable convenient one-liner plot functions.

The downside of this convenience is that the code is hidden, so it's harder to customize or use a different library. In this article, I'll show how to plot survival curves from scratch with both Altair and Plotly. All of the following code lives in a live, executable Jupyter notebook on Binder and the source file can be found in the Crosstab Kite gists repo.

Survival curves describe the probability that a subject of study will “survive” past a given duration of time. This article assumes you're familiar with the concept; if this is the first you've heard of it, lifelines and scikit-survival both have excellent explanations.

Lifelines introduces survival curve estimation with an example about the tenure of political leaders, using a duration table dataset that's included in the package. Each row represents a head of state; the columns of interest are duration, which is the length of each leader's tenure, and observed, which indicates whether the end of each leader's time in office was observed (it would not be observed if that leader died in office or was still in power when the dataset was collected).

1
2
3
4
from lifelines.datasets import load_dd

data = load_dd()
data[['ctryname', 'ehead', 'duration', 'observed']].sample(5, random_state=19)

                      ctryname                ehead  duration  observed
1022                Mauritania  Mustapha Ould Salek         1         1
1565               Switzerland        Ruth Dreifuss         1         1
763                    Ireland      Eamon de Valera         2         1
1722  United States of America         Bill Clinton         8         1
1416                   Somalia    Abdirizak Hussain         3         1

A fitted lifelines Kaplan-Meier model has a method plot_survival_function that uses Matplotlib. It's certainly convenient but by hiding all the logic, it's harder to see how to customize the plot or implement it in a different library.

1
2
3
4
5
6
from lifelines import KaplanMeierFitter

kmf = KaplanMeierFitter()
kmf.fit(durations=data['duration'], event_observed=data['observed'])

kmf.plot_survival_function()  # this is the function we're unpacking
Matplotlib survival curve from Lifelines
Kaplan-Meier estimate of the survival curve for head of state tenure, reproduced from the lifelines documentation. The Y-axis is the probability a political leader's tenure will be greater than the corresponding duration of time on the X-axis.


How to plot a survival curve with Altair

Here's how to generate the same plot from scratch with Altair. There are three things we need to do:

  1. Process the lifelines model output.
  2. Plot the survival curve, as a step function.
  3. Plot the 95% confidence band, as the area between the lower and upper bound step functions.

The fitted lifelines Kaplan-Meier model has two Pandas DataFrames: survival_function_ and confidence_interval_. We need to combine these into a single DataFrame to make the Altair plot. We also need to convert the index into a column so we can reference it as the X-axis.

1
2
3
4
5
df_plot = kmf.survival_function_.copy(deep=True)
df_plot['lower_bound'] = kmf.confidence_interval_['KM_estimate_lower_0.95']
df_plot['upper_bound'] = kmf.confidence_interval_['KM_estimate_upper_0.95']
df_plot.reset_index(inplace=True)
df_plot.head()

   timeline  KM_estimate  lower_bound  upper_bound
0       0.0     1.000000     1.000000     1.000000
1       1.0     0.721792     0.700522     0.741841
2       2.0     0.601973     0.578805     0.624308
3       3.0     0.510929     0.487205     0.534126
4       4.0     0.418835     0.395233     0.442242

Now we can construct Altair plot objects: first the survival curve as a line mark, then the confidence band as an area mark on top of the line. The only trick is that we use interpolate='step-after' in both the line and area marks to create the correct step function.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import altair as alt

line = (
    alt.Chart(df_plot)
    .mark_line(interpolate='step-after')
    .encode(
        x=alt.X("timeline", axis=alt.Axis(title="Duration")),
        y=alt.Y("KM_estimate", axis=alt.Axis(title="Survival probability"))
    )
)

band = line.mark_area(opacity=0.4, interpolate='step-after').encode(
    x='timeline',
    y='lower_bound',
    y2='upper_bound'
)

fig = line + band
fig
Altair version of the survival curve
Kaplan-Meier estimate of the survival curve for political leader tenure. The dataset is from lifelines, the plot was created with Altair.


How to plot a survival curve with Plotly

It's slightly trickier to draw the same plot with Plotly because Plotly's confidence band solution is a bit funky. First, we set up the figure and add the survival curve as a line plot. We specify shape='hv' in the line parameters to get the correct step function.

Note that we don't need to create a plot DataFrame here because we're going to draw each Series as a standalone Plotly trace.

1
2
3
4
5
6
7
8
9
import plotly.graph_objs as go

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=kmf.survival_function_.index, y=kmf.survival_function_['KM_estimate'],
    line=dict(shape='hv', width=3, color='rgb(31, 119, 180)'),
    showlegend=False
))

Next, we add traces for the upper and lower bounds of the confidence band, separately and in that order. The lower bound fills up to the next trace, which seems to be the previous trace defined in the code.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
fig.add_trace(go.Scatter(
    x=kmf.confidence_interval_.index, 
    y=kmf.confidence_interval_['KM_estimate_upper_0.95'],
    line=dict(shape='hv', width=0),
    showlegend=False,
))

fig.add_trace(go.Scatter(
    x=kmf.confidence_interval_.index,
    y=kmf.confidence_interval_['KM_estimate_lower_0.95'],
    line=dict(shape='hv', width=0),
    fill='tonexty',
    fillcolor='rgba(31, 119, 180, 0.4)',
    showlegend=False
))

Finally, we add axis titles and styling, then show the plot. I've omitted some styling here for brevity.

1
2
3
4
5
6
fig.update_layout(
    xaxis_title="Duration",
    yaxis_title="Survival probability"
)

fig.show()
Plotly version of the survival curve
Kaplan-Meier estimate of the survival curve for head of state tenure. The dataset is from lifelines, the plot was created with Plotly.


Check out the live notebook on Binder and the source code in the Github repo, and let me know what you think in the comments below!

Notes & references

  1. Listing image by Abdul A on Unsplash.