Skip to content

anaeim/deep-stock-price-prediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

deep-time-series-forecasting-models-for-stock-price-prediction

Alt text

Intro

A collection of prominent and SOTA time-series forecasting models, including deep, uni-variate, multi-variate and ensemble models.

Installation

To get started, you'll need Python and pip installed.

  1. Clone the Git repository
git clone https://github.com/anaeim/deep-stock-price-prediction.git
  1. Navigate to the project directory
cd deep-stock-price-prediction
  1. Create directories for the stock data and the trained models
    You can download Apple (AAPL) Historical Stock Data from this Kaggle web page and Tesla (TSLA) stock data from this Kaggle web page.
mkdir data
mkdir model_keeper
  1. Install the requirements
pip install -r requirements.txt

Models

  • uni-variate LSTM
  • Prophet
  • NeuralProphet
  • multi-variate LSTM
  • XGBoost
  • LightGBM
  • ensemble models:
    • XGBoost + multi-variate LSTM
    • XGBoost + LightGBM
    • LightGBM + multi-variate LSTM
    • XGBoost + LightGBM + multi-variate LSTM

Training

python train.py --dataset AAPL \
    --ml-model ensemble_XGBoost_lstm_multivariate \
    --test-size 0.35 \
    --time-stamp 100 \
    --epoch 1000 \
    --batch-size 84 \
    --enable-save-model

The meaning of the flags:

  • --dataset-path: the directory that contains the dataset
  • --ml-model: the Machine Learning (ML) model that we use for time-series forecasting
  • --time-stamp: the time stamp to create windowed datasets
  • --test-size, --epochs, and --batch-size are the proportion of the dataset to include in the test split, number of epochs and number of training examples in each iteration of the model, respectively.
  • enable-save-model: to save the trained model into the model_keeper directory

Prediction

python predict.py --dataset AAPL \
    --ml-model ensemble_XGBoost_lstm_multivariate