Skip to content

Commit

Permalink
cleaned version
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Risi committed Jun 10, 2021
1 parent 99cb36a commit a2ce190
Show file tree
Hide file tree
Showing 9 changed files with 1,167 additions and 2 deletions.
59 changes: 57 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,57 @@
# dip
Deep Innovation Protection
# Pytorch implementation of Deep Innovation Protection (DIP)

Paper: Risi and Stanley, "Deep Innovation Protection: Confronting the Credit Assignment Problem in Training Heterogeneous Neural Architectures ""
Proceedings of the Thirty-Fith AAAI Conference on Artificial Intelligence (AAAI-2021)

https://arxiv.org/abs/2001.01683


## Prerequisites

The code is partly based on the PyTorch implementation of "World Models" (https://github.com/ctallec/world-models).

Code requieres Python3 and PyTorch (https://pytorch.org). The rest of the requirements are included in the [requirements file](requirements.txt), to install them:
```bash
pip3 install -r requirements.txt
```

## Running the program

The world model is composed of three different components:

1. A Variational Auto-Encoder (VAE)
2. A Mixture-Density Recurrent Network (MDN-RNN)
3. A linear Controller (C), which takes both the latent encoding and the hidden state of the MDN-RNN as input and outputs the agents action

In contrast to the original world model, all three components are trained end-to-end through evolution. To run training:

```bash
python3 main.py
```

To test a specific genome:

```bash
python3 main.py --test best_1_1_G2.p
```

Additional arguments for the training script are:
* **--folder** : The directory to store the training results.
* **--pop-size** : The population size.
* **--threads** : The number of threads used for training or testing.
* **--generations** : The number of generations used for training.
* **--inno** : 0 = Innoviation protection disabled. 1 = Innovation protection enabled.


### Notes
When running on a headless server, you will need to use `xvfb-run` to launch the controller training script. For instance,
```bash
xvfb-run -a -s "-screen 0 1400x900x24 +extension RANDR" -- python3 main.py
```

When running with a discrete VAE, the size of the latent vector is increased to 128 from the 32-dimensional version used for the standard VAE.

## Authors

* **Sebastian Risi**

98 changes: 98 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

import argparse
import sys
from os.path import exists
from os import mkdir
from torch.multiprocessing import Process, Queue
import torch
import numpy as np
from train import RolloutGenerator, T1Solution
from nsga2 import NSGAII
import multiprocessing
#print("version ",multiprocessing.__version__)

torch.set_num_threads(1)

def main(argv):
parser = argparse.ArgumentParser()

parser.add_argument('--pop-size', type=int, default = 10, help='Population size.')

parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')

parser.add_argument('--generations', type=int, default=1000, metavar='N',
help='number of generations to train')

parser.add_argument('--threads', type=int, default=5, metavar='N',
help='threads')

parser.add_argument('--inno', type=int, default=1, metavar='N',
help='0 = no protection, 1 = protection')


parser.add_argument('--test', type=str, default='', metavar='N',
help='0 = no protection, 1 = protection')

parser.add_argument('--folder', type=str, default='results', metavar='N',
help='folder to store results')


parser.add_argument('--top', type=int, default=3, metavar='N',
help='top elites that should be re-evaluated')


parser.add_argument('--elite_evals', type=int, default=10, metavar='N',
help='how many times should the elite be evaluated')


parser.add_argument('--timelimit', type=int, default=500, metavar='N',
help='time limit on driving task')

args = parser.parse_args()

device = 'cpu'

if args.test!='':
to_evaluate = []

t1 = T1Solution("test", 'cpu', 10000000, 1, 0, multi=False)
t1.load_solution(args.test)
to_evaluate.append(t1)
exit()

log_file = open("log.txt", 'a')

for ind in to_evaluate:
if (args.threads == 1):
average = []
print("Evaluting genome ",args.test)
for i in range(1):
f = ind.r_gen.do_rollout(False, True, 0, False)
average += [f]

print(np.average(average), np.std(average) )
else:
print("Evaluating on threads ",args.threads)
pool = multiprocessing.Pool(args.threads)
ind.multi = True
ind.run_solution(pool, 100, early_termination=False, force_eval = True)
avg_f, _, sd = ind.evaluate_solution(100)
print(avg_f, sd)
log_file.write("%f\t%f" % (avg_f, sd))
log_file.flush()

log_file.close()
#print (t1.evaluate_on_test (args.frames) )
exit()

if not exists(args.folder):
mkdir(args.folder)

nsga2 = NSGAII(2, 0.9, 1.0, args.elite_evals, args.top, args.threads, args.timelimit, args.pop_size, args.inno) #mutation rate, crossover rate

nsga2.run(args.pop_size, args.generations, "{0}_{1}_".format(args.inno, args.seed), args.folder ) #pop_size, num_gens

if __name__ == '__main__':

main(sys.argv)
7 changes: 7 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
""" Models package """
from models.vae import VAE, Encoder, Decoder
from models.mdrnn import MDRNN, MDRNNCell
from models.controller import Controller

__all__ = ['VAE', 'Encoder', 'Decoder',
'MDRNN', 'MDRNNCell', 'Controller']
13 changes: 13 additions & 0 deletions models/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
""" Define controller """
import torch
import torch.nn as nn

class Controller(nn.Module):
""" Controller """
def __init__(self, latents, recurrents, actions):
super().__init__()
self.fc = nn.Linear(latents + recurrents, actions)

def forward(self, *inputs):
cat_in = torch.cat(inputs, dim=1)
return self.fc(cat_in)
154 changes: 154 additions & 0 deletions models/mdrnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Define MDRNN model, supposed to be used as a world model
on the latent space.
"""
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.distributions.normal import Normal

def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments
""" Computes the gmm loss.
Compute minus the log probability of batch under the GMM model described
by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
dimensions (several batch dimension are useful when you have both a batch
axis and a time step axis), gs the number of mixtures and fs the number of
features.
:args batch: (bs1, bs2, *, fs) torch tensor
:args mus: (bs1, bs2, *, gs, fs) torch tensor
:args sigmas: (bs1, bs2, *, gs, fs) torch tensor
:args logpi: (bs1, bs2, *, gs) torch tensor
:args reduce: if not reduce, the mean in the following formula is ommited
:returns:
loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
sum_{k=1..gs} pi[i1, i2, ..., k] * N(
batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))
NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
with fs).
"""
batch = batch.unsqueeze(-2)
normal_dist = Normal(mus, sigmas)
g_log_probs = normal_dist.log_prob(batch)
g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
g_log_probs = g_log_probs - max_log_probs

g_probs = torch.exp(g_log_probs)
probs = torch.sum(g_probs, dim=-1)

log_prob = max_log_probs.squeeze() + torch.log(probs)
if reduce:
return - torch.mean(log_prob)
return - log_prob

class _MDRNNBase(nn.Module):
def __init__(self, latents, actions, hiddens, gaussians):
super().__init__()
self.latents = latents
self.actions = actions
self.hiddens = hiddens
self.gaussians = gaussians

self.gmm_linear = nn.Linear(
hiddens, (2 * latents + 1) * gaussians + 2)

def forward(self, *inputs):
pass

class MDRNN(_MDRNNBase):
""" MDRNN model for multi steps forward """
def __init__(self, latents, actions, hiddens, gaussians):
super().__init__(latents, actions, hiddens, gaussians)
self.rnn = nn.LSTM(latents + actions, hiddens)

def forward(self, actions, latents): # pylint: disable=arguments-differ
""" MULTI STEPS forward.
:args actions: (SEQ_LEN, BSIZE, ASIZE) torch tensor
:args latents: (SEQ_LEN, BSIZE, LSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, rs, ds, parameters of the GMM
prediction for the next latent, gaussian prediction of the reward and
logit prediction of terminality.
- mu_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (SEQ_LEN, BSIZE, N_GAUSS) torch tensor
- rs: (SEQ_LEN, BSIZE) torch tensor
- ds: (SEQ_LEN, BSIZE) torch tensor
"""
seq_len, bs = actions.size(0), actions.size(1)

ins = torch.cat([actions, latents], dim=-1)
outs, _ = self.rnn(ins)
gmm_outs = self.gmm_linear(outs)

stride = self.gaussians * self.latents

mus = gmm_outs[:, :, :stride]
mus = mus.view(seq_len, bs, self.gaussians, self.latents)

sigmas = gmm_outs[:, :, stride:2 * stride]
sigmas = sigmas.view(seq_len, bs, self.gaussians, self.latents)
sigmas = torch.exp(sigmas)

pi = gmm_outs[:, :, 2 * stride: 2 * stride + self.gaussians]
pi = pi.view(seq_len, bs, self.gaussians)
logpi = f.log_softmax(pi, dim=-1)

rs = gmm_outs[:, :, -2]

ds = gmm_outs[:, :, -1]

return mus, sigmas, logpi, rs, ds

class MDRNNCell(_MDRNNBase):
""" MDRNN model for one step forward """
def __init__(self, latents, actions, hiddens, gaussians):
super().__init__(latents, actions, hiddens, gaussians)
self.rnn = nn.LSTMCell(latents + actions, hiddens)

def forward(self, action, latent, hidden): # pylint: disable=arguments-differ
""" ONE STEP forward.
:args actions: (BSIZE, ASIZE) torch tensor
:args latents: (BSIZE, LSIZE) torch tensor
:args hidden: (BSIZE, RSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, r, d, next_hidden, parameters of
the GMM prediction for the next latent, gaussian prediction of the
reward, logit prediction of terminality and next hidden state.
- mu_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (BSIZE, N_GAUSS) torch tensor
- rs: (BSIZE) torch tensor
- ds: (BSIZE) torch tensor
"""
in_al = torch.cat([action, latent], dim=1)

next_hidden = self.rnn(in_al, hidden)
out_rnn = next_hidden[0]

out_full = self.gmm_linear(out_rnn)

stride = self.gaussians * self.latents

mus = out_full[:, :stride]
mus = mus.view(-1, self.gaussians, self.latents)

sigmas = out_full[:, stride:2 * stride]
sigmas = sigmas.view(-1, self.gaussians, self.latents)
sigmas = torch.exp(sigmas)

pi = out_full[:, 2 * stride:2 * stride + self.gaussians]
pi = pi.view(-1, self.gaussians)
logpi = f.log_softmax(pi, dim=-2)

r = out_full[:, -2]

d = out_full[:, -1]

return mus, sigmas, logpi, r, d, next_hidden
Loading

0 comments on commit a2ce190

Please sign in to comment.