Shortcuts

ding.reward_model.base_reward_model 源代码

from abc import ABC, abstractmethod
from typing import Dict
from easydict import EasyDict
from ditk import logging
import os
import copy
from typing import Any
from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file


[文档]class BaseRewardModel(ABC): """ Overview: the base class of reward model Interface: ``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date`` """ @classmethod def default_config(cls: type) -> EasyDict: cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg
[文档] @abstractmethod def estimate(self, data: list) -> Any: """ Overview: estimate reward Arguments: - data (:obj:`List`): the list of data used for estimation Returns / Effects: - This can be a side effect function which updates the reward value - If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward """ raise NotImplementedError()
[文档] @abstractmethod def train(self, data) -> None: """ Overview: Training the reward model Arguments: - data (:obj:`Any`): Data used for training Effects: - This is mostly a side effect function which updates the reward model """ raise NotImplementedError()
[文档] @abstractmethod def collect_data(self, data) -> None: """ Overview: Collecting training data in designated formate or with designated transition. Arguments: - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) Returns / Effects: - This can be a side effect function which updates the data attribute in ``self`` """ raise NotImplementedError()
[文档] @abstractmethod def clear_data(self) -> None: """ Overview: Clearing training data. \ This can be a side effect function which clears the data attribute in ``self`` """ raise NotImplementedError()
[文档] def load_expert_data(self, data) -> None: """ Overview: Getting the expert data, usually used in inverse RL reward model Arguments: - data (:obj:`Any`): Expert data Effects: This is mostly a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) """ pass
def reward_deepcopy(self, train_data) -> Any: """ Overview: this method deepcopy reward part in train_data, and other parts keep shallow copy to avoid the reward part of train_data in the replay buffer be incorrectly modified. Arguments: - train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy. """ train_data_reward_deepcopy = [ {k: copy.deepcopy(v) if k == 'reward' else v for k, v in sample.items()} for sample in train_data ] return train_data_reward_deepcopy def state_dict(self) -> Dict: # this method should be overrided by subclass. return {} def load_state_dict(self, _state_dict) -> None: # this method should be overrided by subclass. pass def save(self, path: str = None, name: str = 'best'): if path is None: path = self.cfg.exp_name path = os.path.join(path, 'reward_model', 'ckpt') if not os.path.exists(path): try: os.makedirs(path) except FileExistsError: pass path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name)) state_dict = self.state_dict() save_file(path, state_dict) logging.info('Saved reward model ckpt in {}'.format(path))
def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel: # noqa """ Overview: Reward Estimation Model. Arguments: - cfg (:obj:`Dict`): Training config - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary Returns: - reward (:obj:`Any`): The reward model """ cfg = copy.deepcopy(cfg) if 'import_names' in cfg: import_module(cfg.pop('import_names')) if hasattr(cfg, 'reward_model'): reward_model_type = cfg.reward_model.pop('type') else: reward_model_type = cfg.pop('type') return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger) def get_reward_model_cls(cfg: EasyDict) -> type: import_module(cfg.get('import_names', [])) return REWARD_MODEL_REGISTRY.get(cfg.type)