Shortcuts

Source code for ding.policy.sql

from typing import List, Dict, Any, Tuple, Union, Optional
from collections import namedtuple, deque
import copy
import torch
from torch.distributions import Categorical
from ditk import logging
from easydict import EasyDict
from ding.torch_utils import Adam, to_device
from ding.utils.data import default_collate, default_decollate
from ding.rl_utils import q_nstep_td_data, q_nstep_sql_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from .base_policy import Policy
from .common_utils import default_preprocess_learn


[docs]@POLICY_REGISTRY.register('sql') class SQLPolicy(Policy): r""" Overview: Policy class of SQL algorithm. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='sql', # (bool) Whether to use cuda for network. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) priority=False, # (float) Reward's future discount factor, aka. gamma. discount_factor=0.97, # (int) N-step reward for target q_value estimation nstep=1, learn=dict( # How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=3, # after the batch data come into the learner, train with the data for 3 times batch_size=64, learning_rate=0.001, # ============================================================== # The following configs are algorithm-specific # ============================================================== # (int) Frequence of target network update. target_update_freq=100, # (bool) Whether ignore done(usually for max step termination env) ignore_done=False, alpha=0.1, ), # collect_mode config collect=dict( # (int) Only one of [n_sample, n_episode] shoule be set # n_sample=8, # collect 8 samples and put them in collector # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, ), eval=dict(), # other config other=dict( # Epsilon greedy with decay. eps=dict( # (str) Decay type. Support ['exp', 'linear']. type='exp', start=0.95, end=0.1, # (int) Decay length(env step) decay=10000, ), replay_buffer=dict(replay_buffer_size=10000, ) ), ) def default_model(self) -> Tuple[str, List[str]]: """ Overview: Return this algorithm default model setting for demonstration. Returns: - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names .. note:: The user can define and use customized network model but must obey the same inferface definition indicated \ by import_names path. For DQN, ``ding.model.template.q_learning.DQN`` """ return 'dqn', ['ding.model.template.q_learning'] def _init_learn(self) -> None: r""" Overview: Learn mode init method. Called by ``self.__init__``. Init the optimizer, algorithm config, main and target models. """ self._priority = self._cfg.priority # Optimizer self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) self._gamma = self._cfg.discount_factor self._nstep = self._cfg.nstep self._alpha = self._cfg.learn.alpha # use wrapper instead of plugin self._target_model = copy.deepcopy(self._model) self._target_model = model_wrap( self._target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': self._cfg.learn.target_update_freq} ) self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') self._learn_model.reset() self._target_model.reset() def _forward_learn(self, data: dict) -> Dict[str, Any]: r""" Overview: Forward and backward function of learn mode. Arguments: - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ data = default_preprocess_learn( data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) if self._cuda: data = to_device(data, self._device) # ==================== # Q-learning forward # ==================== self._learn_model.train() self._target_model.train() # Current q value (main model) q_value = self._learn_model.forward(data['obs'])['logit'] with torch.no_grad(): # Target q value target_q_value = self._target_model.forward(data['next_obs'])['logit'] # Max q value action (main model) target_q_action = self._learn_model.forward(data['next_obs'])['action'] data_n = q_nstep_td_data( q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] ) value_gamma = data.get('value_gamma') loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error( data_n, self._gamma, self._cfg.learn.alpha, nstep=self._nstep, value_gamma=value_gamma ) record_target_v = record_target_v.mean() # ==================== # Q-learning update # ==================== self._optimizer.zero_grad() loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) self._optimizer.step() # ============= # after update # ============= self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr': self._optimizer.defaults['lr'], 'total_loss': loss.item(), 'priority': td_error_per_sample.abs().tolist(), 'record_value_function': record_target_v # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. # '[histogram]action_distribution': data['action'], } def _state_dict_learn(self) -> Dict[str, Any]: return { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer': self._optimizer.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) self._optimizer.load_state_dict(state_dict['optimizer']) def _init_collect(self) -> None: r""" Overview: Collect mode init method. Called by ``self.__init__``. Init traj and unroll length, collect model. Enable the eps_greedy_sample """ self._unroll_len = self._cfg.collect.unroll_len self._gamma = self._cfg.discount_factor # necessary for parallel self._nstep = self._cfg.nstep # necessary for parallel self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') self._collect_model.reset() def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: r""" Overview: Forward function for collect mode with eps_greedy Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. Returns: - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. ReturnsKeys - necessary: ``action`` """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._collect_model.eval() with torch.no_grad(): output = self._collect_model.forward(data, eps=eps, alpha=self._cfg.learn.alpha) if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Overview: For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ or some continuous transitions(DRQN). Arguments: - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ format as the return value of ``self._process_transition`` method. Returns: - samples (:obj:`dict`): The list of training samples. .. note:: We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ And the user can customize the this data processing procecure by overriding this two methods and collector \ itself. """ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) return get_train_sample(data, self._unroll_len) def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: r""" Overview: Generate dict type transition data from inputs. Arguments: - obs (:obj:`Any`): Env observation - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ (here 'obs' indicates obs after env step). Returns: - transition (:obj:`dict`): Dict type transition data. """ transition = { 'obs': obs, 'next_obs': timestep.obs, 'action': model_output['action'], 'reward': timestep.reward, 'done': timestep.done, } return transition def _init_eval(self) -> None: r""" Overview: Evaluate mode init method. Called by ``self.__init__``. Init eval model with argmax strategy. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: r""" Overview: Forward function of eval mode, similar to ``self._forward_collect``. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. Returns: - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. ReturnsKeys - necessary: ``action`` """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() with torch.no_grad(): output = self._eval_model.forward(data) if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: return super()._monitor_vars_learn() + ['record_value_function']

© Copyright 2021, OpenDILab Contributors. Revision 069ece72.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.