diff --git a/README.md b/README.md index 7125ee8fe6..b3b14e3538 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,7 @@ Here is a list of models built on `Qlib`. - [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/) - [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/) - [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/) +- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/) Your PR of new Quant models is highly welcomed. diff --git a/examples/benchmarks/ADD/README.md b/examples/benchmarks/ADD/README.md new file mode 100644 index 0000000000..999521425f --- /dev/null +++ b/examples/benchmarks/ADD/README.md @@ -0,0 +1,3 @@ +# AdaRNN +* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289). + diff --git a/examples/benchmarks/ADD/requirements.txt b/examples/benchmarks/ADD/requirements.txt new file mode 100644 index 0000000000..1fc2779c0f --- /dev/null +++ b/examples/benchmarks/ADD/requirements.txt @@ -0,0 +1,4 @@ +numpy==1.17.4 +pandas==1.1.2 +scikit_learn==0.23.2 +torch==1.7.0 diff --git a/examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml b/examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml new file mode 100644 index 0000000000..033d4d22e4 --- /dev/null +++ b/examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml @@ -0,0 +1,94 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + - + - + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: ADD + module_path: qlib.contrib.model.pytorch_add + kwargs: + d_feat: 6 + hidden_size: 64 + num_layers: 2 + dropout: 0.1 + dec_dropout: 0.0 + n_epochs: 200 + lr: 1e-3 + early_stop: 20 + batch_size: 5000 + metric: ic + base_model: GRU + gamma: 0.1 + gamma_clip: 0.2 + optimizer: adam + mu: 0.2 + GPU: 0 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index 779a2bc127..24f8870eea 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -56,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 | | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 | | LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 | +| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 | | GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 | | AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 | diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py index b691db1560..fab1af734c 100644 --- a/qlib/contrib/model/__init__.py +++ b/qlib/contrib/model/__init__.py @@ -31,8 +31,9 @@ from .pytorch_tabnet import TabnetModel from .pytorch_sfm import SFM_Model from .pytorch_tcn import TCN + from .pytorch_add import ADD - pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN) + pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD) except ModuleNotFoundError: pytorch_classes = () print("Please install necessary libs for PyTorch models.") diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py new file mode 100644 index 0000000000..234d662999 --- /dev/null +++ b/qlib/contrib/model/pytorch_add.py @@ -0,0 +1,598 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import division +from __future__ import print_function + + +import copy +import math +from typing import Text, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from qlib.contrib.model.pytorch_gru import GRUModel +from qlib.contrib.model.pytorch_lstm import LSTMModel +from qlib.contrib.model.pytorch_utils import count_parameters +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset.processor import CSRankNorm +from qlib.log import get_module_logger +from qlib.model.base import Model +from qlib.utils import get_or_create_path +from torch.autograd import Function + + +class ADD(Model): + """ADD Model + + Parameters + ---------- + lr : float + learning rate + d_feat : int + input dimensions for each time step + metric : str + the evaluate metric used in early stop + optimizer : str + optimizer name + GPU : int + the GPU ID used for training + """ + + def __init__( + self, + d_feat=6, + hidden_size=64, + num_layers=2, + dropout=0.0, + dec_dropout=0.0, + n_epochs=200, + lr=0.001, + metric="mse", + batch_size=5000, + early_stop=20, + base_model="GRU", + model_path=None, + optimizer="adam", + gamma=0.1, + gamma_clip=0.4, + mu=0.05, + GPU=0, + seed=None, + **kwargs + ): + # Set logger. + self.logger = get_module_logger("ADD") + self.logger.info("ADD pytorch version...") + + # set hyper-parameters. + self.d_feat = d_feat + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + self.dec_dropout = dec_dropout + self.n_epochs = n_epochs + self.lr = lr + self.metric = metric + self.batch_size = batch_size + self.early_stop = early_stop + self.optimizer = optimizer.lower() + self.base_model = base_model + self.model_path = model_path + self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.seed = seed + + self.gamma = gamma + self.gamma_clip = gamma_clip + self.mu = mu + + self.logger.info( + "ADD parameters setting:" + "\nd_feat : {}" + "\nhidden_size : {}" + "\nnum_layers : {}" + "\ndropout : {}" + "\ndec_dropout : {}" + "\nn_epochs : {}" + "\nlr : {}" + "\nmetric : {}" + "\nbatch_size : {}" + "\nearly_stop : {}" + "\noptimizer : {}" + "\nbase_model : {}" + "\nmodel_path : {}" + "\ngamma : {}" + "\ngamma_clip : {}" + "\nmu : {}" + "\ndevice : {}" + "\nuse_GPU : {}" + "\nseed : {}".format( + d_feat, + hidden_size, + num_layers, + dropout, + dec_dropout, + n_epochs, + lr, + metric, + batch_size, + early_stop, + optimizer.lower(), + base_model, + model_path, + gamma, + gamma_clip, + mu, + self.device, + self.use_gpu, + seed, + ) + ) + + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + self.ADD_model = ADDModel( + d_feat=self.d_feat, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + dropout=self.dropout, + dec_dropout=self.dec_dropout, + base_model=self.base_model, + gamma=self.gamma, + gamma_clip=self.gamma_clip, + ) + self.logger.info("model:\n{:}".format(self.ADD_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ADD_model))) + + if optimizer.lower() == "adam": + self.train_optimizer = optim.Adam(self.ADD_model.parameters(), lr=self.lr) + elif optimizer.lower() == "gd": + self.train_optimizer = optim.SGD(self.ADD_model.parameters(), lr=self.lr) + else: + raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + + self.fitted = False + self.ADD_model.to(self.device) + + @property + def use_gpu(self): + return self.device != torch.device("cpu") + + def loss_pre_excess(self, pred_excess, label_excess, record=None): + mask = ~torch.isnan(label_excess) + pre_excess_loss = F.mse_loss(pred_excess[mask], label_excess[mask]) + if record is not None: + record["pre_excess_loss"] = pre_excess_loss.item() + return pre_excess_loss + + def loss_pre_market(self, pred_market, label_market, record=None): + pre_market_loss = F.cross_entropy(pred_market, label_market) + if record is not None: + record["pre_market_loss"] = pre_market_loss.item() + return pre_market_loss + + def loss_pre(self, pred_excess, label_excess, pred_market, label_market, record=None): + pre_loss = self.loss_pre_excess(pred_excess, label_excess, record) + self.loss_pre_market( + pred_market, label_market, record + ) + if record is not None: + record["pre_loss"] = pre_loss.item() + return pre_loss + + def loss_adv_excess(self, adv_excess, label_excess, record=None): + mask = ~torch.isnan(label_excess) + adv_excess_loss = F.mse_loss(adv_excess.squeeze()[mask], label_excess[mask]) + if record is not None: + record["adv_excess_loss"] = adv_excess_loss.item() + return adv_excess_loss + + def loss_adv_market(self, adv_market, label_market, record=None): + adv_market_loss = F.cross_entropy(adv_market, label_market) + if record is not None: + record["adv_market_loss"] = adv_market_loss.item() + return adv_market_loss + + def loss_adv(self, adv_excess, label_excess, adv_market, label_market, record=None): + adv_loss = self.loss_adv_excess(adv_excess, label_excess, record) + self.loss_adv_market( + adv_market, label_market, record + ) + if record is not None: + record["adv_loss"] = adv_loss.item() + return adv_loss + + def loss_fn(self, x, preds, label_excess, label_market, record=None): + loss = ( + self.loss_pre(preds["excess"], label_excess, preds["market"], label_market, record) + + self.loss_adv(preds["adv_excess"], label_excess, preds["adv_market"], label_market, record) + + self.mu * self.loss_rec(x, preds["reconstructed_feature"], record) + ) + if record is not None: + record["loss"] = loss.item() + return loss + + def loss_rec(self, x, rec_x, record=None): + x = x.reshape(len(x), self.d_feat, -1) + x = x.permute(0, 2, 1) + rec_loss = F.mse_loss(x, rec_x) + if record is not None: + record["rec_loss"] = rec_loss.item() + return rec_loss + + def get_daily_inter(self, df, shuffle=False): + # organize the train data into daily batches + daily_count = df.groupby(level=0).size().values + daily_index = np.roll(np.cumsum(daily_count), 1) + daily_index[0] = 0 + if shuffle: + # shuffle data + daily_shuffle = list(zip(daily_index, daily_count)) + np.random.shuffle(daily_shuffle) + daily_index, daily_count = zip(*daily_shuffle) + return daily_index, daily_count + + def cal_ic_metrics(self, pred, label): + metrics = {} + metrics["mse"] = -F.mse_loss(pred, label).item() + metrics["loss"] = metrics["mse"] + pred = pd.Series(pred.cpu().detach().numpy()) + label = pd.Series(label.cpu().detach().numpy()) + metrics["ic"] = pred.corr(label) + metrics["ric"] = pred.corr(label, method="spearman") + return metrics + + def test_epoch(self, data_x, data_y, data_m): + x_values = data_x.values + y_values = np.squeeze(data_y.values) + m_values = np.squeeze(data_m.values.astype(int)) + self.ADD_model.eval() + + metrics_list = [] + + daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False) + + for idx, count in zip(daily_index, daily_count): + batch = slice(idx, idx + count) + feature = torch.from_numpy(x_values[batch]).float().to(self.device) + label_excess = torch.from_numpy(y_values[batch]).float().to(self.device) + label_market = torch.from_numpy(m_values[batch]).long().to(self.device) + + metrics = {} + preds = self.ADD_model(feature) + self.loss_fn(feature, preds, label_excess, label_market, metrics) + metrics.update(self.cal_ic_metrics(preds["excess"], label_excess)) + metrics_list.append(metrics) + metrics = {} + keys = metrics_list[0].keys() + for k in keys: + vs = [m[k] for m in metrics_list] + metrics[k] = sum(vs) / len(vs) + + return metrics + + def train_epoch(self, x_train_values, y_train_values, m_train_values): + self.ADD_model.train() + + indices = np.arange(len(x_train_values)) + np.random.shuffle(indices) + + cur_step = 1 + + for i in range(len(indices))[:: self.batch_size]: + if len(indices) - i < self.batch_size: + break + batch = indices[i : i + self.batch_size] + feature = torch.from_numpy(x_train_values[batch]).float().to(self.device) + label_excess = torch.from_numpy(y_train_values[batch]).float().to(self.device) + label_market = torch.from_numpy(m_train_values[batch]).long().to(self.device) + + preds = self.ADD_model(feature) + + loss = self.loss_fn(feature, preds, label_excess, label_market) + + self.train_optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.ADD_model.parameters(), 3.0) + self.train_optimizer.step() + cur_step += 1 + + def log_metrics(self, mode, metrics): + metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()] + metrics = ", ".join(metrics) + self.logger.info(metrics) + + def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid): + stop_steps = 0 + best_score = -np.inf + best_epoch = 0 + + # train + self.logger.info("training...") + self.fitted = True + x_train_values = x_train.values + y_train_values = np.squeeze(y_train.values) + m_train_values = np.squeeze(m_train.values.astype(int)) + + for step in range(self.n_epochs): + self.logger.info("Epoch%d:", step) + self.logger.info("training...") + self.train_epoch(x_train_values, y_train_values, m_train_values) + self.logger.info("evaluating...") + train_metrics = self.test_epoch(x_train, y_train, m_train) + valid_metrics = self.test_epoch(x_valid, y_valid, m_valid) + self.log_metrics("train", train_metrics) + self.log_metrics("valid", valid_metrics) + + if self.metric in valid_metrics: + val_score = valid_metrics[self.metric] + else: + raise ValueError("unknown metric name `%s`" % self.metric) + if val_score > best_score: + best_score = val_score + stop_steps = 0 + best_epoch = step + best_param = copy.deepcopy(self.ADD_model.state_dict()) + else: + stop_steps += 1 + if stop_steps >= self.early_stop: + self.logger.info("early stop") + break + self.ADD_model.before_adv_excess.step_alpha() + self.ADD_model.before_adv_market.step_alpha() + self.logger.info("bootstrap_fit best score: {:.6f} @ {}".format(best_score, best_epoch)) + self.ADD_model.load_state_dict(best_param) + return best_score + + def gen_market_label(self, df, raw_label): + market_label = raw_label.groupby("datetime").mean().squeeze() + bins = [-np.inf, self.lo, self.hi, np.inf] + market_label = pd.cut(market_label, bins, labels=False) + market_label.name = ("market_return", "market_return") + df = df.join(market_label) + return df + + def fit_thresh(self, train_label): + market_label = train_label.groupby("datetime").mean().squeeze() + self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3]) + + def fit( + self, + dataset: DatasetH, + evals_result=dict(), + save_path=None, + ): + label_train, label_valid = dataset.prepare( + ["train", "valid"], + col_set=["label"], + data_key=DataHandlerLP.DK_R, + ) + self.fit_thresh(label_train) + df_train, df_valid = dataset.prepare( + ["train", "valid"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, + ) + df_train = self.gen_market_label(df_train, label_train) + df_valid = self.gen_market_label(df_valid, label_valid) + + x_train, y_train, m_train = df_train["feature"], df_train["label"], df_train["market_return"] + x_valid, y_valid, m_valid = df_valid["feature"], df_valid["label"], df_valid["market_return"] + + evals_result["train"] = [] + evals_result["valid"] = [] + # load pretrained base_model + + if self.base_model == "LSTM": + pretrained_model = LSTMModel() + elif self.base_model == "GRU": + pretrained_model = GRUModel() + else: + raise ValueError("unknown base model name `%s`" % self.base_model) + + if self.model_path is not None: + self.logger.info("Loading pretrained model...") + pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + + model_dict = self.ADD_model.enc_excess.state_dict() + pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict} + model_dict.update(pretrained_dict) + self.ADD_model.enc_excess.load_state_dict(model_dict) + model_dict = self.ADD_model.enc_market.state_dict() + pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict} + model_dict.update(pretrained_dict) + self.ADD_model.enc_market.load_state_dict(model_dict) + self.logger.info("Loading pretrained model Done...") + + self.bootstrap_fit(x_train, y_train, m_train, x_valid, y_valid, m_valid) + + best_param = copy.deepcopy(self.ADD_model.state_dict()) + save_path = get_or_create_path(save_path) + torch.save(best_param, save_path) + if self.use_gpu: + torch.cuda.empty_cache() + + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + index = x_test.index + self.ADD_model.eval() + x_values = x_test.values + preds = [] + + daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False) + + for idx, count in zip(daily_index, daily_count): + batch = slice(idx, idx + count) + x_batch = torch.from_numpy(x_values[batch]).float().to(self.device) + + with torch.no_grad(): + pred = self.ADD_model(x_batch) + pred = pred["excess"].detach().cpu().numpy() + + preds.append(pred) + + r = pd.Series(np.concatenate(preds), index=index) + return r + + +class ADDModel(nn.Module): + def __init__( + self, + d_feat=6, + hidden_size=64, + num_layers=1, + dropout=0.0, + dec_dropout=0.5, + base_model="GRU", + gamma=0.1, + gamma_clip=0.4, + ): + super().__init__() + self.d_feat = d_feat + self.base_model = base_model + if base_model == "GRU": + self.enc_excess, self.enc_market = [ + nn.GRU( + input_size=d_feat, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout, + ) + for _ in range(2) + ] + elif base_model == "LSTM": + self.enc_excess, self.enc_market = [ + nn.LSTM( + input_size=d_feat, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout, + ) + for _ in range(2) + ] + else: + raise ValueError("unknown base model name `%s`" % base_model) + self.dec = Decoder(d_feat, 2 * hidden_size, num_layers, dec_dropout, base_model) + + ctx_size = hidden_size * num_layers + self.pred_excess, self.adv_excess = [ + nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 1)) + for _ in range(2) + ] + self.adv_market, self.pred_market = [ + nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 3)) + for _ in range(2) + ] + self.before_adv_market, self.before_adv_excess = [RevGrad(gamma, gamma_clip) for _ in range(2)] + + def forward(self, x): + x = x.reshape(len(x), self.d_feat, -1) + N = x.shape[0] + T = x.shape[-1] + x = x.permute(0, 2, 1) + + out, hidden_excess = self.enc_excess(x) + out, hidden_market = self.enc_market(x) + if self.base_model == "LSTM": + feature_excess = hidden_excess[0].permute(1, 0, 2).reshape(N, -1) + feature_market = hidden_market[0].permute(1, 0, 2).reshape(N, -1) + else: + feature_excess = hidden_excess.permute(1, 0, 2).reshape(N, -1) + feature_market = hidden_market.permute(1, 0, 2).reshape(N, -1) + predicts = {} + predicts["excess"] = self.pred_excess(feature_excess).squeeze(1) + predicts["market"] = self.pred_market(feature_market) + predicts["adv_market"] = self.adv_market(self.before_adv_market(feature_excess)) + predicts["adv_excess"] = self.adv_excess(self.before_adv_excess(feature_market).squeeze(1)) + if self.base_model == "LSTM": + hidden = [torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2)] + else: + hidden = torch.cat([hidden_excess, hidden_market], -1) + x = torch.zeros_like(x[:, 1, :]) + reconstructed_feature = [] + for i in range(T): + x, hidden = self.dec(x, hidden) + reconstructed_feature.append(x) + reconstructed_feature = torch.stack(reconstructed_feature, 1) + predicts["reconstructed_feature"] = reconstructed_feature + return predicts + + +class Decoder(nn.Module): + def __init__(self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model="GRU"): + super().__init__() + self.base_model = base_model + if base_model == "GRU": + self.rnn = nn.GRU( + input_size=d_feat, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout, + ) + elif base_model == "LSTM": + self.rnn = nn.LSTM( + input_size=d_feat, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout, + ) + else: + raise ValueError("unknown base model name `%s`" % base_model) + + self.fc = nn.Linear(hidden_size, d_feat) + + def forward(self, x, hidden): + x = x.unsqueeze(1) + output, hidden = self.rnn(x, hidden) + output = output.squeeze(1) + pred = self.fc(output) + return pred, hidden + + +class RevGradFunc(Function): + @staticmethod + def forward(ctx, input_, alpha_): + ctx.save_for_backward(input_, alpha_) + output = input_ + return output + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + grad_input = None + _, alpha_ = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = -grad_output * alpha_ + return grad_input, None + + +class RevGrad(nn.Module): + def __init__(self, gamma=0.1, gamma_clip=0.4, *args, **kwargs): + """ + A gradient reversal layer. + This layer has no parameters, and simply reverses the gradient + in the backward pass. + """ + super().__init__(*args, **kwargs) + + self.gamma = gamma + self.gamma_clip = torch.tensor(float(gamma_clip), requires_grad=False) + self._alpha = torch.tensor(0, requires_grad=False) + self._p = 0 + + def step_alpha(self): + self._p += 1 + self._alpha = min( + self.gamma_clip, torch.tensor(2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False) + ) + + def forward(self, input_): + return RevGradFunc.apply(input_, self._alpha)