This repository contains the Pytorch implementation of the Spectral Attention Autoregressive Model (SAAM) proposed in the paper 'Deep Autoregressive Models with Spectral Attention', published in the Pattern Recognition journal.
Please, if you use this code, cite the Pattern Recognition article:
@article{moreno2022deep,
title={Deep autoregressive models with spectral attention},
author={Moreno-Pino, Fernando and Olmos, Pablo M and Art{\'e}s-Rodr{\'\i}guez, Antonio},
journal={Pattern Recognition},
pages={109014},
year={2022},
publisher={Elsevier}
}
(The pre-print is also available at arXiv).
The repository is divided into three parts:
SAAM_LSTM_Embedding
contains an implementation of the SAAM model based in DeepAR model. In this implementation, a LSTM is used to perform the embedding described in the paper.SAAM_Transformer_Embedding
contains an implementation of the SAAM model based in ConvTrans model. In this implementation, a decoder-only mode Transformer is used to perform the embedding described in the paper.Synthetic_Datasets_Exps
contains the code to train and evaluate SAAM (using a LSTM for the embedding) in the synthetic data.
Notice that for using the LSTM or Transformer based model you will have to go to the corresponding folder. The arguments are slightly different for both models, you can check them in detail in SAAM_LSTM_Embedding/train_SAAM_LSTM_Emb.py
and SAAM_Transformer_Embedding/train_SAAM_Transformer_Emb.py
. Some examples of how to run both models on the solar dataset are:
# Example for training SAAM performing the embedding through a LSTM on the solar dataset:
python train_FAAM_ct.py --dataset solar --model-name solar --cuda-device 0 --sampling
# Example for training SAAM performing the embedding through a ConvTrans on the solar dataset:
CUDA_VISIBLE_DEVICES=0 python main.py --path data/solar.csv --outdir solar --dataset solar --enc_len 168 --dec_len 24 --batch-size 128
Some important parameters are:
--dataset
: the available options areele
,traffic
,m4
,wind
andsolar
.--sampling
: whetever to compute p-risk metrics.--path
: path to the dataset location.
A example notebook is contained in Synthetic_Datasets_Exps/Synthetic Dataset Experiment.ipynb
. Running this notebook, predictions on the synthetic dataset with a pre-trained model will be visible.
Datasets used in the paper are publicly available. They can be download from the following sources:
- Electricity dataset: https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014#.
- Traffic dataset: https://archive.ics.uci.edu/ml/datasets/PEMS-SF.
- Solar dataset: https://www.nrel.gov/grid/solar-power-data.html.
- Wind dataset: https://www.kaggle.com/sohier/30-years-of-european-wind-generation.
- M4 Hourly dataset: https://www.kaggle.com/yogesh94/m4-forecasting-competition-dataset.
- For the synthetic dataset, 'Synthetic_Datasets_Exps/dataloaders/dataloader_sin_cos.py' can be checked.
The code has been tested with the following dependencies.
python 3.8.5
torch 1.6.0
matplotlib 3.3.2
numpy 1.19.2
pandas 1.1.5
scikit-learn 0.23.2
Notice that Pytorch introduced several modifications on the FFT functions on the version 1.7.0 (https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/). The model was developed with Pytorch 1.6.0, hence running it in a newer version will cause errors. Code will be updated asap to the new Pytorch specifications.
Fernando Moreno-Pino, Pablo M. Olmos, and Antonio Artés-Rodríguez.
For further information: [email protected]