Survival Analysis

How to construct survival tables from duration tables

The survival table is the workhorse of univariate survival analysis; it's the last step before estimating survival and cumulative hazard curves and running hypothesis tests. I showed previously how to go from an event log to a duration table, but I skipped the survival table step by using the Python packages lifelines and scikit-survival to get survival and hazard curve estimates. In this post, I'll go back and show to do the middle step of building the survival table from a duration table.

Once we have the survival table, it's pretty easy to compute survival and cumulative hazard curve estimates, so I'll show that too.

Schematic of data objects in univariate survival analysis
The sequence of data and analysis artifacts in univariate survival analysis. Image by author.


A duration table refresher

To illustrate building a duration table from an event log, I used the Retail Rocket dataset from Kaggle. This dataset is an event log where each of the 2.8 million rows in this event log represents a user's interaction with a product on the Retail Rocket website. We're interested in the duration of time until a user's first completed transaction.

The duration table I computed previously for this dataset looks like this:

1
2
3
4
import pandas as pd

durations = pd.read_parquet("data/retailrocket_durations.parquet")
durations.head()

                       entry_time endpoint_time          final_obs_time  endpoint_observed  duration_days
visitorid                                                                                                
0         2015-09-11 20:49:49.439           NaT 2015-09-18 02:59:47.788              False       6.256925
1         2015-08-13 17:46:06.444           NaT 2015-09-18 02:59:47.788              False      35.384506
2         2015-08-07 17:51:44.567           NaT 2015-09-18 02:59:47.788              False      41.380593
3         2015-08-01 07:10:35.296           NaT 2015-09-18 02:59:47.788              False      47.825839
4         2015-09-15 21:24:27.167           NaT 2015-09-18 02:59:47.788              False       2.232878

Each row represents a Retail Rocket user; the important columns are endpoint_observed, which indicates whether the user made a transaction during the observation window, and duration_days, which is the time from the customer's first visit to their first transaction, if applicable, or the censoring time, which is the most recent timestamp in the raw event log. The basic intuition being that for users who still haven't completed a transaction, we know the time-to-transaction is at least as long as the duration we have observed so far.

What's a survival table?

A survival table shows many study subjects still “at-risk” of experiencing the outcome event at each duration and how many did experience the event at each duration.1 It also shows how many subjects were censored, although the definition is a little trickier. The censored subjects listed in each row are those whose durations were greater than or equal to that row and less than the next row's duration.

Here's the survival table for the duration table above; this is where we want to end up.

               at_risk  num_obs  events  censored
duration_days                                    
0              1407580      118     118         0
1              1407462    14536    9451      5085
2              1392926     5860     379      5481
3              1387066     9605     172      9433
4              1377461    11092     133     10959
...                ...      ...     ...       ...
124             156580    11086       1     11085
125             145494    10196       1     45118
129             100375    11691       2     11689
130              88684     9011       1      9010
131              79673     6990       1     79672

[124 rows x 4 columns]

Let's look at the fourth row, for more concrete intuition. The at_risk column shows that 1,387,066 users had not yet made a transaction two days after they entered the system.2 The num_obs column shows that of those, 9,605 had a duration between 2 and 3 days; 172 users experienced the outcome event, i.e. made a transaction on the 3rd day, and 9,433 were censored.

Let's look more closely at the censored column. Notice the row with duration_days of 125. That row shows that 10,196 users had a recorded duration of 125 days, but the number of censored users is 45,118—how can that be possible?

Remember what I said above about the censored count including subjects whose censoring time is at least as long as the current row's duration and less than the next row's duration. So those 45,118 censored users have a duration of at least 125 days and less than 129 days.

So this is where we want to end up; how do we get there?

From durations to survival

The first step is to count the number of subjects who experienced the event at each duration and the total number of subjects at each duration. This step also establishes the row index for the final survival table; to make the row index nice and clean, I first round the duration days up to the next whole number.3

1
2
3
4
5
import numpy as np
durations['duration_days'] = np.ceil(durations['duration_days']).astype(int)

grp = durations.groupby('duration_days')
survival = pd.DataFrame({"num_obs": grp.size(), "events": grp["endpoint_observed"].sum()})

The number of subjects at risk in each row is the complement of the cumulative sum of subjects with shorter recorded durations. That's a mouthful and the code is equally non-intuitive, but it's easier to see in the output.

1
2
3
4
num_subjects = len(durations)
prior_count = survival["num_obs"].cumsum().shift(1, fill_value=0)
survival.insert(0, "at_risk", num_subjects - prior_count)
print(survival.head())

               at_risk  num_obs  events
duration_days                          
0              1407580      118     118
1              1407462    14536    9451
2              1392926     5860     379
3              1387066     9605     172
4              1377461    11092     133

The total number of users in the system is 1,407,580, so this is the number at-risk at time zero. Of those users, 118 completed a transaction right at time 0 (i.e. their first product interaction was a transaction), so the number of users at risk of transacting on the first day—excluding time 0 itself—is the difference: 1,407,462.

14,536 users had a recorded duration of 1 day, either because they made their first transaction or they were censored in that time (because they entered the system on the last day of the observation window). So the number of users at risk of making a transaction after one full day is

\[ 1,407,580 - (118 + 14,536) = 1,392,926 \]

which is reflected in the third row.

The survival table only has rows for durations with at least one observed event, so we can remove the rows with zero events; these are durations where users were lost to observation only.

1
survival = survival.loc[survival["events"] > 0]

The number of censored subjects at each duration is the number of subjects at risk minus the number at risk in the next duration minus the number who experienced the event at the current duration.

1
2
3
4
5
6
7
survival["censored"] = (
    survival["at_risk"]
    - survival["at_risk"].shift(-1, fill_value=0)
    - survival["events"]
)

print(survival)

               at_risk  num_obs  events  censored
duration_days                                    
0              1407580      118     118         0
1              1407462    14536    9451      5085
2              1392926     5860     379      5481
3              1387066     9605     172      9433
4              1377461    11092     133     10959
...                ...      ...     ...       ...
124             156580    11086       1     11085
125             145494    10196       1     45118
129             100375    11691       2     11689
130              88684     9011       1      9010
131              79673     6990       1     79672

[124 rows x 4 columns]

This seems unnecessarily complicated at first; why not just subtract events from num_obs to get the number of censored subjects?

Remember that the censored column should include subjects whose censored duration is greater than or equal to the current duration and less than the next duration. Because we've removed rows with no observed events, just subtracting the two columns would incorrectly ignore subjects censored at durations when nobody made a transaction.

Estimating survival and hazard functions

A common goal for univariate survival analysis is to plot a Kaplan-Meier estimate of the survival curve. The Kaplan-Meier estimator is

\[ \widehat{S}(t) = \prod_{i: t_i \leq t} 1 - \frac{d_i}{n_i} \] where \(d_i\) is the number of observed events at time \(t_i\) and \(n_i\) is the number of subjects at risk at \(t_i\). Computing this is easy once we have the survival table, using the Pandas Series cumprod method.

1
2
inverse_hazard = 1 - survival["events"] / survival["at_risk"]
survival["survival_proba"] = inverse_hazard.cumprod()

Survival estimate for Retail Rocket data
Kaplan-Meier estimate of the survival curve for Retail Rocket time-to-first-transaction I've omitted the plot code for brevity, but you can find it in the Crosstab Kite gists repo and can read a detailed description of the plot in this post.


Because the fraction of users in this demo who make a transaction is so small (but realistic!), it can be helpful to flip the survival probability upside down, and call it a conversion rate (officially, the cumulative incidence function)

1
survival["conversion_pct"] = 100 * (1 - survival["survival_proba"])
Conversion rate curve estimate for Retail Rocket data
Kaplan-Meier estimate of conversion rate over time, also known as cumulative incidence. The y-axis is percentage points, so 0.9 means 0.9%, or 0.009. Plot code can be found here.


The cumulative hazard function is another important survival analysis output. The Nelson-Aalen estimator is:

\[ \widehat{H}(t) = \sum_{i: t_i \leq t} \frac{d_i}{n_i} \]

This one is even easier to compute from the survival table; it's a one-liner:

1
survival['cumulative_hazard'] = (survival['events'] / survival['at_risk']).cumsum()

Cumulative hazard estimate for Retail Rocket data
Nelson-Aalen estimate of cumulative hazard of a Retail Rocket user completing a transaction over the course of their account life. Plot code can be found here.


Final thoughts

Our final survival table plus univariate model outputs looks like this:

               at_risk  num_obs  events  censored  survival_proba  conversion_pct  cumulative_hazard
duration_days                                                                                       
0              1407580      118     118         0        0.999916        0.008383           0.000084
1              1407462    14536    9451      5085        0.993202        0.679819           0.006799
2              1392926     5860     379      5481        0.992932        0.706843           0.007071
3              1387066     9605     172      9433        0.992808        0.719156           0.007195
4              1377461    11092     133     10959        0.992713        0.728742           0.007291
...                ...      ...     ...       ...             ...             ...                ...
124             156580    11086       1     11085        0.991275        0.872488           0.008740
125             145494    10196       1     45118        0.991268        0.873169           0.008747
129             100375    11691       2     11689        0.991249        0.875145           0.008767
130              88684     9011       1      9010        0.991237        0.876262           0.008779
131              79673     6990       1     79672        0.991225        0.877506           0.008791

[124 rows x 7 columns]

As I showed in the previous article about building duration tables, packages like lifelines and scikit-survival have functions to compute the Kaplan-Meier and Nelson-Aalen outputs, but it's helpful to know how to compute the intermediate survival table, for a couple of reasons. For one, the survival table is also the basis for hypothesis tests about group differences in survival or hazard rates, and it's important to have a strong intuition about how those tests work when applying them. Second, sometimes we want to compute the survival table without using 3rd party packages, i.e. when working in SQL.

Notes & references

  1. Survival analysis concepts have some of the worst names in the business. Names like survival, hazard, and at-risk imply the outcome of interest is bad and suggest that all subjects will eventually experience the outcome, neither of which is true in general. For the Retail Rocket example in this post, conversions are good and most users never complete a transaction. I think we're stuck with the bad names, though, so try to remember that “at-risk” means at risk for completing a first transaction, and “survival” means a user has not yet made any transactions.

  2. I rounded the duration days up to the next whole number, so the survival table would have a countable row index. So, for example, a duration_days entry of 3 means the user's true duration was in the interval (2, 3].

  3. The code in this article should work with Python 3.8, numpy 1.21.0, and pandas 1.3.0.

  4. Listing image by Murray Campbell on Unsplash.