ding.model.template.bcq 源代码
from typing import Union, Dict, Optional, List
from easydict import EasyDict
import numpy as np
import torch
import torch.nn as nn
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import RegressionHead, ReparameterizationHead
from .vae import VanillaVAE
[文档]@MODEL_REGISTRY.register('bcq')
class BCQ(nn.Module):
"""
Overview:
Model of BCQ (Batch-Constrained deep Q-learning).
Off-Policy Deep Reinforcement Learning without Exploration.
https://arxiv.org/abs/1812.02900
Interface:
``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval``
Property:
``mode``
"""
mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval']
[文档] def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
actor_head_hidden_size: List = [400, 300],
critic_head_hidden_size: List = [400, 300],
activation: Optional[nn.Module] = nn.ReLU(),
vae_hidden_dims: List = [750, 750],
phi: float = 0.05
) -> None:
"""
Overview:
Initialize neural network, i.e. agent Q network and actor.
Arguments:
- obs_shape (:obj:`int`): the dimension of observation state
- action_shape (:obj:`int`): the dimension of action shape
- actor_hidden_size (:obj:`list`): the list of hidden size of actor
- critic_hidden_size (:obj:'list'): the list of hidden size of critic
- activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU().
- vae_hidden_dims (:obj:`list`): the list of hidden size of vae
"""
super(BCQ, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.input_size = obs_shape
self.phi = phi
critic_input_size = self.input_size + action_shape
self.critic = nn.ModuleList()
for _ in range(2):
net = []
d = critic_input_size
for dim in critic_head_hidden_size:
net.append(nn.Linear(d, dim))
net.append(activation)
d = dim
net.append(nn.Linear(d, 1))
self.critic.append(nn.Sequential(*net))
net = []
d = critic_input_size
for dim in actor_head_hidden_size:
net.append(nn.Linear(d, dim))
net.append(activation)
d = dim
net.append(nn.Linear(d, 1))
self.actor = nn.Sequential(*net)
self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims)
[文档] def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]:
"""
Overview:
The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \
different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ.
Mode compute_actor:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including action tensor.
Mode compute_critic:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including q_value tensor.
Mode compute_vae:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
Mode compute_eval:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including action tensor.
Examples:
>>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
>>> model = BCQ(32, 6)
>>> outputs = model(inputs, mode='compute_actor')
>>> outputs = model(inputs, mode='compute_critic')
>>> outputs = model(inputs, mode='compute_vae')
>>> outputs = model(inputs, mode='compute_eval')
.. note::
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
[文档] def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
Use critic network to compute q value.
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`).
Shapes:
- inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
- outputs (:obj:`Dict`): :math:`(B, N)`.
Examples:
>>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
>>> model = BCQ(32, 6)
>>> outputs = model.compute_critic(inputs)
"""
obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
x = [m(x).squeeze() for m in self.critic]
return {'q_value': x}
[文档] def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Overview:
Use actor network to compute action.
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`).
Shapes:
- inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
- outputs (:obj:`Dict`): :math:`(B, N)`.
Examples:
>>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
>>> model = BCQ(32, 6)
>>> outputs = model.compute_actor(inputs)
"""
input = torch.cat([inputs['obs'], inputs['action']], -1)
x = self.actor(input)
action = self.phi * 1 * torch.tanh(x)
action = (action + inputs['action']).clamp(-1, 1)
return {'action': action}
[文档] def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
Use vae network to compute action.
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \
``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \
``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
Shapes:
- inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
- outputs (:obj:`Dict`): :math:`(B, N)`.
Examples:
>>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
>>> model = BCQ(32, 6)
>>> outputs = model.compute_vae(inputs)
"""
return self.vae.forward(inputs)
[文档] def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
Use actor network to compute action.
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`).
Shapes:
- inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
- outputs (:obj:`Dict`): :math:`(B, N)`.
Examples:
>>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
>>> model = BCQ(32, 6)
>>> outputs = model.compute_eval(inputs)
"""
obs = inputs['obs']
obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0)
z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5)
sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action']
action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action']
q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0]
idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1)
idx = idx.repeat_interleave(action.shape[-1], dim=-1)
action = action.gather(0, idx).squeeze()
return {'action': action}