How to use PyTorch LSTMs for time series regression

Most intros to LSTM models use natural language processing as the motivating application, but LSTMs can be a good option for multivariable time series regression and classification as well. Here’s how to structure the data and model to make it work.

time series

Brian Kent


October 27, 2021

Many machine learning applications that I’ve come across lately are time series regression tasks, where I want to predict a target variable from several input time series.

I wanted to try LSTM models with these kinds of problems but found it tough to get started. Most LSTM tutorials focus on natural language processing, to the point where it can seem like LSTMs only work with text data. Searching for “LSTM time series” does return some hits, but they’re…not great.

So here’s my attempt; this article shows how to use PyTorch LSTMs for regression with multiple input time series. In particular, I’ll show how to forecast a target time series but once you have the basic data and model structure down, it’s not hard to adapt LSTMs to other types of supervised learning. Here’s the game plan:

  1. Load, visualize, and preprocess the data
  2. Define PyTorch Dataset and DataLoader objects
  3. Define an LSTM regression model
  4. Train and evaluate the model

In the interest of brevity, I’m going to skip lots of things. Most obviously, what’s an LSTM? For that, I suggest starting with the PyTorch tutorials, Andrej Karpathy’s intro to RNNs, and Christopher Olah’s intro to LSTMs. More advanced readers might be wondering:

All good questions…for another article. In the meantime, please see our Github repo for a Jupyter notebook version of the code snippets below.


Our goal in this demo is to forecast air quality in Austin—specifically, 2.5-micron particulate matter (PM2.5)—from sensors around the state of Texas.

Why would we do this, when there are plenty of PM2.5 sensors in Austin? Maybe we don’t want to buy a sensor of our own but we have a friend who will let us borrow one for a few weeks to collect training data. Or maybe we need a stand-in for the official EPA sensors when they go offline, which seems to happen often.1


The data come from Purple Air, which sells sensors and makes (participating) customers’ data available for download. I downloaded seven weeks of this data from six sensors around the state of Texas.

Locations of the Purple Air sensors used in this article. We’ll create a model that forecasts the PM2.5 reading at the red sensor in Austin based on the data at the blue sensors around the rest of Texas.

For this demo, I’ve already preprocessed the data to align and sort the timestamps and interpolate a small number of missing values. Please see the script in the Github repo for the details. Let’s load the data and visualize it.2

import pandas as pd

df = pd.read_csv("processed_pm25.csv", index_col="created_at")
                           Del Rio  McAllen  Midlothian  Midland  Houston  \
2021-09-01 00:00:00+00:00     5.37    23.39        4.79     7.57     7.60   
2021-09-01 00:02:00+00:00     5.95    23.02        4.60     8.50     7.78   
2021-09-01 00:04:00+00:00     5.84    25.49        5.70     8.07     7.64   
2021-09-01 00:06:00+00:00     7.07    25.22        5.18     8.66     7.69   
2021-09-01 00:08:00+00:00     5.18    23.16        5.29     7.98     7.74   
...                            ...      ...         ...      ...      ...   
2021-10-21 23:50:00+00:00     8.31     2.58        9.18    12.62     7.80   
2021-10-21 23:52:00+00:00     8.84     3.02       10.34    11.37     7.05   
2021-10-21 23:54:00+00:00     8.57     2.37       10.28    11.82     6.93   
2021-10-21 23:56:00+00:00     7.96     3.07       10.37    11.07     6.82   
2021-10-21 23:58:00+00:00     7.72     2.88       10.31    12.08     7.22   

2021-09-01 00:00:00+00:00    8.32  
2021-09-01 00:02:00+00:00    8.93  
2021-09-01 00:04:00+00:00   10.10  
2021-09-01 00:06:00+00:00    9.24  
2021-09-01 00:08:00+00:00    9.02  
...                           ...  
2021-10-21 23:50:00+00:00   13.14  
2021-10-21 23:52:00+00:00   12.47  
2021-10-21 23:54:00+00:00   12.34  
2021-10-21 23:56:00+00:00   12.26  
2021-10-21 23:58:00+00:00   12.04  

[36720 rows x 6 columns]

The columns represent sensors and rows represent (sorted) timestamps. The values are PM2.5 readings, measured in micrograms per cubic meter.3

Plotting all six time series together doesn’t reveal much because there are a small number of short but huge spikes. The second plot is zoomed in to a y-axis range of [0, 60]; it shows clear long-run correlations between the sensors but lots of short-run variation both between and within the series. In other words, an interesting dataset!

Pardon a bit of Plotly styling boilerplate up front.

import as px
import plotly.graph_objects as go
import as pio
pio.templates.default = "plotly_white"

plot_template = dict(
        "font_size": 18,
        "xaxis_title_font_size": 24,
        "yaxis_title_font_size": 24})

fig = px.line(df, labels=dict(
    created_at="Date", value="PM2.5 (ug/m3)", variable="Sensor"
  template=plot_template, legend=dict(orientation='h', y=1.02, title_text="")