Skip to content

Commit

Permalink
Internal discrete policy working (DSAC-nicklashansen#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarrienMcKenzie committed Nov 2, 2024
1 parent c7b7574 commit e0b85e2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
36 changes: 21 additions & 15 deletions tdmpc2/common/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from torch.distributions import Categorical

from common import layers, math, init
from tensordict.nn import TensorDictParams
Expand Down Expand Up @@ -131,13 +132,12 @@ def reward(self, z, a, task):
z = torch.cat([z, a], dim=-1)
return self._reward(z)

def pi(self, z, task):
def pi(self, z, task): #DM: Ideally: remove batch parameter... (it's temporary)
"""
Samples an action from the policy prior.
The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network.
"""

if self.cfg.multitask:
z = self.task_emb(z, task)

Expand All @@ -162,11 +162,19 @@ def pi(self, z, task):
mu, pi, log_pi = math.squash(mu, pi, log_pi)

return mu, pi, log_pi, log_std
else:
pi = self._pi(z)
actions = torch.argmax(torch.nn.functional.softmax(pi,dim=-1), dim=-1, keepdim=True) #DM: Discrete SAC Change #2-2
else: #DM: Discrete SAC Change #2-2
logits = self._pi(z)
policy_dist = Categorical(logits=logits)
a1 = policy_dist.sample()

if len(a1.shape)==2: #DM: handling batched embeddings; temporary fix due to distributions.category nuances
actions = torch.reshape(a1, (a1.shape[0], a1.shape[1], 1))
elif len(a1.shape)==1: #DM: handling case in which a single embedding is passed
actions = a1
action_probs = policy_dist.probs
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

return actions, pi, torch.log(pi)
return actions, action_probs, log_probs

def Q(self, z, a, task, return_type='min', target=False, detach=False):
"""
Expand All @@ -182,10 +190,12 @@ def Q(self, z, a, task, return_type='min', target=False, detach=False):
if self.cfg.multitask:
z = self.task_emb(z, task)


#DM: orig shape of z for continuous (pendulum): [3,256,512]
#DM: orig shape fo a for continuous (pendulum): [3, 256, 1]
z = torch.cat([z, a], dim=-1) #ORIGINAL
z = torch.cat([z, a], dim=-1) #DM: ORIGINAL
#z1 = torch.cat([z,0], dim=-1)
#z2 = torch.cat([z,1], dim=-1))
#DM: change the layers... don't need to batch?: [HORIZON, ALL_N_AVAILABLE_ACTIONS, PARTICULAR ACTION]
#DM: new shape of z for continuous (pendulum)e: [3,256,513]
if target:
qnet = self._target_Qs
Expand All @@ -195,6 +205,8 @@ def Q(self, z, a, task, return_type='min', target=False, detach=False):
qnet = self._Qs

out = qnet(z) #DM: z-shape for continuous: [3,256,513]
#z1 -> qnet(z1)
#z2 -> qnet(z2)

if return_type == 'all':
return out
Expand All @@ -203,10 +215,4 @@ def Q(self, z, a, task, return_type='min', target=False, detach=False):
Q = math.two_hot_inv(out[qidx], self.cfg)
if return_type == "min":
return Q.min(0).values
return Q.sum(0) / 2

def replace_column_with_argmax(tensor, dim):
"""Replace a column in a multi-dimensional tensor with the argmax values along a specified dimension."""

argmax_values = torch.argmax(tensor, dim=dim, keepdim=True)
return torch.cat([argmax_values, tensor[:, 1:]], dim=1)
return Q.sum(0) / 2
7 changes: 3 additions & 4 deletions tdmpc2/tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def act(self, obs, t0=False, eval_mode=False, task=None):
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
else: #DM: Entry point
z = self.model.encode(obs, task)
#print("EVAL MODE = ", eval_mode)
a = self.model.pi(z, task)[int(not eval_mode)][0] #ORIGINAL
a = self.model.pi(z, task)[0] #ORIGINAL
return a.cpu()

@torch.no_grad()
Expand Down Expand Up @@ -225,7 +224,7 @@ def update_pi(self, zs, task):

# Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() #DM: Discrete SAC Change #3 (do we need to change it, here?)
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() #DM: Discrete SAC Change #3
pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step()
Expand All @@ -249,7 +248,7 @@ def _td_target(self, next_z, reward, task):
if not DISCRETE:
pi = self.model.pi(next_z, task)[1] #ORIGINAL
else:
pi = self.model.pi(next_z, task)[0] #redundant change?
pi = self.model.pi(next_z, task)[0] #DM: Discrete

discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
Expand Down
6 changes: 3 additions & 3 deletions tdmpc2/trainer/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,19 @@ def train(self):

# Collect experience
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1) #ORIGINAL
action = torch.argmax(torch.nn.functional.softmax(action))
#print("DELIBERATE ACTION")
action = self.agent.act(obs, t0=len(self._tds)==1) #ORIGINAL
#action = torch.argmax(torch.nn.functional.softmax(action))
#set_trace()
else:
action = self.env.rand_act()
#print("RANDOM ACTION")
#if DISCRETE:
# action = torch.tensor((action,))
obs, reward, done, info = self.env.step(int(action)) #DM: brute force test...

#DM-MOD
if DISCRETE:
#print("ACTION-2: ", action)
action = torch.tensor((action,))
#DM-MODIFIED

Expand Down

0 comments on commit e0b85e2

Please sign in to comment.