Lucidrains 系列项目源码解析(七十九)
.\lucidrains\q-transformer\q_transformer\mocks.py
from random import randrange
import torch
from torch.utils.data import Dataset
from beartype.typing import Tuple, Optional
from torchtyping import TensorType
from q_transformer.agent import BaseEnvironment
class MockEnvironment(BaseEnvironment):
def init(self) -> Tuple[
Optional[str],
TensorType[float]
]:
return 'please clean the kitchen', torch.randn(self.state_shape, device = self.device)
def forward(self, actions) -> Tuple[
TensorType[(), float],
TensorType[float],
TensorType[(), bool]
]:
rewards = torch.randn((), device = self.device)
next_states = torch.randn(self.state_shape, device = self.device)
done = torch.zeros((), device = self.device, dtype = torch.bool)
return rewards, next_states, done
class MockReplayDataset(Dataset):
def __init__(
self,
length = 10000,
num_actions = 1,
num_action_bins = 256,
video_shape = (6, 224, 224)
):
self.length = length
self.num_actions = num_actions
self.num_action_bins = num_action_bins
self.video_shape = video_shape
def __len__(self):
return self.length
def __getitem__(self, _):
instruction = "please clean the kitchen"
state = torch.randn(3, *self.video_shape)
if self.num_actions == 1:
action = torch.tensor(randrange(self.num_action_bins + 1))
else:
action = torch.randint(0, self.num_action_bins + 1, (self.num_actions,))
next_state = torch.randn(3, *self.video_shape)
reward = torch.tensor(randrange(2))
done = torch.tensor(randrange(2), dtype = torch.bool)
return instruction, state, action, next_state, reward, done
class MockReplayNStepDataset(Dataset):
def __init__(
self,
length = 10000,
num_steps = 2,
num_actions = 1,
num_action_bins = 256,
video_shape = (6, 224, 224)
):
self.num_steps = num_steps
self.time_shape = (num_steps,)
self.length = length
self.num_actions = num_actions
self.num_action_bins = num_action_bins
self.video_shape = video_shape
def __len__(self):
return self.length
def __getitem__(self, _):
action_dims = (self.num_actions,) if self.num_actions > 1 else tuple()
instruction = "please clean the kitchen"
state = torch.randn(*self.time_shape, 3, *self.video_shape)
action = torch.randint(0, self.num_action_bins + 1, (*self.time_shape, *action_dims))
next_state = torch.randn(3, *self.video_shape)
reward = torch.randint(0, 2, self.time_shape)
done = torch.zeros(self.time_shape, dtype = torch.bool)
return instruction, state, action, next_state, reward, done
.\lucidrains\q-transformer\q_transformer\optimizer.py
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_adam_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True
):
has_wd = wd > 0
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if group_wd_params and has_wd:
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
if not has_wd:
return Adam(params, lr = lr, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
.\lucidrains\q-transformer\q_transformer\q_learner.py
from pathlib import Path
from functools import partial
from contextlib import nullcontext
from collections import namedtuple
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.utils.data import Dataset, DataLoader
from torchtyping import TensorType
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Optional, Union, List, Tuple
from q_transformer.q_robotic_transformer import QRoboticTransformer
from q_transformer.optimizer import get_adam_optimizer
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
QIntermediates = namedtuple('QIntermediates', [
'q_pred_all_actions',
'q_pred',
'q_next',
'q_target'
])
Losses = namedtuple('Losses', [
'td_loss',
'conservative_reg_loss'
])
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def is_divisible(num, den):
return (num % den) == 0
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def cycle(dl):
while True:
for batch in dl:
yield batch
def batch_select_indices(t, indices):
indices = rearrange(indices, '... -> ... 1')
selected = t.gather(-1, indices)
return rearrange(selected, '... 1 -> ...')
class QLearner(Module):
@beartype
def __init__(
self,
model: Union[QRoboticTransformer, Module],
*,
dataset: Dataset,
batch_size: int,
num_train_steps: int,
learning_rate: float,
min_reward: float = 0.,
grad_accum_every: int = 1,
monte_carlo_return: Optional[float] = None,
weight_decay: float = 0.,
accelerator: Optional[Accelerator] = None,
accelerator_kwargs: dict = dict(),
dataloader_kwargs: dict = dict(
shuffle = True
),
q_target_ema_kwargs: dict = dict(
beta = 0.99,
update_after_step = 10,
update_every = 5
),
max_grad_norm = 0.5,
n_step_q_learning = False,
discount_factor_gamma = 0.98,
conservative_reg_loss_weight = 1.,
optimizer_kwargs: dict = dict(),
checkpoint_folder = './checkpoints',
checkpoint_every = 1000,
def __init__(
self,
model,
discount_factor_gamma,
n_step_q_learning,
conservative_reg_loss_weight,
q_target_ema_kwargs,
max_grad_norm,
learning_rate,
weight_decay,
optimizer_kwargs,
accelerator,
accelerator_kwargs,
min_reward,
monte_carlo_return,
dataset,
batch_size,
dataloader_kwargs,
checkpoint_every,
checkpoint_folder,
num_train_steps,
grad_accum_every
):
super().__init__()
self.is_multiple_actions = model.num_actions > 1
self.discount_factor_gamma = discount_factor_gamma
self.n_step_q_learning = n_step_q_learning
self.has_conservative_reg_loss = conservative_reg_loss_weight > 0.
self.conservative_reg_loss_weight = conservative_reg_loss_weight
self.register_buffer('discount_matrix', None, persistent = False)
self.model = model
self.ema_model = EMA(
model,
include_online_model = False,
**q_target_ema_kwargs
)
self.max_grad_norm = max_grad_norm
self.optimizer = get_adam_optimizer(
model.parameters(),
lr = learning_rate,
wd = weight_decay,
**optimizer_kwargs
)
if not exists(accelerator):
accelerator = Accelerator(
kwargs_handlers = [
DistributedDataParallelKwargs(find_unused_parameters = True)
],
**accelerator_kwargs
)
self.accelerator = accelerator
self.min_reward = min_reward
self.monte_carlo_return = monte_carlo_return
self.dataloader = DataLoader(
dataset,
batch_size = batch_size,
**dataloader_kwargs
)
(
self.model,
self.ema_model,
self.optimizer,
self.dataloader
) = self.accelerator.prepare(
self.model,
self.ema_model,
self.optimizer,
self.dataloader
)
self.checkpoint_every = checkpoint_every
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
assert self.checkpoint_folder.is_dir()
self.register_buffer('zero', torch.tensor(0.))
self.num_train_steps = num_train_steps
self.grad_accum_every = grad_accum_every
self.register_buffer('step', torch.tensor(0))
def save(
self,
checkpoint_num = None,
overwrite = True
):
name = 'checkpoint'
if exists(checkpoint_num):
name += f'-{checkpoint_num}'
path = self.checkpoint_folder / (name + '.pt')
assert overwrite or not path.exists()
pkg = dict(
model = self.unwrap(self.model).state_dict(),
ema_model = self.unwrap(self.ema_model).state_dict(),
optimizer = self.optimizer.state_dict(),
step = self.step.item()
)
torch.save(pkg, str(path))
def load(self, path):
path = Path(path)
assert exists(path)
pkg = torch.load(str(path))
self.unwrap(self.model).load_state_dict(pkg['model'])
self.unwrap(self.ema_model).load_state_dict(pkg['ema_model'])
self.optimizer.load_state_dict(pkg['optimizer'])
self.step.copy_(pkg['step'])
@property
def device(self):
return self.accelerator.device
@property
def is_main(self):
return self.accelerator.is_main_process
def unwrap(self, module):
return self.accelerator.unwrap_model(module)
def print(self, msg):
return self.accelerator.print(msg)
def wait(self):
return self.accelerator.wait_for_everyone()
def get_discount_matrix(self, timestep):
if exists(self.discount_matrix) and self.discount_matrix.shape[-1] >= timestep:
return self.discount_matrix[:timestep, :timestep]
timestep_arange = torch.arange(timestep, device=self.accelerator.device)
powers = (timestep_arange[None, :] - timestep_arange[:, None])
discount_matrix = torch.triu(self.discount_factor_gamma ** powers)
self.register_buffer('discount_matrix', discount_matrix, persistent=False)
return self.discount_matrix
def q_learn(
self,
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
reward: TensorType['b', float],
done: TensorType['b', bool],
*,
monte_carlo_return=None
) -> Tuple[TensorType[()], QIntermediates]:
γ = self.discount_factor_gamma
not_terminal = (~done).float()
q_pred_all_actions = self.model(states, text_embeds=text_embeds)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_next = self.ema_model(next_states, text_embeds=text_embeds).amax(dim=-1)
q_next.clamp_(min=default(monte_carlo_return, -1e4))
q_target = reward + not_terminal * (γ * q_next)
loss = F.mse_loss(q_pred, q_target)
return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)
def n_step_q_learn(
self,
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return=None
) -> Tuple[TensorType[()], QIntermediates]:
"""
einops
b - batch
c - channels
f - frames
h - height
w - width
t - timesteps
a - action bins
q - q values
d - text cond dimension
"""
num_timesteps, device = states.shape[1], states.device
states, time_ps = pack_one(states, '* c f h w')
text_embeds, _ = pack_one(text_embeds, '* d')
repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)
γ = self.discount_factor_gamma
dones = dones.cumsum(dim = -1) > 0
dones = F.pad(dones, (1, 0), value = False)
not_terminal = (~dones).float()
actions = rearrange(actions, 'b t -> (b t)')
q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_pred = unpack_one(q_pred, time_ps, '*')
q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1)
q_next.clamp_(min = default(monte_carlo_return, -1e4))
rewards, _ = pack([rewards, q_next], 'b *')
γ = self.get_discount_matrix(num_timesteps + 1)[:-1, :]
q_target = einsum('b t, q t -> b q', not_terminal * rewards, γ)
loss = F.mse_loss(q_pred, q_target)
q_pred_all_actions = unpack_one(q_pred_all_actions, time_ps, '* a')
return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)
def autoregressive_q_learn_handle_single_timestep(
self,
text_embeds,
states,
actions,
next_states,
rewards,
dones,
*,
monte_carlo_return = None
):
"""
simply detect and handle single timestep
and use `autoregressive_q_learn` as more general function
"""
if states.ndim == 5:
states = rearrange(states, 'b ... -> b 1 ...')
if actions.ndim == 2:
actions = rearrange(actions, 'b ... -> b 1 ...')
if rewards.ndim == 1:
rewards = rearrange(rewards, 'b -> b 1')
if dones.ndim == 1:
dones = rearrange(dones, 'b -> b 1')
return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)
def autoregressive_q_learn(
self,
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', 'n', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = None
) -> Tuple[TensorType[()], QIntermediates]:
"""
einops
b - batch
c - channels
f - frames
h - height
w - width
t - timesteps
n - number of actions
a - action bins
q - q values
d - text cond dimension
"""
monte_carlo_return = default(monte_carlo_return, -1e4)
num_timesteps, device = states.shape[1], states.device
states, time_ps = pack_one(states, '* c f h w')
actions, _ = pack_one(actions, '* n')
text_embeds, _ = pack_one(text_embeds, '* d')
repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)
dones = dones.cumsum(dim = -1) > 0
dones = F.pad(dones, (1, -1), value = False)
not_terminal = (~dones).float()
rewards = rewards * not_terminal
γ = self.discount_factor_gamma
q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_pred = unpack_one(q_pred, time_ps, '* n')
q_next = self.ema_model(next_states, text_embeds = text_embeds)
q_next = q_next.max(dim = -1).values
q_next.clamp_(min = monte_carlo_return)
q_target_all_actions = self.ema_model(states, text_embeds = repeated_text_embeds, actions = actions)
q_target = q_target_all_actions.max(dim = -1).values
q_target.clamp_(min = monte_carlo_return)
q_target = unpack_one(q_target, time_ps, '* n')
q_pred_rest_actions, q_pred_last_action = q_pred[..., :-1], q_pred[..., -1]
q_target_first_action, q_target_rest_actions = q_target[..., 0], q_target[..., 1:]
losses_all_actions_but_last = F.mse_loss(q_pred_rest_actions, q_target_rest_actions, reduction = 'none')
q_target_last_action, _ = pack([q_target_first_action[..., 1:], q_next], 'b *')
q_target_last_action = rewards + γ * q_target_last_action
losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action, reduction = 'none')
losses, _ = pack([losses_all_actions_but_last, losses_last_action], '*')
return losses.mean(), QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)
def learn(
self,
*args,
min_reward: Optional[float] = None,
monte_carlo_return: Optional[float] = None
):
_, _, actions, *_ = args
q_learn_kwargs = dict(
monte_carlo_return = monte_carlo_return
)
if self.is_multiple_actions:
td_loss, q_intermediates = self.autoregressive_q_learn_handle_single_timestep(*args, **q_learn_kwargs)
num_timesteps = actions.shape[1]
elif self.n_step_q_learning:
td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs)
num_timesteps = actions.shape[1]
else:
td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs)
num_timesteps = 1
if not self.has_conservative_reg_loss:
return loss, Losses(td_loss, self.zero)
batch = actions.shape[0]
q_preds = q_intermediates.q_pred_all_actions
q_preds = rearrange(q_preds, '... a -> (...) a')
num_action_bins = q_preds.shape[-1]
num_non_dataset_actions = num_action_bins - 1
actions = rearrange(actions, '... -> (...) 1')
dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds))
q_actions_not_taken = q_preds[~dataset_action_mask.bool()]
q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions)
conservative_reg_loss = ((q_actions_not_taken - (min_reward * num_timesteps)) ** 2).sum() / num_non_dataset_actions
loss = 0.5 * td_loss + \
0.5 * conservative_reg_loss * self.conservative_reg_loss_weight
loss_breakdown = Losses(td_loss, conservative_reg_loss)
return loss, loss_breakdown
def forward(
self,
*,
monte_carlo_return: Optional[float] = None,
min_reward: Optional[float] = None
):
monte_carlo_return = default(monte_carlo_return, self.monte_carlo_return)
min_reward = default(min_reward, self.min_reward)
step = self.step.item()
replay_buffer_iter = cycle(self.dataloader)
self.model.train()
self.ema_model.train()
while step < self.num_train_steps:
self.optimizer.zero_grad()
for grad_accum_step in range(self.grad_accum_every):
is_last = grad_accum_step == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext
with self.accelerator.autocast(), context():
loss, (td_loss, conservative_reg_loss) = self.learn(
*next(replay_buffer_iter),
min_reward = min_reward,
monte_carlo_return = monte_carlo_return
)
self.accelerator.backward(loss / self.grad_accum_every)
self.print(f'td loss: {td_loss.item():.3f}')
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.wait()
self.ema_model.update()
step += 1
self.step.add_(1)
self.wait()
if self.is_main and is_divisible(step, self.checkpoint_every):
checkpoint_num = step // self.checkpoint_every
self.save(checkpoint_num)
self.wait()
self.print('training complete')
.\lucidrains\q-transformer\q_transformer\q_robotic_transformer.py
from random import random
from functools import partial, cache
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from beartype import beartype
from beartype.typing import Union, List, Optional, Callable, Tuple, Dict, Any
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce
from q_transformer.attend import Attend
from classifier_free_guidance_pytorch import (
TextConditioner,
AttentionTextConditioner,
NullConditioner,
classifier_free_guidance
)
def exists(val):
return val is not None
def xnor(x, y):
""" (True, True) or (False, False) -> True """
return not (x ^ y)
def divisible_by(num, den):
return (num % den) == 0
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def l2norm(t, dim = -1):
return F.normalize(t, dim = dim)
def pack_one(x, pattern):
return pack([x], pattern)
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
class RotaryEmbedding(Module):
def __init__(self, dim, omega = 10000):
super().__init__()
inv_freq = 1.0 / (omega ** (torch.arange(0, dim, 4).float() / dim))
self.register_buffer('inv_freq', inv_freq)
@autocast(enabled = False)
def forward(self, height_width):
device, dtype = self.inv_freq.device, self.inv_freq.dtype
axial_pos = torch.arange(height_width, device = device).type(dtype)
freqs = torch.einsum('i, j -> i j', axial_pos, self.inv_freq)
freqs = repeat(freqs, '... f -> ... (f c)', c = 2)
freqs = torch.broadcast_tensors(freqs[None, :, :], freqs[:, None, :])
freqs = torch.cat(freqs, dim = -1)
return rearrange(freqs, '... f -> (...) f')
def rotate_half(x):
x1, x2 = rearrange(x, '... (d c) -> ... d c', c = 2).unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d c -> ... (d c)')
@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
@cache
def get_is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
def MaybeSyncBatchnorm2d(is_distributed = None):
is_distributed = default(is_distributed, get_is_distributed())
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d
class RMSNorm(Module):
def __init__(self, dim, affine = True):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim)) if affine else 1.
def forward(self, x):
return l2norm(x) * self.gamma * self.scale
class ChanRMSNorm(Module):
def __init__(self, dim, affine = True):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1)) if affine else 1.
def forward(self, x):
return l2norm(x, dim = 1) * self.gamma * self.scale
def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
n = torch.arange(seq, device = device)
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
omega = 1. / (temperature ** omega)
n = n[:, None] * omega[None, :]
pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
return pos_emb.type(dtype)
class Residual(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class FeedForward(Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.,
adaptive_ln = False
):
super().__init__()
self.adaptive_ln = adaptive_ln
inner_dim = int(dim * mult)
self.norm = RMSNorm(dim, affine = not adaptive_ln)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(
self,
x,
cond_fn: Optional[Callable] = None
):
x = self.norm(x)
assert xnor(self.adaptive_ln, exists(cond_fn))
if exists(cond_fn):
x = cond_fn(x)
return self.net(x)
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
batch, device = x.shape[0], x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((batch, 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.,
is_distributed = None,
use_layernorm = True
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
if use_layernorm:
norm_klass = ChanRMSNorm
else:
norm_klass = MaybeSyncBatchnorm2d(is_distributed)
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
norm_klass(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
norm_klass(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
norm_klass(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7,
num_mem_kv = 4,
flash = True
):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.heads = heads
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.to_v_gates = nn.Sequential(
nn.Linear(dim, self.heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
)
self.attend = Attend(
causal = False,
dropout = dropout,
flash = flash
)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
rotary_emb = None
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x = self.norm(x)
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
g = self.to_v_gates(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)
out = self.attend(q, k, v)
out = out * g
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
class MaxViT(Module):
@beartype
def __init__(
self,
*,
num_classes,
dim,
depth: Tuple[int, ...],
heads = 8,
dim_head = 64,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
use_layernorm = True,
dropout = 0.1,
channels = 3,
flash_attn = True
):
super().__init__()
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = nn.Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = ModuleList([])
self.window_size = window_size
w = window_size
assert divisible_by(dim_head, 4), f'{dim_head} must be divisible by 4 for axial rotary embedding for maxvit'
self.axial_rotary_emb = RotaryEmbedding(dim_head)
self.register_buffer('cached_rotary_emb', self.axial_rotary_emb(window_size), persistent = False)
cond_hidden_dims = []
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
cond_hidden_dims.append(stage_dim_in)
block = nn.ModuleList([
MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate,
use_layernorm = use_layernorm
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),
Residual(Attention(dim = layer_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = w, flash = flash_attn)),
Residual(FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),
Residual(Attention(dim = layer_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = w, flash = flash_attn)),
Residual(FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
])
self.layers.append(block)
embed_dim = dims[-1]
self.embed_dim = dims[-1]
self.cond_hidden_dims = cond_hidden_dims
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
RMSNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
@beartype
def forward(
self,
img,
texts: Optional[List[str]] = None,
cond_fns: Optional[Tuple[Callable, ...]] = None,
cond_drop_prob = 0.,
return_embeddings = False
assert all([divisible_by(d, self.window_size) for d in img.shape[-2:])
x = self.conv_stem(img)
rotary_emb = self.cached_rotary_emb
cond_fns = iter(default(cond_fns, []))
for (
mb_conv,
rearr_windowed_in,
windowed_attn,
windowed_ff,
rearr_windowed_out,
rearr_grid_in,
grid_attn,
grid_ff,
rearr_grid_out
) in self.layers:
cond_fn = next(cond_fns, None)
if exists(cond_fn):
x = cond_fn(x)
x = mb_conv(x)
x = rearr_windowed_in(x)
x = windowed_attn(x, rotary_emb = rotary_emb)
x = windowed_ff(x)
x = rearr_windowed_out(x)
x = rearr_grid_in(x)
x = grid_attn(x, rotary_emb = rotary_emb)
x = grid_ff(x)
x = rearr_grid_out(x)
if return_embeddings:
return x
return self.mlp_head(x)
class TransformerAttention(Module):
def __init__(
self,
dim,
dim_head = 64,
dim_context = None,
heads = 8,
num_mem_kv = 4,
norm_context = False,
adaptive_ln = False,
dropout = 0.1,
flash = True,
causal = False
):
super().__init__()
self.heads = heads
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)
self.adaptive_ln = adaptive_ln
self.norm = RMSNorm(dim, affine = not adaptive_ln)
self.context_norm = RMSNorm(dim_context) if norm_context else None
self.attn_dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.num_mem_kv = num_mem_kv
self.mem_kv = None
if num_mem_kv > 0:
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
self.attend = Attend(
dropout = dropout,
flash = flash,
causal = causal
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
context = None,
mask = None,
attn_mask = None,
cond_fn: Optional[Callable] = None,
cache: Optional[Tensor] = None,
return_cache = False
):
b = x.shape[0]
assert xnor(exists(context), exists(self.context_norm))
if exists(context):
context = self.context_norm(context)
kv_input = default(context, x)
x = self.norm(x)
assert xnor(exists(cond_fn), self.adaptive_ln)
if exists(cond_fn):
x = cond_fn(x)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
if exists(cache):
ck, cv = cache
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
new_kv_cache = torch.stack((k, v))
if exists(self.mem_kv):
mk, mv = map(lambda t: repeat(t, '... -> b ...', b = b), self.mem_kv)
k = torch.cat((mk, k), dim = -2)
v = torch.cat((mv, v), dim = -2)
if exists(mask):
mask = F.pad(mask, (self.num_mem_kv, 0), value = True)
if exists(attn_mask):
attn_mask = F.pad(attn_mask, (self.num_mem_kv, 0), value = True)
out = self.attend(q, k, v, mask = mask, attn_mask = attn_mask)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
if not return_cache:
return out
return out, new_kv_cache
class Transformer(Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
depth = 6,
attn_dropout = 0.,
ff_dropout = 0.,
adaptive_ln = False,
flash_attn = True,
cross_attend = False,
causal = False,
final_norm = True
):
super().__init__()
self.layers = ModuleList([])
attn_kwargs = dict(
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = attn_dropout,
flash = flash_attn
)
for _ in range(depth):
self.layers.append(ModuleList([
TransformerAttention(**attn_kwargs, causal = causal, adaptive_ln = adaptive_ln, norm_context = False),
TransformerAttention(**attn_kwargs, norm_context = True) if cross_attend else None,
FeedForward(dim = dim, dropout = ff_dropout, adaptive_ln = adaptive_ln)
]))
self.norm = RMSNorm(dim) if final_norm else nn.Identity()
@beartype
def forward(
self,
x,
cond_fns: Optional[Tuple[Callable, ...]] = None,
attn_mask = None,
context: Optional[Tensor] = None,
cache: Optional[Tensor] = None,
return_cache = False
):
has_cache = exists(cache)
if has_cache:
x_prev, x = x[..., :-1, :], x[..., -1:, :]
cond_fns = iter(default(cond_fns, []))
cache = iter(default(cache, []))
new_caches = []
for attn, maybe_cross_attn, ff in self.layers:
attn_out, new_cache = attn(
x,
attn_mask = attn_mask,
cond_fn = next(cond_fns, None),
return_cache = True,
cache = next(cache, None)
)
new_caches.append(new_cache)
x = x + attn_out
if exists(maybe_cross_attn):
assert exists(context)
x = maybe_cross_attn(x, context = context) + x
x = ff(x, cond_fn = next(cond_fns, None)) + x
new_caches = torch.stack(new_caches)
if has_cache:
x = torch.cat((x_prev, x), dim = -2)
out = self.norm(x)
if not return_cache:
return out
return out, new_caches
class TokenLearner(Module):
"""
https://arxiv.org/abs/2106.11297
using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
"""
def __init__(
self,
*,
dim,
ff_mult = 2,
num_output_tokens = 8,
num_layers = 2
):
super().__init__()
inner_dim = dim * ff_mult * num_output_tokens
self.num_output_tokens = num_output_tokens
self.net = nn.Sequential(
nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
nn.GELU(),
nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
)
def forward(self, x):
x, ps = pack_one(x, '* c h w')
x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
attn = self.net(x)
attn = rearrange(attn, 'b g h w -> b 1 g h w')
x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)
x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
x = unpack_one(x, ps, '* c n')
return x
class DuelingHead(Module):
def __init__(
self,
dim,
expansion_factor = 2,
action_bins = 256
):
super().__init__()
dim_hidden = dim * expansion_factor
self.stem = nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.SiLU()
)
self.to_values = nn.Sequential(
nn.Linear(dim_hidden, 1)
)
self.to_advantages = nn.Sequential(
nn.Linear(dim_hidden, action_bins)
)
def forward(self, x):
x = self.stem(x)
advantages = self.to_advantages(x)
advantages = advantages - reduce(advantages, '... a -> ... 1', 'mean')
values = self.to_values(x)
q_values = values + advantages
return q_values.sigmoid()
class QHeadSingleAction(Module):
def __init__(
self,
dim,
*,
num_learned_tokens = 8,
action_bins = 256,
dueling = False
):
super().__init__()
self.action_bins = action_bins
if dueling:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
DuelingHead(
dim,
action_bins = action_bins
)
)
else:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
RMSNorm(dim),
nn.Linear(dim, action_bins),
nn.Sigmoid()
)
def get_random_actions(self, batch_size):
return torch.randint(0, self.action_bins, (batch_size,), device = self.device)
def get_optimal_actions(
self,
encoded_state,
return_q_values = False,
actions = None,
**kwargs
):
assert not exists(actions), 'single actions will never receive previous actions'
q_values = self.forward(encoded_state)
max_q, action_indices = q_values.max(dim = -1)
if not return_q_values:
return action_indices
return action_indices, max_q
def forward(self, encoded_state):
return self.to_q_values(encoded_state)
class QHeadMultipleActions(Module):
def __init__(
self,
dim,
*,
num_actions = 8,
action_bins = 256,
attn_depth = 2,
attn_dim_head = 32,
attn_heads = 8,
dueling = False,
weight_tie_action_bin_embed = False
):
super().__init__()
self.num_actions = num_actions
self.action_bins = action_bins
self.action_bin_embeddings = nn.Parameter(torch.zeros(num_actions, action_bins, dim))
nn.init.normal_(self.action_bin_embeddings, std = 0.02)
self.to_q_values = None
if not weight_tie_action_bin_embed:
self.to_q_values = nn.Linear(dim, action_bins)
self.transformer = Transformer(
dim = dim,
depth = attn_depth,
dim_head = attn_dim_head,
heads = attn_heads,
cross_attend = True,
adaptive_ln = False,
causal = True,
final_norm = True
)
self.final_norm = RMSNorm(dim)
self.dueling = dueling
if dueling:
self.to_values = nn.Parameter(torch.zeros(num_actions, dim))
@property
def device(self):
return self.action_bin_embeddings.device
def maybe_append_actions(self, sos_tokens, actions: Optional[Tensor] = None):
if not exists(actions):
return sos_tokens
batch, num_actions = actions.shape
action_embeddings = self.action_bin_embeddings[:num_actions]
action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch)
past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1])
bin_embeddings = action_embeddings.gather(-2, past_action_bins)
bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d')
tokens, _ = pack((sos_tokens, bin_embeddings), 'b * d')
tokens = tokens[:, :self.num_actions]
return tokens
def get_q_values(self, embed):
num_actions = embed.shape[-2]
if exists(self.to_q_values):
logits = self.to_q_values(embed)
else:
action_bin_embeddings = self.action_bin_embeddings[:num_actions]
action_bin_embeddings = torch.roll(action_bin_embeddings, shifts = -1, dims = 1)
logits = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)
if self.dueling:
advantages = logits
values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions])
values = rearrange(values, 'b n -> b n 1')
q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean'))
else:
q_values = logits
return q_values.sigmoid()
def get_random_actions(self, batch_size, num_actions = None):
num_actions = default(num_actions, self.num_actions)
return torch.randint(0, self.action_bins, (batch_size, num_actions), device = self.device)
@torch.no_grad()
def get_optimal_actions(
self,
encoded_state,
return_q_values = False,
actions: Optional[Tensor] = None,
prob_random_action: float = 0.5,
**kwargs
):
assert 0. <= prob_random_action <= 1.
batch = encoded_state.shape[0]
if prob_random_action == 1:
return self.get_random_actions(batch)
sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean')
tokens = self.maybe_append_actions(sos_token, actions = actions)
action_bins = []
cache = None
for action_idx in range(self.num_actions):
embed, cache = self.transformer(
tokens,
context = encoded_state,
cache = cache,
return_cache = True
)
last_embed = embed[:, action_idx]
bin_embeddings = self.action_bin_embeddings[action_idx]
q_values = einsum('b d, a d -> b a', last_embed, bin_embeddings)
if prob_random_action > 0.:
random_mask = torch.zeros_like(selected_action_bins).float().uniform_(0., 1.) < prob_random_action
random_actions = self.get_random_actions(batch, 1)
random_actions = rearrange(random_actions, '... 1 -> ...')
selected_action_bins = torch.where(
random_mask,
random_actions,
selected_action_bins
)
next_action_embed = bin_embeddings[selected_action_bins]
tokens, _ = pack((tokens, next_action_embed), 'b * d')
action_bins.append(selected_action_bins)
action_bins = torch.stack(action_bins, dim = -1)
if not return_q_values:
return action_bins
all_q_values = self.get_q_values(embed)
return action_bins, all_q_values
def forward(
self,
encoded_state: Tensor,
actions: Optional[Tensor] = None
):
"""
einops
b - batch
n - number of actions
a - action bins
d - dimension
"""
sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean')
tokens = self.maybe_append_actions(sos_token, actions = actions)
embed = self.transformer(tokens, context = encoded_state)
return self.get_q_values(embed)
class QRoboticTransformer(Module):
@beartype
def __init__(
self,
*,
vit: Union[Dict[str, Any], MaxViT],
num_actions = 8,
action_bins = 256,
depth = 6,
heads = 8,
dim_head = 64,
token_learner_ff_mult = 2,
token_learner_num_layers = 2,
token_learner_num_output_tokens = 8,
cond_drop_prob = 0.2,
use_attn_conditioner = False,
conditioner_kwargs: dict = dict(),
dueling = False,
flash_attn = True,
condition_on_text = True,
q_head_attn_kwargs: dict = dict(
attn_heads = 8,
attn_dim_head = 64,
attn_depth = 2
),
weight_tie_action_bin_embed = True
):
super().__init__()
if isinstance(vit, dict):
vit = MaxViT(**vit)
self.vit = vit
self.num_vit_stages = len(vit.cond_hidden_dims)
attend_dim = vit.embed_dim
assert num_actions >= 1
self.num_actions = num_actions
self.is_single_action = num_actions == 1
self.action_bins = action_bins
self.condition_on_text = condition_on_text
if condition_on_text:
conditioner_klass = AttentionTextConditioner if use_attn_conditioner else TextConditioner
self.conditioner = conditioner_klass(
hidden_dims = (*tuple(vit.cond_hidden_dims), *((attend_dim,) * depth * 2)),
hiddens_channel_first = (*((True,) * self.num_vit_stages), *((False,) * depth * 2)),
cond_drop_prob = cond_drop_prob,
**conditioner_kwargs
)
else:
self.conditioner = NullConditioner(hidden_dims = tuple())
self.token_learner = TokenLearner(
dim = vit.embed_dim,
ff_mult = token_learner_ff_mult,
num_output_tokens = token_learner_num_output_tokens,
num_layers = token_learner_num_layers
)
self.num_learned_tokens = token_learner_num_output_tokens
self.transformer_depth = depth
self.transformer = Transformer(
dim = attend_dim,
dim_head = dim_head,
heads = heads,
depth = depth,
flash_attn = flash_attn,
adaptive_ln = condition_on_text,
final_norm = True
)
self.cond_drop_prob = cond_drop_prob
if self.is_single_action:
self.q_head = QHeadSingleAction(
attend_dim,
num_learned_tokens = self.num_learned_tokens,
action_bins = action_bins,
dueling = dueling
)
else:
self.q_head = QHeadMultipleActions(
attend_dim,
action_bins = action_bins,
dueling = dueling,
weight_tie_action_bin_embed = weight_tie_action_bin_embed,
**q_head_attn_kwargs
)
@property
def device(self):
return next(self.parameters()).device
def get_random_actions(self, batch_size = 1):
return self.q_head.get_random_actions(batch_size)
@beartype
def embed_texts(self, texts: List[str]):
return self.conditioner.embed_texts(texts)
@torch.no_grad()
def get_optimal_actions(
self,
*args,
return_q_values = False,
actions: Optional[Tensor] = None,
**kwargs
):
encoded_state = self.encode_state(*args, **kwargs)
return self.q_head.get_optimal_actions(encoded_state, return_q_values = return_q_values, actions = actions)
def get_actions(
self,
video,
*args,
prob_random_action = 0.,
**kwargs,
):
batch_size = video.shape[0]
assert 0. <= prob_random_action <= 1.
if random() < prob_random_action:
return self.get_random_actions(batch_size = batch_size)
return self.get_optimal_actions(video, *args, **kwargs)
def encode_state(
self,
video: Tensor,
texts: Optional[Union[List[str], Tuple[str]]] = None,
text_embeds: Optional[Tensor] = None,
actions: Optional[Tensor] = None,
cond_drop_prob = 0.,
):
"""
einops
b - batch
c - channels
f - frames
h - height
w - width
n - number of learned tokens
"""
if not self.condition_on_text:
assert (not exists(texts) and not exists(text_embeds)), 'neither texts nor text embeds should be passed in'
else:
assert exists(texts) ^ exists(text_embeds), 'either texts or text embeds must be passed in if conditioning on instructions'
if exists(texts) and isinstance(texts, tuple):
texts = list(texts)
text_cond_kwargs = dict(texts = texts, text_embeds = text_embeds)
depth = self.transformer_depth
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
frames, device = video.shape[2], video.device
cond_fns, _ = self.conditioner(
**text_cond_kwargs,
cond_drop_prob = cond_drop_prob,
repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
)
vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):]
video = rearrange(video, 'b c f h w -> b f c h w')
images, packed_shape = pack_one(video, '* c h w')
tokens = self.vit(
images,
texts = texts,
cond_fns = vit_cond_fns,
cond_drop_prob = cond_drop_prob,
return_embeddings = True
)
tokens = unpack_one(tokens, packed_shape, '* c h w')
learned_tokens = self.token_learner(tokens)
tokens_per_frame = learned_tokens.shape[-1]
learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c')
attn_mask = ~torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)
pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)
learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)
attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = attn_mask)
return attended_tokens
@classifier_free_guidance
def forward(
self,
video: Tensor,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None,
actions: Optional[Tensor] = None,
cond_drop_prob = 0.,
video = video.to(self.device)
if exists(actions):
actions = actions.to(self.device)
encoded_state = self.encode_state(
video = video,
texts = texts,
text_embeds = text_embeds,
actions = actions,
cond_drop_prob = cond_drop_prob
)
if self.is_single_action:
assert not exists(actions), 'actions should not be passed in for single action robotic transformer'
q_values = self.q_head(encoded_state)
else:
q_values = self.q_head(encoded_state, actions = actions)
return q_values