Shortcuts

TD3BC

Overview

TD3BC, proposed in the 2021 paper A Minimalist Approach to Offline Reinforcement Learning, is a simple approach to offline RL where only two changes are made to TD3: a weighted behavior cloning loss is added to the policy update and the states are normalized. Unlike competing methods there are no changes to architecture or underlying hyperparameters. The resulting algorithm is a simple baseline that is easy to implement and tune, while more than halving the overall run time by removing the additional computational overhead of previous methods.

../_images/td3bc_paper_table1.png

Implementation changes offline RL algorithms make to the underlying base RL algorithm. † corresponds to details that add additional hyperparameter(s), and ‡ corresponds to ones that add a computational cost. Ref

Quick Facts

  1. TD3BC is an offline RL algorithm.

  2. TD3BC is based on TD3 and behavior cloning.

Key Equations or Key Graphs

TD3BC simply consists to add a behavior cloning term to TD3 in order to regularize the policy:

\[\begin{aligned} \pi = \arg\max_{\pi} \mathbb{E}_{(s, a) \sim D} [ \lambda Q(s, \pi(s)) - (\pi(s)-a)^2 ] \end{aligned}\]

\((\pi(s)-a)^2\) is the behavior cloning term acts as a regularizer and aims to push the policy towards favoring actions contained in the dataset. The hyperparameter \(\lambda\) is used to control the strength of the regularizer.

Assuming an action range of [−1, 1], the BC term is at most 4, however the range of Q will be a function of the scale of the reward. Consequently, the scalar \(\lambda\) can be defined as:

\[\begin{aligned} \lambda = \frac{\alpha}{\frac{1}{N}\sum_{s_i, a_i}|Q(s_i, a_i)|} \end{aligned}\]

which is simply a normalization term based on the average absolute value of Q over mini-batches. This formulation has also the benefit of normalizing the learning rate across tasks since it is dependent on the scale of Q. The default value for \(\alpha\) is 2.5.

Additionally, all the states in each mini-batch are normalized, such that they have mean 0 and standard deviation 1. This normalization improves the stability of the learned policy.

Implementations

The default config is defined as follows:

class ding.policy.td3_bc.TD3BCPolicy(cfg: easydict.EasyDict, model: Optional[torch.nn.modules.module.Module] = None, enable_field: Optional[List[str]] = None)[source]
Overview:

Policy class of TD3_BC algorithm.

Since DDPG and TD3 share many common things, we can easily derive this TD3_BC class from DDPG class by changing _actor_update_freq, _twin_critic and noise in model wrapper.

https://arxiv.org/pdf/2106.06860.pdf

Property:

learn_mode, collect_mode, eval_mode

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

td3_bc

RL policy register name, refer
to registry POLICY_REGISTRY
this arg is optional,
a placeholder

2

cuda

bool

True

Whether to use cuda for network

3

random_
collect_size

int

25000

Number of randomly collected
training samples in replay
buffer when training starts.
Default to 25000 for
DDPG/TD3, 10000 for
sac.

4

model.twin_
critic


bool

True

Whether to use two critic
networks or only one.


Default True for TD3,
Clipped Double
Q-learning method in
TD3 paper.

5

learn.learning
_rate_actor

float

1e-3

Learning rate for actor
network(aka. policy).


6

learn.learning
_rate_critic

float

1e-3

Learning rates for critic
network (aka. Q-network).


7

learn.actor_
update_freq


int

2

When critic network updates
once, how many times will actor
network update.

Default 2 for TD3, 1
for DDPG. Delayed
Policy Updates method
in TD3 paper.

8

learn.noise




bool

True

Whether to add noise on target
network’s action.



Default True for TD3,
False for DDPG.
Target Policy Smoo-
thing Regularization
in TD3 paper.

9

learn.noise_
range

dict

dict(min=-0.5,
max=0.5,)

Limit for range of target
policy smoothing noise,
aka. noise_clip.



10

learn.-
ignore_done

bool

False

Determine whether to ignore
done flag.
Use ignore_done only
in halfcheetah env.

11

learn.-
target_theta


float

0.005

Used for soft update of the
target network.


aka. Interpolation
factor in polyak aver
aging for target
networks.

12

collect.-
noise_sigma



float

0.1

Used for add noise during co-
llection, through controlling
the sigma of distribution


Sample noise from dis
tribution, Ornstein-
Uhlenbeck process in
DDPG paper, Guassian
process in ours.

Model

Here we provide examples of ContinuousQAC model as default model for TD3BC.

class ding.model.ContinuousQAC(obs_shape: Union[int, ding.utils.type_helper.SequenceType], action_shape: Union[int, ding.utils.type_helper.SequenceType, easydict.EasyDict], action_space: str, twin_critic: bool = False, actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Optional[torch.nn.modules.module.Module] = ReLU(), norm_type: Optional[str] = None, encoder_hidden_size_list: Optional[ding.utils.type_helper.SequenceType] = None, share_encoder: Optional[bool] = False)[source]
Overview:

The neural network and computation graph of algorithms related to Q-value Actor-Critic (QAC), such as DDPG/TD3/SAC. This model now supports continuous and hybrid action space. The ContinuousQAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. In high-dimensional observation space like 2D image, we often use a shared encoder for both actor_encoder and critic_encoder. In low-dimensional observation space like 1D vector, we often use different encoders.

Interfaces:

__init__, forward, compute_actor, compute_critic

compute_actor(obs: torch.Tensor) Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]][source]
Overview:

QAC forward computation graph for actor part, input observation tensor to predict action or action logit.

Arguments:
  • x (torch.Tensor): The input observation tensor data.

Returns:
  • outputs (Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]): Actor output dict varying from action_space: regression, reparameterization, hybrid.

ReturnsKeys (regression):
  • action (torch.Tensor): Continuous action with same size as action_shape, usually in DDPG/TD3.

ReturnsKeys (reparameterization):
  • logit (Dict[str, torch.Tensor]): The predictd reparameterization action logit, usually in SAC. It is a list containing two tensors: mu and sigma. The former is the mean of the gaussian distribution, the latter is the standard deviation of the gaussian distribution.

ReturnsKeys (hybrid):
  • logit (torch.Tensor): The predicted discrete action type logit, it will be the same dimension as action_type_shape, i.e., all the possible discrete action types.

  • action_args (torch.Tensor): Continuous action arguments with same size as action_args_shape.

Shapes:
  • obs (torch.Tensor): \((B, N0)\), B is batch size and N0 corresponds to obs_shape.

  • action (torch.Tensor): \((B, N1)\), B is batch size and N1 corresponds to action_shape.

  • logit.mu (torch.Tensor): \((B, N1)\), B is batch size and N1 corresponds to action_shape.

  • logit.sigma (torch.Tensor): \((B, N1)\), B is batch size.

  • logit (torch.Tensor): \((B, N2)\), B is batch size and N2 corresponds to action_shape.action_type_shape.

  • action_args (torch.Tensor): \((B, N3)\), B is batch size and N3 corresponds to action_shape.action_args_shape.

Examples:
>>> # Regression mode
>>> model = ContinuousQAC(64, 6, 'regression')
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 6])
>>> # Reparameterization Mode
>>> model = ContinuousQAC(64, 6, 'reparameterization')
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6])  # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
compute_critic(inputs: Dict[str, torch.Tensor]) Dict[str, torch.Tensor][source]
Overview:

QAC forward computation graph for critic part, input observation and action tensor to predict Q-value.

Arguments:
  • inputs (Dict[str, torch.Tensor]): The dict of input data, including obs and action tensor, also contains logit and action_args tensor in hybrid action_space.

ArgumentsKeys:
  • obs: (torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data.

  • action (Union[torch.Tensor, Dict]): Continuous action with same size as action_shape.

  • logit (torch.Tensor): Discrete action logit, only in hybrid action_space.

  • action_args (torch.Tensor): Continuous action arguments, only in hybrid action_space.

Returns:
  • outputs (Dict[str, torch.Tensor]): The output dict of QAC’s forward computation graph for critic, including q_value.

ReturnKeys:
  • q_value (torch.Tensor): Q value tensor with same size as batch size.

Shapes:
  • obs (torch.Tensor): \((B, N1)\), where B is batch size and N1 is obs_shape.

  • logit (torch.Tensor): \((B, N2)\), B is batch size and N2 corresponds to action_shape.action_type_shape.

  • action_args (torch.Tensor): \((B, N3)\), B is batch size and N3 corresponds to action_shape.action_args_shape.

  • action (torch.Tensor): \((B, N4)\), where B is batch size and N4 is action_shape.

  • q_value (torch.Tensor): \((B, )\), where B is batch size.

Examples:
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
>>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, )  # q value
forward(inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) Dict[str, torch.Tensor][source]
Overview:

QAC forward computation graph, input observation tensor to predict Q-value or action logit. Different mode will forward with different network modules to get different outputs and save computation.

Arguments:
  • inputs (Union[torch.Tensor, Dict[str, torch.Tensor]]): The input data for forward computation graph, for compute_actor, it is the observation tensor, for compute_critic, it is the dict data including obs and action tensor.

  • mode (str): The forward mode, all the modes are defined in the beginning of this class.

Returns:
  • output (Dict[str, torch.Tensor]): The output dict of QAC forward computation graph, whose key-values vary in different forward modes.

Examples (Actor):
>>> # Regression mode
>>> model = ContinuousQAC(64, 6, 'regression')
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 6])
>>> # Reparameterization Mode
>>> model = ContinuousQAC(64, 6, 'reparameterization')
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6])  # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
Examples (Critic):
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
>>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, )  # q value

Benchmark

environment

best mean reward

evaluation results

config link

comparison

Halfcheetah

(Medium Expert)

13037

../_images/halfcheetah_td3bc.png

config_link_ha

d3rlpy(12124)

Walker2d

(Medium Expert)

5066

../_images/walker2d_td3bc.png

config_link_w

d3rlpy(5108)

Hopper

(Medium Expert)

3653

../_images/hopper_td3bc.png

config_link_ho

d3rlpy(3690)

environment

random

medium replay

medium expert

medium

expert

Halfcheetah

1592

5192

13037

5257

13247

Walker2d

345

1724

3653

3268

3664

Hopper

985

2317

5066

3826

5232

Note: the D4RL environment used in this benchmark can be found here.

References

  • Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021; [https://arxiv.org/abs/2106.06860 arXiv:2106.06860].

  • Scott Fujimoto, Herke van Hoof, David Meger: “Addressing Function Approximation Error in Actor-Critic Methods”, 2018; [http://arxiv.org/abs/1802.09477 arXiv:1802.09477].

Other Public Implementations

Read the Docs v: latest
Versions
latest
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.