Shortcuts

ding.utils

autolog

Please refer to ding/utils/autolog for more details.

TimeMode

class ding.utils.autolog.TimeMode(value)[源代码]
Overview:

Mode that used to decide the format of range_values function

ABSOLUTE: use absolute time RELATIVE_LIFECYCLE: use relative time based on property’s lifecycle RELATIVE_CURRENT_TIME: use relative time based on current time

ABSOLUTE = 0
RELATIVE_CURRENT_TIME = 2
RELATIVE_LIFECYCLE = 1

RangedData

class ding.utils.autolog.RangedData(expire: float, use_pickle: bool = False)[源代码]
Overview:

A data structure that can store data for a period of time.

Interfaces:

__init__, append, extend, current, history, expire, __bool__, _get_time.

Properties:
  • expire (float): The expire time.

__append(time_: float, data: _Tp)
Overview:

Append the data.

__append_item(time_: float, data: _Tp)
Overview:

Append the data item.

Arguments:
  • time_ (float): The time.

  • data (_Tp): The data item.

__check_expire()
Overview:

Check the expire time.

__check_time(time_: float)
Overview:

Check the time.

Arguments:
  • time_ (float): The time.

__current()
Overview:

Get the current data.

__flush_history()
Overview:

Flush the history data.

__get_data_item(data_id: int) _Tp
Overview:

Get the data item.

Arguments:
  • data_id (int): The data id.

__history()
Overview:

Get the history data.

__history_yield()
Overview:

Yield the history data.

__init__(expire: float, use_pickle: bool = False)[源代码]
Overview:

Initialize the RangedData object.

Arguments:
  • expire (float): The expire time of the data.

  • use_pickle (bool): Whether to use pickle to serialize the data.

__registry_data_item(data: _Tp) int
Overview:

Registry the data item.

Arguments:
  • data (_Tp): The data item.

__remove_data_item(data_id: int)
Overview:

Remove the data item.

Arguments:
  • data_id (int): The data id.

_abc_impl = <_abc._abc_data object>
abstract _get_time() float[源代码]
Overview:

Get the current time.

append(data: _Tp)[源代码]
Overview:

Append the data.

current() _Tp[源代码]
Overview:

Get the current data.

property expire: float
Overview:

Get the expire time.

extend(iter_: Iterable[_Tp])[源代码]
Overview:

Extend the data.

history() List[Tuple[int | float, _Tp]][源代码]
Overview:

Get the history data.

TimeRangedData

class ding.utils.autolog.TimeRangedData(time_: BaseTime, expire: float)[源代码]
Overview:

A data structure that can store data for a period of time.

Interfaces:

__init__, _get_time, append, extend, current, history, expire, __bool__.

Properties:
  • time (BaseTime): The time.

  • expire (float): The expire time.

__init__(time_: BaseTime, expire: float)[源代码]
Overview:

Initialize the TimeRangedData object.

Arguments:
_abc_impl = <_abc._abc_data object>
_get_time() float[源代码]
Overview:

Get the current time.

property time
Overview:

Get the time.

LoggedModel

class ding.utils.autolog.LoggedModel(time_: _TimeObjectType, expire: _TimeType)[源代码]
Overview:

A model with timeline (integered time, such as 1st, 2nd, 3rd, can also be modeled as a kind of self-defined discrete time, such as the implement of TickTime). Serveral values have association with each other can be maintained together by using LoggedModel.

Example:

Define AvgList model like this

>>> from ding.utils.autolog import LoggedValue, LoggedModel
>>> class AvgList(LoggedModel):
>>>     value = LoggedValue(float)
>>>     __property_names = ['value']
>>>
>>>     def __init__(self, time_: BaseTime, expire: Union[int, float]):
>>>         LoggedModel.__init__(self, time_, expire)
>>>         # attention, original value must be set in __init__ function, or it will not
>>>         # be activated, the timeline of this value will also be unexpectedly affected.
>>>         self.value = 0.0
>>>         self.__register()
>>>
>>>     def __register(self):
>>>         def __avg_func(prop_name: str) -> float:  # function to calculate average value of properties
>>>             records = self.range_values[prop_name]()
>>>             (_start_time, _), _ = records[0]
>>>             (_, _end_time), _ = records[-1]
>>>
>>>             _duration = _end_time - _start_time
>>>             _sum = sum([_value * (_end_time - _begin_time) for (_begin_time, _end_time), _value in records])
>>>
>>>             return _sum / _duration
>>>
>>>         for _prop_name in self.__property_names:
>>>             self.register_attribute_value('avg', _prop_name, partial(__avg_func, prop_name=_prop_name))

Use it like this

>>> from ding.utils.autolog import NaturalTime, TimeMode
>>>
>>> if __name__ == "__main__":
>>>     _time = NaturalTime()
>>>     ll = AvgList(_time, expire=10)
>>>
>>>     # just do something here ...
>>>
>>>     print(ll.range_values['value']()) # original range_values function in LoggedModel of last 10 secs
>>>     print(ll.range_values['value'](TimeMode.ABSOLUTE))  # use absolute time
>>>     print(ll.avg['value']())  # average value of last 10 secs
Interfaces:

__init__, time, expire, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__, get_property_attribute

Property:
  • time (BaseTime): The time.

  • expire (float): The expire time.

__get_property_ranged_data(name: str) TimeRangedData
Overview:

Get ranged data of one property.

Arguments:
  • name (str): The property name.

__get_range_values_func(name: str)
Overview:

Get range_values function of one property.

Arguments:
  • name (str): The property name.

__init__(time_: _TimeObjectType, expire: _TimeType)[源代码]
Overview:

Initialize the LoggedModel object using the given arguments.

Arguments:
__init_properties()
Overview:

Initialize all properties.

property __properties: List[str]
Overview:

Get all property names.

__register_default_funcs()
Overview:

Register default functions.

_abc_impl = <_abc._abc_data object>
current_time() float | int[源代码]
Overview:

Get current time (real time that regardless of time proxy’s frozen statement)

Returns:

int or float: current time

property expire: _TimeType
Overview:

Get expire time

Returns:

int or float: time that old value records expired

fixed_time() float | int[源代码]
Overview:

Get fixed time (will be frozen time if time proxy is frozen) This feature can be useful when adding value replay feature (in the future)

Returns:

int or float: fixed time

freeze()[源代码]
Overview:

Freeze time proxy object. This feature can be useful when adding value replay feature (in the future)

get_property_attribute(property_name: str) List[str][源代码]
Overview:

Find all registered attributes (except common “range_values” attribute, since “range_values” is not added to self.__prop2attr) of one given property.

Arguments:
  • property_name (str): name of property to query attributes

Returns:
  • attr_list (List[str]): the registered attributes list of the input property

register_attribute_value(attribute_name: str, property_name: str, value: Any)[源代码]
Overview:

Register a new attribute for one of the values. Example can be found in overview of class.

Arguments:
  • attribute_name (str): name of attribute

  • property_name (str): name of property

  • value (Any): value of attribute

property time: _TimeObjectType
Overview:

Get original time object passed in, can execute method (such as step()) by this property.

Returns:

BaseTime: time object used by this model

unfreeze()[源代码]
Overview:

Unfreeze time proxy object. This feature can be useful when adding value replay feature (in the future)

BaseTime

class ding.utils.autolog.BaseTime[源代码]
Overview:

Abstract time interface

Interfaces:

time

_abc_impl = <_abc._abc_data object>
abstract time() int | float[源代码]
Overview:

Get time information

Returns:
  • time(float, int): time information

NaturalTime

class ding.utils.autolog.NaturalTime[源代码]
Overview:

Natural time object

Interfaces:

__init__, time

Example:
>>> from ding.utils.autolog.time_ctl import NaturalTime
>>> time_ = NaturalTime()
__init__()[源代码]
_abc_impl = <_abc._abc_data object>
time() float[源代码]
Overview:

Get current natural time (float format, unix timestamp)

Returns:
  • time(float): unix timestamp

Example:
>>> from ding.utils.autolog.time_ctl import NaturalTime
>>> time_ = NaturalTime()
>>> time_.time()
1603896383.8811457

TickTime

class ding.utils.autolog.TickTime(init: int = 0)[源代码]
Overview:

Tick time object

Interfaces:

__init__, step, time

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime()
__init__(init: int = 0)[源代码]
Overview:

Constructor of TickTime

Arguments:
  • init (int): initial time, default is 0

_abc_impl = <_abc._abc_data object>
step(delta: int = 1) int[源代码]
Overview

Step the time forward for this TickTime

Arguments:
  • delta (int): steps to step forward, default is 1

Returns:
  • time (int): new time after stepping

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime(0)
>>> time_.step()
1
>>> time_.step(2)
3
time() int[源代码]
Overview

Get current tick time

Returns:

int: current tick time

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime(0)
>>> time_.step()
>>> time_.time()
1

TimeProxy

class ding.utils.autolog.TimeProxy(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]
Overview:

Proxy of time object, it can freeze time, sometimes useful when reproducing. This object is thread-safe, and also freeze and unfreeze operation is strictly ordered.

Interfaces:

__init__, freeze, unfreeze, time, current_time

Examples:
>>> from ding.utils.autolog.time_ctl import TickTime, TimeProxy
>>> tick_time_ = TickTime()
>>> time_ = TimeProxy(tick_time_)
>>> tick_time_.step()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
1 1 1
>>> time_.freeze()
>>> tick_time_.step()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
2 1 2
>>> time_.unfreeze()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
2 2 2
__init__(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]
Overview:

Constructor for Time proxy

Arguments:
  • time_ (BaseTime): another time object it based on

  • frozen (bool): this object will be frozen immediately if true, otherwise not, default is False

  • lock_type (LockContextType): type of the lock, default is THREAD_LOCK

_abc_impl = <_abc._abc_data object>
current_time() int | float[源代码]
Overview:

Get current time (will not be frozen time)

Returns:

int or float: current time

freeze()[源代码]
Overview:

Freeze this time proxy

property is_frozen: bool
Overview:

Get if this time proxy object is frozen

Returns:

bool: true if it is frozen, otherwise false

time() int | float[源代码]
Overview:

Get time (may be frozen time)

Returns:

int or float: the time

unfreeze()[源代码]
Overview:

Unfreeze this time proxy

LoggedValue

class ding.utils.autolog.LoggedValue(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[源代码]
Overview:

LoggedValue can be used as property in LoggedModel, for it has __get__ and __set__ method. This class’s instances will be associated with their owner LoggedModel instance, all the LoggedValue of one LoggedModel will shared the only one time object (defined in time_ctl), so that timeline can be managed properly.

Interfaces:

__init__, __get__, __set__

Properties:
  • __property_name (str): The name of the property.

__get_ranged_data(instance) TimeRangedData
Overview:

Get the ranged data.

Interfaces:

__get_ranged_data

__init__(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[源代码]
Overview:

Initialize the LoggedValue object.

Interfaces:

__init__

property __property_name
Overview:

Get the name of the property.

data.structure

Please refer to ding/utils/data/structure for more details.

Cache

class ding.utils.data.structure.Cache(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False)[源代码]
Overview:

Data cache for reducing concurrent pressure, with timeout and full queue eject mechanism

Interfaces:

__init__, push_data, get_cached_data_iter, run, close

Property:

remain_data_count

__init__(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False) None[源代码]
Overview:

Initialize the cache object.

Arguments:
  • maxlen (int): Maximum length of the cache queue.

  • timeout (float): Maximum second of the data can remain in the cache.

  • monitor_interval (float): Interval of the timeout monitor thread checks the time.

  • _debug (bool): Whether to use debug mode or not, which enables debug print info.

_timeout_monitor() None[源代码]
Overview:

The workflow of the timeout monitor thread.

_warn_if_timeout() bool[源代码]
Overview:

Return whether is timeout.

Returns
  • result: (bool) Whether is timeout.

close() None[源代码]
Overview:

Shut down the cache internal thread and send the end flag to send queue’s iterator.

dprint(s: str) None[源代码]
Overview:

In debug mode, print debug str.

Arguments:
  • s (str): Debug info to be printed.

get_cached_data_iter() callable_iterator[源代码]
Overview:

Get the iterator of the send queue. Once a data is pushed into send queue, it can be accessed by this iterator. ‘STOP’ is the end flag of this iterator.

Returns:
  • iterator (callable_iterator) The send queue iterator.

push_data(data: Any) None[源代码]
Overview:

Push data into receive queue, if the receive queue is full(after push), then push all the data in receive queue into send queue.

Arguments:
  • data (Any): The data which needs to be added into receive queue

小技巧

thread-safe

property remain_data_count: int
Overview:

Return receive queue’s remain data count

Returns:
  • count (int): The size of the receive queue.

run() None[源代码]
Overview:

Launch the cache internal thread, e.g. timeout monitor thread.

LifoDeque

class ding.utils.data.structure.LifoDeque(maxsize=0)[源代码]
Overview:

Like LifoQueue, but automatically replaces the oldest data when the queue is full.

Interfaces:

_init, _put, _get

_init(maxsize)[源代码]

data.base_dataloader

Please refer to ding/utils/data/base_dataloader for more details.

IDataLoader

class ding.utils.data.base_dataloader.IDataLoader[源代码]
Overview:

Base class of data loader

Interfaces:

__init__, __next__, __iter__, _get_data, close

_get_data(batch_size: int | None = None) List[Tensor][源代码]
Overview:

Get one batch data

Arguments:
  • batch_size (Optional[int]): sometimes, batch_size is specified by each iteration, if batch_size is None, use default batch_size value

close() None[源代码]
Overview:

Close data loader

data.collate_fn

Please refer to ding/utils/data/collate_fn for more details.

ttorch_collate

ding.utils.data.collate_fn.ttorch_collate(x, json: bool = False, cat_1dim: bool = True)[源代码]
Overview:

Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested dictionary of tensors.

Arguments:
  • x : The input list of tensors or nested dictionaries of tensors.

  • json (bool): If True, converts the output to JSON format. Defaults to False.

  • cat_1dim (bool): If True, concatenates tensors with shape (B, 1) along the last dimension. Defaults to True.

Returns:

The collated output tensor or nested dictionary of tensors.

Examples:
>>> # case 1: Collate a list of tensors
>>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
>>> collated = ttorch_collate(tensors)
collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> # case 2: Collate a nested dictionary of tensors
>>> nested_dict = {
        'a': torch.tensor([1, 2, 3]),
        'b': torch.tensor([4, 5, 6]),
        'c': torch.tensor([7, 8, 9])
    }
>>> collated = ttorch_collate(nested_dict)
collated = {
    'a': torch.tensor([1, 2, 3]),
    'b': torch.tensor([4, 5, 6]),
    'c': torch.tensor([7, 8, 9])
}
>>> # case 3: Collate a list of nested dictionaries of tensors
>>> nested_dicts = [
        {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])},
        {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])}
    ]
>>> collated = ttorch_collate(nested_dicts)
collated = {
    'a': torch.tensor([[1, 2, 3], [7, 8, 9]]),
    'b': torch.tensor([[4, 5, 6], [10, 11, 12]])
}

default_collate

ding.utils.data.collate_fn.default_collate(batch: Sequence, cat_1dim: bool = True, ignore_prefix: list = ['collate_ignore']) Tensor | Mapping | Sequence[源代码]
Overview:

Put each data field into a tensor with outer dimension batch size.

Arguments:
  • batch (Sequence): A data sequence, whose length is batch size, whose element is one piece of data.

  • cat_1dim (bool): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True.

  • ignore_prefix (list): A list of prefixes to ignore when collating dictionaries, defaults to [‘collate_ignore’].

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): the collated data, with batch size into each data field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].

Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> default_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> default_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = default_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

timestep_collate

ding.utils.data.collate_fn.timestep_collate(batch: List[Dict[str, Any]]) Dict[str, Tensor | list][源代码]
Overview:

Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep.

Arguments:
  • batch(List[Dict[str, Any]]): A list of dictionaries with length B, where each dictionary represents a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the data field and the value is a sequence of torch.Tensor objects with any shape.

Returns:
  • ret(Dict[str, Union[torch.Tensor, list]]): The collated data, with the timestep and batch size incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, …].

Examples:
>>> batch = [
        {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]},
        {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]}
    ]
>>> collated_data = timestep_collate(batch)
>>> print(collated_data['data'].shape)
torch.Size([2, 2, 3])

diff_shape_collate

ding.utils.data.collate_fn.diff_shape_collate(batch: Sequence) Tensor | Mapping | Sequence[源代码]
Overview:

Collates a batch of data with different shapes. This function is similar to default_collate, but it allows tensors in the batch to have None values, which is common in StarCraft observations.

Arguments:
  • batch (Sequence): A sequence of data, where each element is a piece of data.

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): The collated data, with the batch size applied to each data field. The return type depends on the original element type and can be a torch.Tensor, Mapping, or Sequence.

Examples:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> diff_shape_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> diff_shape_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = diff_shape_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

default_decollate

ding.utils.data.collate_fn.default_decollate(batch: Tensor | Sequence | Mapping, ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state']) List[Any][源代码]
Overview:

Drag out batch_size collated data’s batch size to decollate it, which is the reverse operation of default_collate.

Arguments:
  • batch (Union[torch.Tensor, Sequence, Mapping]): The collated data batch. It can be a tensor, sequence, or mapping.

  • ignore(List[str]): A list of names to be ignored. Only applicable if the input batch is a dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to [‘prev_state’, ‘prev_actor_state’, ‘prev_critic_state’].

Returns:
  • ret (List[Any]): A list with B elements, where B is the batch size.

Examples:
>>> batch = {
    'a': [
        [1, 2, 3],
        [4, 5, 6]
    ],
    'b': [
        [7, 8, 9],
        [10, 11, 12]
    ]}
>>> default_decollate(batch)
{
    0: {'a': [1, 2, 3], 'b': [7, 8, 9]},
    1: {'a': [4, 5, 6], 'b': [10, 11, 12]},
}

data.dataloader

Please refer to ding/utils/data/dataloader for more details.

AsyncDataLoader

class ding.utils.data.dataloader.AsyncDataLoader(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0)[源代码]
Overview:

An asynchronous dataloader.

Interfaces:

__init__, __iter__, __next__, _get_data, _async_loop, _worker_loop, _cuda_loop, _get_data, close

__init__(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0) None[源代码]
Overview:

Init dataloader with input parameters. If data_source is dict, data will only be processed in get_data_thread and put into async_train_queue. If data_source is Callable, data will be processed by implementing functions, and can be sorted in two types:

  • num_workers == 0 or 1: Only main worker will process it and put into async_train_queue.

  • num_workers > 1: Main worker will divide a job into several pieces, push every job into job_queue; Then slave workers get jobs and implement; Finally they will push procesed data into async_train_queue.

At the last step, if device contains “cuda”, data in async_train_queue will be transferred to cuda_queue for uer to access.

Arguments:
  • data_source (Union[Callable, dict]): The data source, e.g. function to be implemented(Callable), replay buffer’s real data(dict), etc.

  • batch_size (int): Batch size.

  • device (str): Device.

  • chunk_size (int): The size of a chunked piece in a batch, should exactly divide batch_size, only function when there are more than 1 worker.

  • collate_fn (Callable): The function which is used to collate batch size into each data field.

  • num_workers (int): Number of extra workers. 0 or 1 means only 1 main worker and no extra ones, i.e. Multiprocessing is disabled. More than 1 means multiple workers implemented by multiprocessing are to processs data respectively.

_async_loop(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[源代码]
Overview:

Main worker process. Run through self.async_process. Firstly, get data from self.get_data_thread. If multiple workers, put data in self.job_queue for further multiprocessing operation; If only one worker, process data and put directly into self.async_train_queue.

Arguments:
  • p (tm.multiprocessing.connection): Parent connection.

  • c (tm.multiprocessing.connection): Child connection.

_cuda_loop() None[源代码]
Overview:

Only when using cuda, would this be run as a thread through self.cuda_thread. Get data from self.async_train_queue, change its device and put it into self.cuda_queue

_get_data(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[源代码]
Overview:

Init dataloader with input parameters. Will run as a thread through self.get_data_thread.

Arguments:
  • p (tm.multiprocessing.connection): Parent connection.

  • c (tm.multiprocessing.connection): Child connection.

_worker_loop() None[源代码]
Overview:

Worker process. Run through each element in list self.worker. Get data job from self.job_queue, process it and then put into self.async_train_queue. Only function when self.num_workers > 1, which means using multiprocessing.

close() None[源代码]
Overview:

Delete this dataloader. First set end_flag to True, which means different processes/threads will clear and close all data queues; Then all processes will be terminated and joined.

data.dataset

Please refer to ding/utils/data/dataset for more details.

DatasetStatistics

class ding.utils.data.dataset.DatasetStatistics(mean: ndarray, std: ndarray, action_bounds: ndarray)[源代码]
Overview:

Dataset statistics.

__init__(mean: ndarray, std: ndarray, action_bounds: ndarray) None
action_bounds: ndarray
mean: ndarray
std: ndarray

NaiveRLDataset

class ding.utils.data.dataset.NaiveRLDataset(cfg)[源代码]
Overview:

Naive RL dataset, which is used for offline RL algorithms.

Interfaces:

__init__, __len__, __getitem__

__getitem__(idx: int) Dict[str, Tensor][源代码]
Overview:

Get the item of the dataset.

__init__(cfg) None[源代码]
Overview:

Initialization method.

Arguments:
  • cfg (dict): Config dict.

__len__() int[源代码]
Overview:

Get the length of the dataset.

D4RLDataset

class ding.utils.data.dataset.D4RLDataset(cfg: dict)[源代码]
Overview:

D4RL dataset, which is used for offline RL algorithms.

Interfaces:

__init__, __len__, __getitem__

Properties:
  • mean (np.ndarray): Mean of the dataset.

  • std (np.ndarray): Std of the dataset.

  • action_bounds (np.ndarray): Action bounds of the dataset.

  • statistics (dict): Statistics of the dataset.

__getitem__(idx: int) Dict[str, Tensor][源代码]
Overview:

Get the item of the dataset.

__init__(cfg: dict) None[源代码]
Overview:

Initialization method.

Arguments:
  • cfg (dict): Config dict.

__len__() int[源代码]
Overview:

Get the length of the dataset.

_cal_statistics(dataset, env, eps=0.001, add_action_buffer=True)[源代码]
Overview:

Calculate the statistics of the dataset.

Arguments:
  • dataset (Dict[str, np.ndarray]): The d4rl dataset.

  • env (gym.Env): The environment.

  • eps (float): Epsilon.

_load_d4rl(dataset: Dict[str, ndarray]) None[源代码]
Overview:

Load the d4rl dataset.

Arguments:
  • dataset (Dict[str, np.ndarray]): The d4rl dataset.

_normalize_states(dataset)[源代码]
Overview:

Normalize the states.

Arguments:
  • dataset (Dict[str, np.ndarray]): The d4rl dataset.

property action_bounds: ndarray
Overview:

Get the action bounds of the dataset.

property data: List
property mean
Overview:

Get the mean of the dataset.

property statistics: dict
Overview:

Get the statistics of the dataset.

property std
Overview:

Get the std of the dataset.

HDF5Dataset

class ding.utils.data.dataset.HDF5Dataset(cfg: dict)[源代码]
Overview:

HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. The hdf5 format is a common format for storing large numerical arrays in Python. For more details, please refer to https://support.hdfgroup.org/HDF5/.

Interfaces:

__init__, __len__, __getitem__

Properties:
  • mean (np.ndarray): Mean of the dataset.

  • std (np.ndarray): Std of the dataset.

  • action_bounds (np.ndarray): Action bounds of the dataset.

  • statistics (dict): Statistics of the dataset.

__getitem__(idx: int) Dict[str, Tensor][源代码]
Overview:

Get the item of the dataset.

Arguments:
  • idx (int): The index of the dataset.

__init__(cfg: dict) None[源代码]
Overview:

Initialization method.

Arguments:
  • cfg (dict): Config dict.

__len__() int[源代码]
Overview:

Get the length of the dataset.

_cal_statistics(eps: float = 0.001)[源代码]
Overview:

Calculate the statistics of the dataset.

Arguments:
  • eps (float): Epsilon.

_load_data(dataset: Dict[str, ndarray]) None[源代码]
Overview:

Load the dataset.

Arguments:
  • dataset (Dict[str, np.ndarray]): The dataset.

_normalize_states()[源代码]
Overview:

Normalize the states.

property action_bounds: ndarray
Overview:

Get the action bounds of the dataset.

property mean
Overview:

Get the mean of the dataset.

property statistics: dict
Overview:

Get the statistics of the dataset.

property std
Overview:

Get the std of the dataset.

D4RLTrajectoryDataset

class ding.utils.data.dataset.D4RLTrajectoryDataset(cfg: dict)[源代码]
Overview:

D4RL trajectory dataset, which is used for offline RL algorithms.

Interfaces:

__init__, __len__, __getitem__

D4RL_DATASET_STATS = {'halfcheetah-medium-expert-v2': {'state_mean': [-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, -0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, 8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, 0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314], 'state_std': [0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, 0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, 1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, 6.4811787605285645, 6.378620147705078]}, 'halfcheetah-medium-replay-v2': {'state_mean': [-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, -0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, 3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, -0.015839405357837677], 'state_std': [0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, 0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, 1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, 6.085654258728027, 7.25300407409668]}, 'halfcheetah-medium-v2': {'state_mean': [-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, -0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, 5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, 0.013382787816226482], 'state_std': [0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, 0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, 1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, 5.671932697296143, 7.4982590675354]}, 'hopper-medium-expert-v2': {'state_mean': [1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, 0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, -0.1766270101070404, -0.11862941086292267, -0.12097819894552231], 'state_std': [0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, 0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, 5.725032806396484]}, 'hopper-medium-replay-v2': {'state_mean': [1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, 0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, -0.5287045240402222, -0.14465883374214172, -0.19652697443962097], 'state_std': [0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, 1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, 5.108601093292236]}, 'hopper-medium-v2': {'state_mean': [1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, 2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, -0.18540096282958984, -0.28461286425590515], 'state_std': [0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, 0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, 5.607253551483154]}, 'walker2d-medium-expert-v2': {'state_mean': [1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, 0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, 3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, -0.27366524934768677], 'state_std': [0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, 0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, 1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, 4.039782524108887, 5.891613960266113]}, 'walker2d-medium-replay-v2': {'state_mean': [1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, -0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, -0.08934258669614792, -0.2992438077926636, -0.5984178185462952], 'state_std': [0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, 0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, 2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, 3.845186948776245, 5.4768385887146]}, 'walker2d-medium-v2': {'state_mean': [1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, -0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, 0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654], 'state_std': [0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, 0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, 1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, 3.7445690631866455, 5.5851287841796875]}}
REF_MAX_SCORE = {'halfcheetah': 12135.0, 'hopper': 3234.3, 'walker2d': 4592.3}
REF_MIN_SCORE = {'halfcheetah': -280.178953, 'hopper': -20.272305, 'walker2d': 1.629008}
__getitem__(idx: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor][源代码]
Overview:

Get the item of the dataset.

Arguments:
  • idx (int): The index of the dataset.

__init__(cfg: dict) None[源代码]
Overview:

Initialization method.

Arguments:
  • cfg (dict): Config dict.

__len__() int[源代码]
Overview:

Get the length of the dataset.

get_d4rl_dataset_stats(env_d4rl_name: str) Dict[str, list][源代码]
Overview:

Get the d4rl dataset stats.

Arguments:
  • env_d4rl_name (str): The d4rl env name.

get_max_timestep() int[源代码]
Overview:

Get the max timestep of the dataset.

get_state_stats() Tuple[ndarray, ndarray][源代码]
Overview:

Get the state mean and std of the dataset.

D4RLDiffuserDataset

class ding.utils.data.dataset.D4RLDiffuserDataset(dataset_path: str, context_len: int, rtg_scale: float)[源代码]
Overview:

D4RL diffuser dataset, which is used for offline RL algorithms.

Interfaces:

__init__, __len__, __getitem__

__init__(dataset_path: str, context_len: int, rtg_scale: float) None[源代码]
Overview:

Initialization method of D4RLDiffuserDataset.

Arguments:
  • dataset_path (str): The dataset path.

  • context_len (int): The length of the context.

  • rtg_scale (float): The scale of the returns to go.

FixedReplayBuffer

class ding.utils.data.dataset.FixedReplayBuffer(data_dir: str, replay_suffix: int, *args, **kwargs)[源代码]
Overview:

Object composed of a list of OutofGraphReplayBuffers.

Interfaces:

__init__, get_transition_elements, sample_transition_batch

__init__(data_dir: str, replay_suffix: int, *args, **kwargs)[源代码]
Overview:

Initialize the FixedReplayBuffer class.

Arguments:
  • data_dir (str): Log directory from which to load the replay buffer.

  • replay_suffix (int): If not None, then only load the replay buffer corresponding to the specific suffix in data directory.

  • args (list): Arbitrary extra arguments.

  • kwargs (dict): Arbitrary keyword arguments.

_load_buffer(suffix)[源代码]
Overview:

Loads a OutOfGraphReplayBuffer replay buffer.

Arguments:
  • suffix (int): The suffix of the replay buffer.

get_transition_elements()[源代码]
Overview:

Returns the transition elements.

load_single_buffer(suffix)[源代码]
Overview:

Load a single replay buffer.

Arguments:
  • suffix (int): The suffix of the replay buffer.

sample_transition_batch(batch_size=None, indices=None)[源代码]
Overview:

Returns a batch of transitions (including any extra contents).

Arguments:
  • batch_size (int): The batch size.

  • indices (list): The indices of the batch.

PCDataset

class ding.utils.data.dataset.PCDataset(all_data)[源代码]
Overview:

Dataset for Procedure Cloning.

Interfaces:

__init__, __len__, __getitem__

__getitem__(item)[源代码]
Overview:

Get the item of the dataset.

Arguments:
  • item (int): The index of the dataset.

__init__(all_data)[源代码]
Overview:

Initialization method of PCDataset.

Arguments:
  • all_data (tuple): The tuple of all data.

__len__()[源代码]
Overview:

Get the length of the dataset.

load_bfs_datasets

ding.utils.data.dataset.load_bfs_datasets(train_seeds=1, test_seeds=5)[源代码]
Overview:

Load BFS datasets.

Arguments:
  • train_seeds (int): The number of train seeds.

  • test_seeds (int): The number of test seeds.

BCODataset

class ding.utils.data.dataset.BCODataset(data=None)[源代码]
Overview:

Dataset for Behavioral Cloning from Observation.

Interfaces:

__init__, __len__, __getitem__

Properties:
  • obs (np.ndarray): The observation array.

  • action (np.ndarray): The action array.

__getitem__(idx)[源代码]
Overview:

Get the item of the dataset.

Arguments:
  • idx (int): The index of the dataset.

__init__(data=None)[源代码]
Overview:

Initialization method of BCODataset.

Arguments:
  • data (dict): The data dict.

__len__()[源代码]
Overview:

Get the length of the dataset.

property action
Overview:

Get the action array.

property obs
Overview:

Get the observation array.

SequenceDataset

class ding.utils.data.dataset.SequenceDataset(cfg)[源代码]
Overview:

Dataset for diffuser.

Interfaces:

__init__, __len__, __getitem__

__getitem__(idx, eps=0.0001)[源代码]
Overview:

Get the item of the dataset.

Arguments:
  • idx (int): The index of the dataset.

  • eps (float): The epsilon.

__init__(cfg)[源代码]
Overview:

Initialization method of SequenceDataset.

Arguments:
  • cfg (dict): The config dict.

__len__()[源代码]
Overview:

Get the length of the dataset.

_get_bounds()[源代码]
Overview:

Get the bounds of the dataset.

get_conditions(observations)[源代码]
Overview:

Get the conditions on current observation for planning.

Arguments:
  • observations (np.ndarray): The observation array.

make_indices(path_lengths, horizon)[源代码]
Overview:

Make indices for sampling from dataset. Each index maps to a datapoint.

Arguments:
  • path_lengths (np.ndarray): The path length array.

  • horizon (int): The horizon.

maze2d_set_terminals(env, dataset)[源代码]
Overview:

Set the terminals for maze2d.

Arguments:
  • env (gym.Env): The gym env.

  • dataset (dict): The dataset dict.

normalize(keys=['observations', 'actions'])[源代码]
Overview:

Normalize the dataset, normalize fields that will be predicted by the diffusion model

Arguments:
  • keys (list): The list of keys.

normalize_value(value)[源代码]
Overview:

Normalize the value.

Arguments:
  • value (np.ndarray): The value array.

process_maze2d_episode(episode)[源代码]
Overview:

Process the maze2d episode, adds in next_observations field to episode.

Arguments:
  • episode (dict): The episode dict.

sequence_dataset(env, dataset=None)[源代码]
Overview:

Sequence the dataset.

Arguments:
  • env (gym.Env): The gym env.

hdf5_save

ding.utils.data.dataset.hdf5_save(exp_data, expert_data_path)[源代码]
Overview:

Save the data to hdf5.

naive_save

ding.utils.data.dataset.naive_save(exp_data, expert_data_path)[源代码]
Overview:

Save the data to pickle.

offline_data_save_type

ding.utils.data.dataset.offline_data_save_type(exp_data, expert_data_path, data_type='naive')[源代码]
Overview:

Save the offline data.

create_dataset

ding.utils.data.dataset.create_dataset(cfg, **kwargs) Dataset[源代码]
Overview:

Create dataset.

bfs_helper

Please refer to ding/utils/bfs_helper for more details.

get_vi_sequence

ding.utils.bfs_helper.get_vi_sequence(env: Env, observation: ndarray) Tuple[ndarray, List][源代码]
Overview:

Given an instance of the maze environment and the current observation, using Broad-First-Search (BFS) algorithm to plan an optimal path and record the result.

Arguments:
  • env (Env): The instance of the maze environment.

  • observation (np.ndarray): The current observation.

Returns:
  • output (Tuple[np.ndarray, List]): The BFS result. output[0] contains the BFS map after each iteration and output[1] contains the optimal actions before reaching the finishing point.

collection_helper

Please refer to ding/utils/collection_helper for more details.

iter_mapping

ding.utils.collection_helper.iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType])[源代码]
Overview:

Map a list of iterable elements to input iteration callable

Arguments:
  • iter_(_IterType list): The list for iteration

  • mapping (Callable [[_IterType], _IterTargetType]): A callable that maps iterable elements function.

Return:
  • (iter_mapping object): Iteration results

Example:
>>> iterable_list = [1, 2, 3, 4, 5]
>>> _iter = iter_mapping(iterable_list, lambda x: x ** 2)
>>> print(list(_iter))
[1, 4, 9, 16, 25]

compression_helper

Please refer to ding/utils/compression_helper for more details.

CloudPickleWrapper

class ding.utils.compression_helper.CloudPickleWrapper(data: Any)[源代码]
Overview:

CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression).

Interfaces:

__init__, __getstate__, __setstate__.

__init__(data: Any) None[源代码]
Overview:

Initialize the CloudPickleWrapper using the given arguments.

Arguments:
  • data (Any): The object to be dumped.

dummy_compressor

ding.utils.compression_helper.dummy_compressor(data: Any) Any[源代码]
Overview:

Return the raw input data.

Arguments:
  • data (Any): The input data of the compressor.

Returns:
  • output (Any): This compressor will exactly return the input data.

zlib_data_compressor

ding.utils.compression_helper.zlib_data_compressor(data: Any) bytes[源代码]
Overview:

Takes the input compressed data and return the compressed original data (zlib compressor) in binary format.

Arguments:
  • data (Any): The input data of the compressor.

Returns:
  • output (bytes): The compressed byte-like result.

Examples:
>>> zlib_data_compressor("Hello")

lz4_data_compressor

ding.utils.compression_helper.lz4_data_compressor(data: Any) bytes[源代码]
Overview:

Return the compressed original data (lz4 compressor).The compressor outputs in binary format.

Arguments:
  • data (Any): The input data of the compressor.

Returns:
  • output (bytes): The compressed byte-like result.

Examples:
>>> lz4.block.compress(pickle.dumps("Hello"))
b'R€•      ŒHello”.'

jpeg_data_compressor

ding.utils.compression_helper.jpeg_data_compressor(data: ndarray) bytes[源代码]
Overview:

To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the buffer. This function encodes the observation numpy arr to the jpeg strings.

Arguments:
  • data (np.array): the observation numpy arr.

Returns:
  • img_str (bytes): The compressed byte-like result.

get_data_compressor

ding.utils.compression_helper.get_data_compressor(name: str)[源代码]
Overview:

Get the data compressor according to the input name.

Arguments:
  • name(str): Name of the compressor, support ['lz4', 'zlib', 'jpeg', 'none']

Return:
  • compressor (Callable): Corresponding data_compressor, taking input data returning compressed data.

Example:
>>> compress_fn = get_data_compressor('lz4')
>>> compressed_data = compressed(input_data)

dummy_decompressor

ding.utils.compression_helper.dummy_decompressor(data: Any) Any[源代码]
Overview:

Return the input data.

Arguments:
  • data (Any): The input data of the decompressor.

Returns:
  • output (bytes): The decompressed result, which is exactly the input.

lz4_data_decompressor

ding.utils.compression_helper.lz4_data_decompressor(compressed_data: bytes) Any[源代码]
Overview:

Return the decompressed original data (lz4 compressor).

Arguments:
  • data (bytes): The input data of the decompressor.

Returns:
  • output (Any): The decompressed object.

zlib_data_decompressor

ding.utils.compression_helper.zlib_data_decompressor(compressed_data: bytes) Any[源代码]
Overview:

Return the decompressed original data (zlib compressor).

Arguments:
  • data (bytes): The input data of the decompressor.

Returns:
  • output (Any): The decompressed object.

jpeg_data_decompressor

ding.utils.compression_helper.jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) ndarray[源代码]
Overview:

To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the buffer. This function decodes the observation numpy arr from the jpeg strings.

Arguments:
  • compressed_data (bytes): The jpeg strings.

  • gray_scale (bool): If the observation is gray, gray_scale=True,

    if the observation is RGB, gray_scale=False.

Returns:
  • arr (np.ndarray): The decompressed numpy array.

get_data_decompressor

ding.utils.compression_helper.get_data_decompressor(name: str) Callable[源代码]
Overview:

Get the data decompressor according to the input name.

Arguments:
  • name(str): Name of the decompressor, support ['lz4', 'zlib', 'none']

备注

For all the decompressors, the input of a bytes-like object is required.

Returns:
  • decompressor (Callable): Corresponding data decompressor.

Examples:
>>> decompress_fn = get_data_decompressor('lz4')
>>> origin_data = compressed(compressed_data)

default_helper

Please refer to ding/utils/default_helper for more details.

get_shape0

ding.utils.default_helper.get_shape0(data: List | Dict | Tensor | Tensor) int[源代码]
Overview:

Get shape[0] of data’s torch tensor or treetensor

Arguments:
  • data (Union[List,Dict,torch.Tensor,ttorch.Tensor]): data to be analysed

Returns:
  • shape[0] (int): first dimension length of data, usually the batchsize.

lists_to_dicts

ding.utils.default_helper.lists_to_dicts(data: List[dict | NamedTuple] | Tuple[dict | NamedTuple], recursive: bool = False) Mapping[object, object] | NamedTuple[源代码]
Overview:

Transform a list of dicts to a dict of lists.

Arguments:
  • data (Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]):

    A dict of lists need to be transformed

  • recursive (bool): whether recursively deals with dict element

Returns:
  • newdata (Union[Mapping[object, object], NamedTuple]): A list of dicts as a result

Example:
>>> from ding.utils import *
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
{1: [1, 2], 10: [3, 4]}

dicts_to_lists

ding.utils.default_helper.dicts_to_lists(data: Mapping[object, List[object]]) List[Mapping[object, object]][源代码]
Overview:

Transform a dict of lists to a list of dicts.

Arguments:
  • data (Mapping[object, list]): A list of dicts need to be transformed

Returns:
  • newdata (List[Mapping[object, object]]): A dict of lists as a result

Example:
>>> from ding.utils import *
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
[{1: 1, 10: 3}, {1: 2, 10: 4}]

override

ding.utils.default_helper.override(cls: type) Callable[[Callable], Callable][源代码]
Overview:

Annotation for documenting method overrides.

Arguments:
  • cls (type): The superclass that provides the overridden method. If this

    cls does not actually have the method, an error is raised.

squeeze

ding.utils.default_helper.squeeze(data: object) object[源代码]
Overview:

Squeeze data from tuple, list or dict to single object

Arguments:
  • data (object): data to be squeezed

Example:
>>> a = (4, )
>>> a = squeeze(a)
>>> print(a)
>>> 4

default_get

ding.utils.default_helper.default_get(data: dict, name: str, default_value: Any | None = None, default_fn: Callable | None = None, judge_fn: Callable | None = None) Any[源代码]
Overview:

Getting the value by input, checks generically on the inputs with at least data and name. If name exists in data, get the value at name; else, add name to default_get_set with value generated by default_fn (or directly as default_value) that is checked by `` judge_fn`` to be legal.

Arguments:
  • data(dict): Data input dictionary

  • name(str): Key name

  • default_value(Optional[Any]) = None,

  • default_fn(Optional[Callable]) = Value

  • judge_fn(Optional[Callable]) = None

Returns:
  • ret(list): Splitted data

  • residual(list): Residule list

list_split

ding.utils.default_helper.list_split(data: list, step: int) List[list][源代码]
Overview:

Split list of data by step.

Arguments:
  • data(list): List of data for spliting

  • step(int): Number of step for spliting

Returns:
  • ret(list): List of splitted data.

  • residual(list): Residule list. This value is None when data divides steps.

Example:
>>> list_split([1,2,3,4],2)
([[1, 2], [3, 4]], None)
>>> list_split([1,2,3,4],3)
([[1, 2, 3]], [4])

error_wrapper

ding.utils.default_helper.error_wrapper(fn, default_ret, warning_msg='')[源代码]
Overview:

wrap the function, so that any Exception in the function will be catched and return the default_ret

Arguments:
  • fn (Callable): the function to be wraped

  • default_ret (obj): the default return when an Exception occurred in the function

Returns:
  • wrapper (Callable): the wrapped function

Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
>>> def get_rank():  # Get the rank of linklink model, return 0 if use FakeLink.
>>>    if is_fake_link:
>>>        return 0
>>>    return error_wrapper(link.get_rank, 0)()

LimitedSpaceContainer

class ding.utils.default_helper.LimitedSpaceContainer(min_val: int, max_val: int)[源代码]
Overview:

A space simulator.

Interfaces:

__init__, get_residual_space, release_space

__init__(min_val: int, max_val: int) None[源代码]
Overview:

Set min_val and max_val of the container, also set cur to min_val for initialization.

Arguments:
  • min_val (int): Min volume of the container, usually 0.

  • max_val (int): Max volume of the container.

acquire_space() bool[源代码]
Overview:

Try to get one pice of space. If there is one, return True; Otherwise return False.

Returns:
  • flag (bool): Whether there is any piece of residual space.

decrease_space() None[源代码]
Overview:

Decrease one piece in space. Decrement max_val.

get_residual_space() int[源代码]
Overview:

Get all residual pieces of space. Set cur to max_val

Arguments:
  • ret (int): Residual space, calculated by max_val - cur.

increase_space() None[源代码]
Overview:

Increase one piece in space. Increment max_val.

release_space() None[源代码]
Overview:

Release only one piece of space. Decrement cur, but ensure it won’t be negative.

deep_merge_dicts

ding.utils.default_helper.deep_merge_dicts(original: dict, new_dict: dict) dict[源代码]
Overview:

Merge two dicts by calling deep_update

Arguments:
  • original (dict): Dict 1.

  • new_dict (dict): Dict 2.

Returns:
  • merged_dict (dict): A new dict that is d1 and d2 deeply merged.

deep_update

ding.utils.default_helper.deep_update(original: dict, new_dict: dict, new_keys_allowed: bool = False, whitelist: List[str] | None = None, override_all_if_type_changes: List[str] | None = None)[源代码]
Overview:

Update original dict with values from new_dict recursively.

Arguments:
  • original (dict): Dictionary with default values.

  • new_dict (dict): Dictionary with values to be updated

  • new_keys_allowed (bool): Whether new keys are allowed.

  • whitelist (Optional[List[str]]):

    List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level.

  • override_all_if_type_changes(Optional[List[str]]):

    List of top level keys with value=dict, for which we always simply override the entire value (dict), if the “type” key in that value dict changes.

备注

If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the whitelist, then new subkeys can be introduced.

flatten_dict

ding.utils.default_helper.flatten_dict(data: dict, delimiter: str = '/') dict[源代码]
Overview:

Flatten the dict, see example

Arguments:
  • data (dict): Original nested dict

  • delimiter (str): Delimiter of the keys of the new dict

Returns:
  • data (dict): Flattened nested dict

Example:
>>> a
{'a': {'b': 100}}
>>> flatten_dict(a)
{'a/b': 100}

set_pkg_seed

ding.utils.default_helper.set_pkg_seed(seed: int, use_cuda: bool = True) None[源代码]
Overview:

Side effect function to set seed for random, numpy random, and torch's manual seed. This is usaually used in entry scipt in the section of setting random seed for all package and instance

Argument:
  • seed(int): Set seed

  • use_cuda(bool) Whether use cude

Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py
>>> ...
# Set random seed for all package and instance
>>> collector_env.seed(seed)
>>> evaluator_env.seed(seed, dynamic_seed=False)
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
>>> ...
# Set up RL Policy, etc.
>>> ...

one_time_warning

ding.utils.default_helper.one_time_warning(warning_msg: str) None[源代码]
Overview:

Print warning message only once.

Arguments:
  • warning_msg (str): Warning message.

split_fn

ding.utils.default_helper.split_fn(data, indices, start, end)[源代码]
Overview:

Split data by indices

Arguments:
  • data (Union[List, Dict, torch.Tensor, ttorch.Tensor]): data to be analysed

  • indices (np.ndarray): indices to split

  • start (int): start index

  • end (int): end index

split_data_generator

ding.utils.default_helper.split_data_generator(data: dict, split_size: int, shuffle: bool = True) dict[源代码]
Overview:

Split data into batches

Arguments:
  • data (dict): data to be analysed

  • split_size (int): split size

  • shuffle (bool): whether shuffle

RunningMeanStd

class ding.utils.default_helper.RunningMeanStd(epsilon=0.0001, shape=(), device=device(type='cpu'))[源代码]
Overview:

Wrapper to update new variable, new mean, and new count

Interfaces:

__init__, update, reset, new_shape

Properties:
  • mean, std, _epsilon, _shape, _mean, _var, _count

__init__(epsilon=0.0001, shape=(), device=device(type='cpu'))[源代码]
Overview:

Initialize self. See help(type(self)) for accurate signature; setup the properties.

Arguments:
  • env (gym.Env): the environment to wrap.

  • epsilon (Float): the epsilon used for self for the std output

  • shape (:obj: np.array): the np array shape used for the expression of this wrapper on attibutes of mean and variance

property mean: ndarray
Overview:

Property mean gotten from self._mean

static new_shape(obs_shape, act_shape, rew_shape)[源代码]
Overview:

Get new shape of observation, acton, and reward; in this case unchanged.

Arguments:

obs_shape (Any), act_shape (Any), rew_shape (Any)

Returns:

obs_shape (Any), act_shape (Any), rew_shape (Any)

reset()[源代码]
Overview:

Resets the state of the environment and reset properties: _mean, _var, _count

property std: ndarray
Overview:

Property std calculated from self._var and the epsilon value of self._epsilon

update(x)[源代码]
Overview:

Update mean, variable, and count

Arguments:
  • x: the batch

make_key_as_identifier

ding.utils.default_helper.make_key_as_identifier(data: Dict[str, Any]) Dict[str, Any][源代码]
Overview:

Make the key of dict into legal python identifier string so that it is compatible with some python magic method such as __getattr.

Arguments:
  • data (Dict[str, Any]): The original dict data.

Return:
  • new_data (Dict[str, Any]): The new dict data with legal identifier keys.

remove_illegal_item

ding.utils.default_helper.remove_illegal_item(data: Dict[str, Any]) Dict[str, Any][源代码]
Overview:

Remove illegal item in dict info, like str, which is not compatible with Tensor.

Arguments:
  • data (Dict[str, Any]): The original dict data.

Return:
  • new_data (Dict[str, Any]): The new dict data without legal items.

design_helper

Please refer to ding/utils/design_helper for more details.

SingletonMetaclass

class ding.utils.design_helper.SingletonMetaclass(name, bases, namespace, **kwargs)[源代码]
Overview:

Returns the given type instance in input class

Interfaces:

__call__

instances = {<class 'ding.framework.parallel.Parallel'>: <ding.framework.parallel.Parallel object>}

fast_copy

Please refer to ding/utils/fast_copy for more details.

_FastCopy

class ding.utils.fast_copy._FastCopy[源代码]
Overview:

The idea of this class comes from this article https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. We use recursive calls to copy each object that needs to be copied, which will be 5x faster than copy.deepcopy.

Interfaces:

__init__, _copy_list, _copy_dict, _copy_tensor, _copy_ndarray, copy.

__init__()[源代码]
Overview:

Initialize the _FastCopy object.

_copy_dict(d: dict) dict[源代码]
Overview:

Copy the dict.

Arguments:
  • d (dict): The dict to be copied.

_copy_list(l: List) dict[源代码]
Overview:

Copy the list.

Arguments:
  • l (List): The list to be copied.

_copy_ndarray(a: ndarray) ndarray[源代码]
Overview:

Copy the ndarray.

Arguments:
  • a (np.ndarray): The ndarray to be copied.

_copy_tensor(t: Tensor) Tensor[源代码]
Overview:

Copy the tensor.

Arguments:
  • t (torch.Tensor): The tensor to be copied.

copy(sth: Any) Any[源代码]
Overview:

Copy the object.

Arguments:
  • sth (Any): The object to be copied.

file_helper

Please refer to ding/utils/file_helper for more details.

read_from_ceph

ding.utils.file_helper.read_from_ceph(path: str) object[源代码]
Overview:

Read file from ceph

Arguments:
  • path (str): File path in ceph, start with "s3://"

Returns:
  • (data): Deserialized data

_get_redis

ding.utils.file_helper._get_redis(host='localhost', port=6379)[源代码]
Overview:

Ensures redis usage

Arguments:
  • host (str): Host string

  • port (int): Port number

Returns:
  • (Redis(object)): Redis object with given host, port, and db=0

read_from_redis

ding.utils.file_helper.read_from_redis(path: str) object[源代码]
Overview:

Read file from redis

Arguments:
  • path (str): Dile path in redis, could be a string key

Returns:
  • (data): Deserialized data

_ensure_rediscluster

ding.utils.file_helper._ensure_rediscluster(startup_nodes=[{'host': '127.0.0.1', 'port': '7000'}])[源代码]
Overview:

Ensures redis usage

Arguments:
  • List of startup nodes (dict) of
    • host (str): Host string

    • port (int): Port number

Returns:
  • (RedisCluster(object)): RedisCluster object with given host, port, and False for decode_responses in default.

read_from_rediscluster

ding.utils.file_helper.read_from_rediscluster(path: str) object[源代码]
Overview:

Read file from rediscluster

Arguments:
  • path (str): Dile path in rediscluster, could be a string key

Returns:
  • (data): Deserialized data

read_from_file

ding.utils.file_helper.read_from_file(path: str) object[源代码]
Overview:

Read file from local file system

Arguments:
  • path (str): File path in local file system

Returns:
  • (data): Deserialized data

_ensure_memcached

ding.utils.file_helper._ensure_memcached()[源代码]
Overview:

Ensures memcache usage

Returns:
  • (MemcachedClient instance): MemcachedClient’s class instance built with current memcached_client’s server_list.conf and client.conf files

read_from_mc

ding.utils.file_helper.read_from_mc(path: str, flush=False) object[源代码]
Overview:

Read file from memcache, file must be saved by torch.save()

Arguments:
  • path (str): File path in local system

Returns:
  • (data): Deserialized data

read_from_path

ding.utils.file_helper.read_from_path(path: str)[源代码]
Overview:

Read file from ceph

Arguments:
  • path (str): File path in ceph, start with "s3://", or use local file system

Returns:
  • (data): Deserialized data

save_file_ceph

ding.utils.file_helper.save_file_ceph(path, data)[源代码]
Overview:

Save pickle dumped data file to ceph

Arguments:
  • path (str): File path in ceph, start with "s3://", use file system when not

  • data (Any): Could be dict, list or tensor etc.

save_file_redis

ding.utils.file_helper.save_file_redis(path, data)[源代码]
Overview:

Save pickle dumped data file to redis

Arguments:
  • path (str): File path (could be a string key) in redis

  • data (Any): Could be dict, list or tensor etc.

save_file_rediscluster

ding.utils.file_helper.save_file_rediscluster(path, data)[源代码]
Overview:

Save pickle dumped data file to rediscluster

Arguments:
  • path (str): File path (could be a string key) in redis

  • data (Any): Could be dict, list or tensor etc.

read_file

ding.utils.file_helper.read_file(path: str, fs_type: str | None = None, use_lock: bool = False) object[源代码]
Overview:

Read file from path

Arguments:
  • path (str): The path of file to read

  • fs_type (str or None): The file system type, support {'normal', 'ceph'}

  • use_lock (bool): Whether use_lock is in local normal file system

save_file

ding.utils.file_helper.save_file(path: str, data: object, fs_type: str | None = None, use_lock: bool = False) None[源代码]
Overview:

Save data to file of path

Arguments:
  • path (str): The path of file to save to

  • data (object): The data to save

  • fs_type (str or None): The file system type, support {'normal', 'ceph'}

  • use_lock (bool): Whether use_lock is in local normal file system

remove_file

ding.utils.file_helper.remove_file(path: str, fs_type: str | None = None) None[源代码]
Overview:

Remove file

Arguments:
  • path (str): The path of file you want to remove

  • fs_type (str or None): The file system type, support {'normal', 'ceph'}

import_helper

Please refer to ding/utils/import_helper for more details.

try_import_ceph

ding.utils.import_helper.try_import_ceph()[源代码]
Overview:

Try import ceph module, if failed, return None

Returns:
  • (Module): Imported module, or None when ceph not found

try_import_mc

ding.utils.import_helper.try_import_mc()[源代码]
Overview:

Try import mc module, if failed, return None

Returns:
  • (Module): Imported module, or None when mc not found

try_import_redis

ding.utils.import_helper.try_import_redis()[源代码]
Overview:

Try import redis module, if failed, return None

Returns:
  • (Module): Imported module, or None when redis not found

try_import_rediscluster

ding.utils.import_helper.try_import_rediscluster()[源代码]
Overview:

Try import rediscluster module, if failed, return None

Returns:
  • (Module): Imported module, or None when rediscluster not found

import_module

ding.utils.import_helper.import_module(modules: List[str]) None[源代码]
Overview:

Import several module as a list

Arguments:
  • (str list): List of module names

k8s_helper

Please refer to ding/utils/k8s_helper for more details.

get_operator_server_kwargs

ding.utils.k8s_helper.get_operator_server_kwargs(cfg: EasyDict) dict[源代码]
Overview:

Get kwarg dict from config file

Arguments:
  • cfg (EasyDict) System config

Returns:
  • result (dict) Containing api_version, namespace, name, port, host.

exist_operator_server

ding.utils.k8s_helper.exist_operator_server() bool[源代码]
Overview:

Check if the ‘KUBERNETES_SERVER_URL’ environment variable exists.

pod_exec_command

ding.utils.k8s_helper.pod_exec_command(kubeconfig: str, name: str, namespace: str, cmd: str) Tuple[int, str][源代码]
Overview:

Execute command in pod

Arguments:
  • kubeconfig (str) The path of kubeconfig file

  • name (str) The name of pod

  • namespace (str) The namespace of pod

K8sType

class ding.utils.k8s_helper.K8sType(value)[源代码]

An enumeration.

K3s = 2
Local = 1

K8sLauncher

class ding.utils.k8s_helper.K8sLauncher(config_path: str)[源代码]
Overview:

object to manage the K8s cluster

Interfaces:

__init__, _load, create_cluster, _check_k3d_tools, delete_cluster, preload_images

__init__(config_path: str) None[源代码]
Overview:

Initialize the K8sLauncher object.

Arguments:
  • config_path (str): The path of the config file.

_check_k3d_tools() None[源代码]
Overview:

Check if the k3d tools exist.

_load(config_path: str) None[源代码]
Overview:

Load the config file.

Arguments:
  • config_path (str): The path of the config file.

create_cluster() None[源代码]
Overview:

Create the k8s cluster.

delete_cluster() None[源代码]
Overview:

Delete the k8s cluster.

preload_images(images: list) None[源代码]
Overview:

Preload images.

lock_helper

Please refer to ding/utils/lock_helper for more details.

LockContextType

class ding.utils.lock_helper.LockContextType(value)[源代码]
Overview:

Enum to express the type of the lock.

PROCESS_LOCK = 2
THREAD_LOCK = 1

LockContext

class ding.utils.lock_helper.LockContext(lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]
Overview:

Generate a LockContext in order to make sure the thread safety.

Interfaces:

__init__, __enter__, __exit__.

Example:
>>> with LockContext() as lock:
>>>     print("Do something here.")
__init__(lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]
Overview:

Init the lock according to the given type.

Arguments:
  • lock_type (LockContextType): The type of lock to be used. Defaults to LockContextType.THREAD_LOCK.

acquire()[源代码]
Overview:

Acquires the lock.

release()[源代码]
Overview:

Releases the lock.

get_rw_file_lock

ding.utils.lock_helper.get_rw_file_lock(name: str, op: str)[源代码]
Overview:

Get generated file lock with name and operator

Arguments:
  • name (str): Lock’s name.

  • op (str): Assigned operator, i.e. read or write.

Returns:
  • (RWLockFairD): Generated rwlock

FcntlContext

class ding.utils.lock_helper.FcntlContext(lock_path: str)[源代码]
Overview:

A context manager that acquires an exclusive lock on a file using fcntl. This is useful for preventing multiple processes from running the same code.

Interfaces:

__init__, __enter__, __exit__.

Example:
>>> lock_path = "/path/to/lock/file"
>>> with FcntlContext(lock_path) as lock:
>>>    # Perform operations while the lock is held
__init__(lock_path: str) None[源代码]
Overview:

Initialize the LockHelper object.

Arguments:
  • lock_path (str): The path to the lock file.

get_file_lock

ding.utils.lock_helper.get_file_lock(name: str, op: str) FcntlContext[源代码]
Overview:

Acquires a file lock for the specified file.

Arguments:
  • name (str): The name of the file.

  • op (str): The operation to perform on the file lock.

log_helper

Please refer to ding/utils/log_helper for more details.

build_logger

ding.utils.log_helper.build_logger(path: str, name: str | None = None, need_tb: bool = True, need_text: bool = True, text_level: int | str = 20) Tuple[Logger | None, SummaryWriter | None][源代码]
Overview:

Build text logger and tensorboard logger.

Arguments:
  • path (str): Logger(Textlogger & SummaryWriter)’s saved dir

  • name (str): The logger file name

  • need_tb (bool): Whether SummaryWriter instance would be created and returned

  • need_text (bool): Whether loggingLogger instance would be created and returned

  • text_level (int` or str): Logging level of logging.Logger, default set to logging.INFO

Returns:
  • logger (Optional[logging.Logger]): Logger that displays terminal output

  • tb_logger (Optional['SummaryWriter']): Saves output to tfboard, only return when need_tb.

TBLoggerFactory

class ding.utils.log_helper.TBLoggerFactory[源代码]
Overview:

TBLoggerFactory is a factory class for SummaryWriter.

Interfaces:

create_logger

Properties:
  • tb_loggers (Dict[str, SummaryWriter]): A dict that stores SummaryWriter instances.

classmethod create_logger(logdir: str) DistributedWriter[源代码]
tb_loggers = {}

LoggerFactory

class ding.utils.log_helper.LoggerFactory[源代码]
Overview:

LoggerFactory is a factory class for logging.Logger.

Interfaces:

create_logger, get_tabulate_vars, get_tabulate_vars_hor

classmethod create_logger(path: str, name: str = 'default', level: int | str = 20) Logger[源代码]
Overview:

Create logger using logging

Arguments:
  • name (str): Logger’s name

  • path (str): Logger’s save dir

  • level (int or str): Used to set the level. Reference: Logger.setLevel method.

Returns:
  • (logging.Logger): new logging logger

static get_tabulate_vars(variables: Dict[str, Any]) str[源代码]
Overview:

Get the text description in tabular form of all vars

Arguments:
  • variables (List[str]): Names of the vars to query.

Returns:
  • string (str): Text description in tabular form of all vars

static get_tabulate_vars_hor(variables: Dict[str, Any]) str[源代码]
Overview:

Get the text description in tabular form of all vars

Arguments:
  • variables (List[str]): Names of the vars to query.

pretty_print

ding.utils.log_helper.pretty_print(result: dict, direct_print: bool = True) str[源代码]
Overview:

Print a dict result in a pretty way

Arguments:
  • result (dict): The result to print

  • direct_print (bool): Whether to print directly

Returns:
  • string (str): The pretty-printed result in str format

log_writer_helper

Please refer to ding/utils/log_writer_helper for more details.

DistributedWriter

class ding.utils.log_writer_helper.DistributedWriter(*args, **kwargs)[源代码]
Overview:

A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. The best way is to use it in conjunction with the router to take advantage of the message and event components of the router (see writer.plugin).

Interfaces:

get_instance, plugin, initialize, __del__

__del__()[源代码]
Overview:

Close the file writer.

classmethod get_instance(*args, **kwargs) DistributedWriter[源代码]
Overview:

Get instance and set the root level instance on the first called. If args and kwargs is none, this method will return root instance.

Arguments:
  • args (Tuple): The arguments passed to the __init__ function of the parent class, SummaryWriter.

  • kwargs (Dict): The keyword arguments passed to the __init__ function of the parent class, SummaryWriter.

plugin(router: Parallel, is_writer: bool = False) DistributedWriter[源代码]
Overview:

Plugin router, so when using this writer with active router, it will automatically send requests to the main writer instead of writing it to the disk. So we can collect data from multiple processes and write them into one file.

Arguments:
  • router (Parallel): The router to be plugged in.

  • is_writer (bool): Whether this writer is the main writer.

Examples:
>>> DistributedWriter().plugin(router, is_writer=True)

enable_parallel

ding.utils.log_writer_helper.enable_parallel(fn_name, fn)[源代码]
Overview:

Decorator to enable parallel writing.

Arguments:
  • fn_name (str): The name of the function to be called.

  • fn (Callable): The function to be called.

normalizer_helper

Please refer to ding/utils/normalizer_helper for more details.

DatasetNormalizer

class ding.utils.normalizer_helper.DatasetNormalizer(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[源代码]
Overview:

The DatasetNormalizer class provides functionality to normalize and unnormalize data in a dataset. It takes a dataset as input and applies a normalizer function to each key in the dataset.

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[源代码]
Overview:

Initialize the NormalizerHelper object.

Arguments:
  • dataset (np.ndarray): The dataset to be normalized.

  • normalizer (str): The type of normalizer to be used. Can be a string representing the name of the normalizer class.

  • path_lengths (list): The length of the paths in the dataset. Defaults to None.

normalize(x: ndarray, key: str) ndarray[源代码]
Overview:

Normalize the input data using the specified key.

Arguments:
  • x (np.ndarray): The input data to be normalized.

  • key (:obj`str`): The key to identify the normalizer.

Returns:
  • ret (np.ndarray): The normalized value of the input data.

unnormalize(x: ndarray, key: str) ndarray[源代码]
Overview:

Unnormalizes the given value x using the specified key.

Arguments:
  • x (np.ndarray): The value to be unnormalized.

  • key (:obj`str`): The key to identify the normalizer.

Returns:
  • ret (np.ndarray): The unnormalized value.

flatten

ding.utils.normalizer_helper.flatten(dataset: dict, path_lengths: list) dict[源代码]
Overview:

Flattens dataset of { key: [ n_episodes x max_path_length x dim ] } to { key : [ (n_episodes * sum(path_lengths)) x dim ] }

Arguments:
  • dataset (dict): The dataset to be flattened.

  • path_lengths (list): A list of path lengths for each episode.

Returns:
  • flattened (dict): The flattened dataset.

Normalizer

class ding.utils.normalizer_helper.Normalizer(X)[源代码]
Overview:

Parent class, subclass by defining the normalize and unnormalize methods

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X)[源代码]
Overview:

Initialize the Normalizer object.

Arguments:
  • X (np.ndarray): The data to be normalized.

normalize(*args, **kwargs)[源代码]
Overview:

Normalize the input data.

Arguments:
  • args (list): The arguments passed to the normalize function.

  • kwargs (dict): The keyword arguments passed to the normalize function.

unnormalize(*args, **kwargs)[源代码]
Overview:

Unnormalize the input data.

Arguments:
  • args (list): The arguments passed to the unnormalize function.

  • kwargs (dict): The keyword arguments passed to the unnormalize function.

GaussianNormalizer

class ding.utils.normalizer_helper.GaussianNormalizer(*args, **kwargs)[源代码]
Overview:

A class that normalizes data to zero mean and unit variance.

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(*args, **kwargs)[源代码]
Overview:

Initialize the GaussianNormalizer object.

Arguments:
  • args (list): The arguments passed to the __init__ function of the parent class, i.e., the Normalizer class.

  • kwargs (dict): The keyword arguments passed to the __init__ function of the parent class, i.e., the Normalizer class.

normalize(x: ndarray) ndarray[源代码]
Overview:

Normalize the input data.

Arguments:
  • x (np.ndarray): The input data to be normalized.

Returns:
  • ret (np.ndarray): The normalized data.

unnormalize(x: ndarray) ndarray[源代码]
Overview:

Unnormalize the input data.

Arguments:
  • x (np.ndarray): The input data to be unnormalized.

Returns:
  • ret (np.ndarray): The unnormalized data.

CDFNormalizer

class ding.utils.normalizer_helper.CDFNormalizer(X)[源代码]
Overview:

A class that makes training data uniform (over each dimension) by transforming it with marginal CDFs.

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X)[源代码]
Overview:

Initialize the CDFNormalizer object.

Arguments:
  • X (np.ndarray): The data to be normalized.

normalize(x: ndarray) ndarray[源代码]
Overview:

Normalizes the input data.

Arguments:
  • x (np.ndarray): The input data.

Returns:
  • ret (np.ndarray): The normalized data.

unnormalize(x: ndarray) ndarray[源代码]
Overview:

Unnormalizes the input data.

Arguments:
  • x (np.ndarray): The input data.

Returns:
  • ret (np.ndarray):: The unnormalized data.

wrap(fn_name: str, x: ndarray) ndarray[源代码]
Overview:

Wraps the given function name and applies it to the input data.

Arguments:
  • fn_name (str): The name of the function to be applied.

  • x (np.ndarray): The input data.

Returns:
  • ret: The output of the function applied to the input data.

CDFNormalizer1d

class ding.utils.normalizer_helper.CDFNormalizer1d(X: ndarray)[源代码]
Overview:

CDF normalizer for a single dimension. This class provides methods to normalize and unnormalize data using the Cumulative Distribution Function (CDF) approach.

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X: ndarray)[源代码]
Overview:

Initialize the CDFNormalizer1d object.

Arguments:
  • X (np.ndarray): The data to be normalized.

normalize(x: ndarray) ndarray[源代码]
Overview:

Normalize the input data.

Arguments:
  • x (np.ndarray): The data to be normalized.

Returns:
  • ret (np.ndarray): The normalized data.

unnormalize(x: ndarray, eps: float = 0.0001) ndarray[源代码]
Overview:

Unnormalize the input data.

Arguments:
  • x (np.ndarray): The data to be unnormalized.

  • eps (float): A small value used for numerical stability. Defaults to 1e-4.

Returns:
  • ret (np.ndarray): The unnormalized data.

empirical_cdf

ding.utils.normalizer_helper.empirical_cdf(sample: ~numpy.ndarray) -> (<class 'numpy.ndarray'>, <class 'numpy.ndarray'>)[源代码]
Overview:

Compute the empirical cumulative distribution function (CDF) of a given sample.

Arguments:
  • sample (np.ndarray): The input sample for which to compute the empirical CDF.

Returns:
  • quantiles (np.ndarray): The unique values in the sample.

  • cumprob (np.ndarray): The cumulative probabilities corresponding to the quantiles.

References:

atleast_2d

ding.utils.normalizer_helper.atleast_2d(x: ndarray) ndarray[源代码]
Overview:

Ensure that the input array has at least two dimensions.

Arguments:
  • x (np.ndarray): The input array.

Returns:
  • ret (np.ndarray): The input array with at least two dimensions.

LimitsNormalizer

class ding.utils.normalizer_helper.LimitsNormalizer(X)[源代码]
Overview:

A class that normalizes and unnormalizes values within specified limits. This class maps values within the range [xmin, xmax] to the range [-1, 1].

Interfaces:

__init__, __repr__, normalize, unnormalize.

normalize(x: ndarray) ndarray[源代码]
Overview:

Normalizes the input values.

Argments:
  • x (np.ndarray): The input values to be normalized.

Returns:
  • ret (np.ndarray): The normalized values.

unnormalize(x: ndarray, eps: float = 0.0001) ndarray[源代码]
Overview:

Unnormalizes the input values.

Arguments:
  • x (np.ndarray): The input values to be unnormalized.

  • eps (float): A small value used for clipping. Defaults to 1e-4.

Returns:
  • ret (np.ndarray): The unnormalized values.

orchestrator_launcher

Please refer to ding/utils/orchestrator_launcher for more details.

OrchestratorLauncher

class ding.utils.orchestrator_launcher.OrchestratorLauncher(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack')[源代码]
Overview:

Object to manage di-orchestrator in existing k8s cluster

Interfaces:

__init__, create_orchestrator, delete_orchestrator

__init__(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack') None[源代码]
Overview:

Initialize the OrchestratorLauncher object.

Arguments:
  • version (str): The version of di-orchestrator.

  • name (str): The name of di-orchestrator.

  • cluster (K8sLauncher): The k8s cluster to deploy di-orchestrator.

  • registry (str): The docker registry to pull images.

  • cert_manager_version (str): The version of cert-manager.

  • cert_manager_registry (str): The docker registry to pull cert-manager images.

_check_kubectl_tools() None[源代码]
Overview:

Check if kubectl tools is installed.

create_orchestrator() None[源代码]
Overview:

Create di-orchestrator in k8s cluster.

delete_orchestrator() None[源代码]
Overview:

Delete di-orchestrator in k8s cluster.

create_components_from_config

ding.utils.orchestrator_launcher.create_components_from_config(config: str) None[源代码]
Overview:

Create components from config file.

Arguments:
  • config (str): The config file.

wait_to_be_ready

ding.utils.orchestrator_launcher.wait_to_be_ready(namespace: str, component: str, timeout: int = 120) None[源代码]
Overview:

Wait for the component to be ready.

Arguments:
  • namespace (str): The namespace of the component.

  • component (str): The name of the component.

  • timeout (int): The timeout of waiting.

profiler_helper

Please refer to ding/utils/profiler_helper for more details.

Profiler

class ding.utils.profiler_helper.Profiler[源代码]
Overview:

A class for profiling code execution. It can be used as a context manager or a decorator.

Interfaces:

__init__, mkdir, write_profile, profile.

__init__()[源代码]
Overview:

Initialize the Profiler object.

mkdir(directory: str)[源代码]
OverView:

Create a directory if it doesn’t exist.

Arguments:
  • directory (str): The path of the directory to be created.

profile(folder_path='./tmp')[源代码]
OverView:

Enable profiling and save the results to files.

Arguments:
  • folder_path (str): The path of the folder where the profiling files will be saved. Defaults to “./tmp”.

write_profile(pr: Profile, folder_path: str)[源代码]
OverView:

Write the profiling results to files.

Arguments:
  • pr (cProfile.Profile): The profiler object containing the profiling results.

  • folder_path (str): The path of the folder where the profiling files will be saved.

pytorch_ddp_dist_helper

Please refer to ding/utils/pytorch_ddp_dist_helper for more details.

get_rank

ding.utils.pytorch_ddp_dist_helper.get_rank() int[源代码]
Overview:

Get the rank of current process in total world_size

get_world_size

ding.utils.pytorch_ddp_dist_helper.get_world_size() int[源代码]
Overview:

Get the world_size(total process number in data parallel training)

allreduce

ding.utils.pytorch_ddp_dist_helper.allreduce(x: Tensor) None[源代码]
Overview:

All reduce the tensor x in the world

Arguments:
  • x (torch.Tensor): the tensor to be reduced

allreduce_async

ding.utils.pytorch_ddp_dist_helper.allreduce_async(name: str, x: Tensor) None[源代码]
Overview:

All reduce the tensor x in the world asynchronously

Arguments:
  • name (str): the name of the tensor

  • x (torch.Tensor): the tensor to be reduced

reduce_data

ding.utils.pytorch_ddp_dist_helper.reduce_data(x: int | float | Tensor, dst: int) int | float | Tensor[源代码]
Overview:

Reduce the tensor x to the destination process dst

Arguments:
  • x (Union[int, float, torch.Tensor]): the tensor to be reduced

  • dst (int): the destination process

allreduce_data

ding.utils.pytorch_ddp_dist_helper.allreduce_data(x: int | float | Tensor, op: str) int | float | Tensor[源代码]
Overview:

All reduce the tensor x in the world

Arguments:
  • x (Union[int, float, torch.Tensor]): the tensor to be reduced

  • op (str): the operation to perform on data, support ['sum', 'avg']

get_group

ding.utils.pytorch_ddp_dist_helper.get_group(group_size: int) List[源代码]
Overview:

Get the group segmentation of group_size each group

Arguments:
  • group_size (int) the group_size

dist_mode

ding.utils.pytorch_ddp_dist_helper.dist_mode(func: Callable) Callable[源代码]
Overview:

Wrap the function so that in can init and finalize automatically before each call

Arguments:
  • func (Callable): the function to be wrapped

dist_init

ding.utils.pytorch_ddp_dist_helper.dist_init(backend: str = 'nccl', addr: str | None = None, port: str | None = None, rank: int | None = None, world_size: int | None = None) Tuple[int, int][源代码]
Overview:

Initialize the distributed training setting

Arguments:
  • backend (str): The backend of the distributed training, support ['nccl', 'gloo']

  • addr (str): The address of the master node

  • port (str): The port of the master node

  • rank (int): The rank of current process

  • world_size (int): The total number of processes

dist_finalize

ding.utils.pytorch_ddp_dist_helper.dist_finalize() None[源代码]
Overview:

Finalize distributed training resources

DDPContext

class ding.utils.pytorch_ddp_dist_helper.DDPContext[源代码]
Overview:

A context manager for linklink distribution

Interfaces:

__init__, __enter__, __exit__

__init__() None[源代码]
Overview:

Initialize the DDPContext

simple_group_split

ding.utils.pytorch_ddp_dist_helper.simple_group_split(world_size: int, rank: int, num_groups: int) List[源代码]
Overview:

Split the group according to worldsize, rank and num_groups

Arguments:
  • world_size (int): The world size

  • rank (int): The rank

  • num_groups (int): The number of groups

备注

With faulty input, raise array split does not result in an equal division

to_ddp_config

ding.utils.pytorch_ddp_dist_helper.to_ddp_config(cfg: EasyDict) EasyDict[源代码]
Overview:

Convert the config to ddp config

Arguments:
  • cfg (EasyDict): The config to be converted

registry

Please refer to ding/utils/registry for more details.

Registry

class ding.utils.registry.Registry(*args, **kwargs)[源代码]
Overview:

A helper class for managing registering modules, it extends a dictionary and provides a register functions.

Interfaces:

__init__, register, get, build, query, query_details

Examples (creating):
>>> some_registry = Registry({"default": default_module})
Examples (registering: normal way):
>>> def foo():
>>>     ...
>>> some_registry.register("foo_module", foo)
Examples (registering: decorator way):
>>> @some_registry.register("foo_module")
>>> @some_registry.register("foo_modeul_nickname")
>>> def foo():
>>>     ...
Examples (accessing):
>>> f = some_registry["foo_module"]
__init__(*args, **kwargs) None[源代码]
Overview:

Initialize the Registry object.

Arguments:
  • args (Tuple): The arguments passed to the __init__ function of the parent class, dict.

  • kwargs (Dict): The keyword arguments passed to the __init__ function of the parent class, dict.

static _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) None[源代码]
Overview:

Register the module.

Arguments:
  • module_dict (dict): The dict to store the module.

  • module_name (str): The name of the module.

  • module (Callable): The module to be registered.

  • force_overwrite (bool): Whether to overwrite the module with the same name.

build(obj_type: str, *obj_args, **obj_kwargs) object[源代码]
Overview:

Build the object.

Arguments:
  • obj_type (str): The type of the object.

  • obj_args (Tuple): The arguments passed to the object.

  • obj_kwargs (Dict): The keyword arguments passed to the object.

get(module_name: str) Callable[源代码]
Overview:

Get the module.

Arguments:
  • module_name (str): The name of the module.

query() Iterable[源代码]
Overview:

all registered module names.

query_details(aliases: Iterable | None = None) OrderedDict[源代码]
Overview:

Get the details of the registered modules.

Arguments:
  • aliases (Optional[Iterable]): The aliases of the modules.

register(module_name: str | None = None, module: Callable | None = None, force_overwrite: bool = False) Callable[源代码]
Overview:

Register the module.

Arguments:
  • module_name (Optional[str]): The name of the module.

  • module (Optional[Callable]): The module to be registered.

  • force_overwrite (bool): Whether to overwrite the module with the same name.

render_helper

Please refer to ding/utils/render_helper for more details.

render_env

ding.utils.render_helper.render_env(env, render_mode: str | None = 'rgb_array') ndarray[源代码]
Overview:

Render the environment’s current frame.

Arguments:
  • env (gym.Env): DI-engine env instance.

  • render_mode (str): Render mode.

Returns:
  • frame (numpy.ndarray): [H * W * C]

render

ding.utils.render_helper.render(env: BaseEnv, render_mode: str | None = 'rgb_array') ndarray[源代码]
Overview:

Render the environment’s current frame.

Arguments:
  • env (BaseEnv): DI-engine env instance.

  • render_mode (str): Render mode.

Returns:
  • frame (numpy.ndarray): [H * W * C]

get_env_fps

ding.utils.render_helper.get_env_fps(env) int[源代码]
Overview:

Get the environment’s fps.

Arguments:
  • env (gym.Env): DI-engine env instance.

Returns:
  • fps (int).

fps

ding.utils.render_helper.fps(env_manager: BaseEnvManager) int[源代码]
Overview:

Render the environment’s fps.

Arguments:
  • env (BaseEnvManager): DI-engine env manager instance.

Returns:
  • fps (int).

scheduler_helper

Please refer to ding/utils/scheduler_helper for more details.

Scheduler

class ding.utils.scheduler_helper.Scheduler(merged_scheduler_config: EasyDict)[源代码]
Overview:

Update learning parameters when the trueskill metrics has stopped improving. For example, models often benefits from reducing entropy weight once the learning process stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a ‘patience’ number of epochs, the corresponding parameter is increased or decreased, which decides on the ‘schedule_mode’.

Arguments:
  • schedule_flag (bool): Indicates whether to use scheduler in training pipeline.

    Default: False

  • schedule_mode (str): One of ‘reduce’, ‘add’,’multi’,’div’. The schecule_mode

    decides the way of updating the parameters. Default:’reduce’.

  • factor (float)Amount (greater than 0) by which the parameter will be

    increased/decreased. Default: 0.05

  • change_range (list): Indicates the minimum and maximum value

    the parameter can reach respectively. Default: [-1,1]

  • threshold (float): Threshold for measuring the new optimum,

    to only focus on significant changes. Default: 1e-4.

  • optimize_mode (str): One of ‘min’, ‘max’, which indicates the sign of

    optimization objective. Dynamic_threshold = last_metrics + threshold in max mode or last_metrics - threshold in min mode. Default: ‘min’

  • patience (int): Number of epochs with no improvement after which

    the parameter will be updated. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only update the parameter after the 3rd epoch if the metrics still hasn’t improved then. Default: 10.

  • cooldown (int): Number of epochs to wait before resuming

    normal operation after the parameter has been updated. Default: 0.

Interfaces:

__init__, update_param, step

Property:

in_cooldown, is_better

__init__(merged_scheduler_config: EasyDict) None[源代码]
Overview:

Initialize the scheduler.

Arguments:
  • merged_scheduler_config (EasyDict): the scheduler config, which merges the user

    config and defaul config

config = {'change_range': [-1, 1], 'cooldown': 0, 'factor': 0.05, 'optimize_mode': 'min', 'patience': 10, 'schedule_flag': False, 'schedule_mode': 'reduce', 'threshold': 0.0001}
property in_cooldown: bool
Overview:

Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler will ignore any bad epochs.

is_better(cur: float) bool[源代码]
Overview:

Checks whether the current metrics is better than last matric with respect to threshold.

Args:
  • cur (float): current metrics

step(metrics: float, param: float) float[源代码]
Overview:

Decides whether to update the scheduled parameter

Args:
  • metrics (float): current input metrics

  • param (float): parameter need to be updated

Returns:
  • step_param (float): parameter after one step

update_param(param: float) float[源代码]
Overview:

update the scheduling parameter

Args:
  • param (float): parameter need to be updated

Returns:
  • updated param (float): parameter after updating

segment_tree

Please refer to ding/utils/segment_tree for more details.

njit

ding.utils.segment_tree.njit()[源代码]
Overview:

Decorator to compile a function using numba.

SegmentTree

class ding.utils.segment_tree.SegmentTree(capacity: int, operation: Callable, neutral_element: float | None = None)[源代码]
Overview:

Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, non-leaf nodes are to do some operations on its left and right child.

Interfaces:

__init__, reduce, __setitem__, __getitem__

__init__(capacity: int, operation: Callable, neutral_element: float | None = None) None[源代码]
Overview:

Initialize the segment tree. Tree’s root node is at index 1.

Arguments:
  • capacity (int): Capacity of the tree (the number of the leaf nodes), should be the power of 2.

  • operation (function): The operation function to construct the tree, e.g. sum, max, min, etc.

  • neutral_element (float or None): The value of the neutral element, which is used to init all nodes value in the tree.

_compile() None[源代码]
Overview:

Compile the functions using numba.

reduce(start: int = 0, end: int | None = None) float[源代码]
Overview:

Reduce the tree in range [start, end)

Arguments:
  • start (int): Start index(relative index, the first leaf node is 0), default set to 0

  • end (int or None): End index(relative index), default set to self.capacity

Returns:
  • reduce_result (float): The reduce result value, which is dependent on data type and operation

SumSegmentTree

class ding.utils.segment_tree.SumSegmentTree(capacity: int)[源代码]
Overview:

Sum segment tree, which is inherited from SegmentTree. Init by passing operation='sum'.

Interfaces:

__init__, find_prefixsum_idx

__init__(capacity: int) None[源代码]
Overview:

Init sum segment tree by passing operation='sum'

Arguments:
  • capacity (int): Capacity of the tree (the number of the leaf nodes).

find_prefixsum_idx(prefixsum: float, trust_caller: bool = True) int[源代码]
Overview:

Find the highest non-zero index i, sum_{j}leaf[j] <= prefixsum (where 0 <= j < i) and sum_{j}leaf[j] > prefixsum (where 0 <= j < i+1)

Arguments:
  • prefixsum (float): The target prefixsum.

  • trust_caller (bool): Whether to trust caller, which means whether to check whether this tree’s sum is greater than the input prefixsum by calling reduce function.

    Default set to True.

Returns:
  • idx (int): Eligible index.

MinSegmentTree

class ding.utils.segment_tree.MinSegmentTree(capacity: int)[源代码]
Overview:

Min segment tree, which is inherited from SegmentTree. Init by passing operation='min'.

Interfaces:

__init__

__init__(capacity: int) None[源代码]
Overview:

Initialize sum segment tree by passing operation='min'

Arguments:
  • capacity (int): Capacity of the tree (the number of the leaf nodes).

_setitem

ding.utils.segment_tree._setitem(tree: ndarray, idx: int, val: float, operation: str) None
Overview:

Set tree[idx] = val; Then update the related nodes.

Arguments:
  • tree (np.ndarray): The tree array.

  • idx (int): The index of the leaf node.

  • val (float): The value that will be assigned to leaf[idx].

  • operation (str): The operation function to construct the tree, e.g. sum, max, min, etc.

_reduce

ding.utils.segment_tree._reduce(tree: ndarray, start: int, end: int, neutral_element: float, operation: str) float
Overview:

Reduce the tree in range [start, end)

Arguments:
  • tree (np.ndarray): The tree array.

  • start (int): Start index(relative index, the first leaf node is 0).

  • end (int): End index(relative index).

  • neutral_element (float): The value of the neutral element, which is used to init all nodes value in the tree.

  • operation (str): The operation function to construct the tree, e.g. sum, max, min, etc.

_find_prefixsum_idx

ding.utils.segment_tree._find_prefixsum_idx(tree: ndarray, capacity: int, prefixsum: float, neutral_element: float) int
Overview:

Find the highest non-zero index i, sum_{j}leaf[j] <= prefixsum (where 0 <= j < i) and sum_{j}leaf[j] > prefixsum (where 0 <= j < i+1)

Arguments:
  • tree (np.ndarray): The tree array.

  • capacity (int): Capacity of the tree (the number of the leaf nodes).

  • prefixsum (float): The target prefixsum.

  • neutral_element (float): The value of the neutral element, which is used to init all nodes value in the tree.

slurm_helper

Please refer to ding/utils/slurm_helper for more details.

get_ip

ding.utils.slurm_helper.get_ip() str[源代码]
Overview:

Get the ip of the current node

get_manager_node_ip

ding.utils.slurm_helper.get_manager_node_ip(node_ip: str | None = None) str[源代码]
Overview:

Look up the manager node of the slurm cluster and return the node ip

Arguments:
  • node_ip (Optional[str]): The ip of the current node

get_cls_info

ding.utils.slurm_helper.get_cls_info() Dict[str, list][源代码]
Overview:

Get the cluster info

node_to_partition

ding.utils.slurm_helper.node_to_partition(target_node: str) Tuple[str, str][源代码]
Overview:

Get the partition of the target node

Arguments:
  • target_node (str): The target node

node_to_host

ding.utils.slurm_helper.node_to_host(node: str) str[源代码]
Overview:

Get the host of the node

Arguments:
  • node (str): The node

find_free_port_slurm

ding.utils.slurm_helper.find_free_port_slurm(node: str) int[源代码]
Overview:

Find a free port on the node

Arguments:
  • node (str): The node

system_helper

Please refer to ding/utils/system_helper for more details.

get_ip

ding.utils.system_helper.get_ip() str[源代码]
Overview:

Get the ip(host) of socket

Returns:
  • ip(str): The corresponding ip

get_pid

ding.utils.system_helper.get_pid() int[源代码]
Overview:

os.getpid

get_task_uid

ding.utils.system_helper.get_task_uid() str[源代码]
Overview:

Get the slurm job_id, pid and uid

PropagatingThread

class ding.utils.system_helper.PropagatingThread(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None)[源代码]
Overview:

Subclass of Thread that propagates execution exception in the thread to the caller

Interfaces:

run, join

Examples:
>>> def func():
>>>     raise Exception()
>>> t = PropagatingThread(target=func, args=())
>>> t.start()
>>> t.join()
join() Any[源代码]
Overview:

Join the thread

run() None[源代码]
Overview:

Run the thread

find_free_port

ding.utils.system_helper.find_free_port(host: str) int[源代码]
Overview:

Look up the free port list and return one

Arguments:
  • host (str): The host

time_helper_base

Please refer to ding/utils/time_helper_base for more details.

TimeWrapper

class ding.utils.time_helper_base.TimeWrapper[源代码]
Overview:

Abstract class method that defines TimeWrapper class

Interfaces:

wrapper, start_time, end_time

classmethod end_time()[源代码]
Overview:

Abstract classmethod, stop timing

classmethod start_time()[源代码]
Overview:

Abstract classmethod, start timing

classmethod wrapper(fn)[源代码]
Overview:

Classmethod wrapper, wrap a function and automatically return its running time

Arguments:
  • fn (function): The function to be wrap and timed

time_helper_cuda

Please refer to ding/utils/time_helper_cuda for more details.

get_cuda_time_wrapper

ding.utils.time_helper_cuda.get_cuda_time_wrapper() Callable[[], TimeWrapper][源代码]
Overview:

Return the TimeWrapperCuda class, this wrapper aims to ensure compatibility in no cuda device

Returns:
  • TimeWrapperCuda(class): See TimeWrapperCuda class

备注

Must use torch.cuda.synchronize(), reference: <https://blog.csdn.net/u013548568/article/details/81368019>

time_helper

Please refer to ding/utils/time_helper for more details.

build_time_helper

ding.utils.time_helper.build_time_helper(cfg: EasyDict | None = None, wrapper_type: str | None = None) Callable[[], TimeWrapper][源代码]
Overview:

Build the timehelper

Arguments:
  • cfg (dict):

    The config file, which is a multilevel dict, have large domain like evaluate, common, model, train etc, and each large domain has it’s smaller domain.

  • wrapper_type (str): The type of wrapper returned, support ['time', 'cuda']

Returns:
  • time_wrapper (TimeWrapper):

    Return the corresponding timewrapper, Reference: ding.utils.timehelper.TimeWrapperTime and ding.utils.timehelper.get_cuda_time_wrapper.

EasyTimer

class ding.utils.time_helper.EasyTimer(cuda=True)[源代码]
Overview:

A decent timer wrapper that can be used easily.

Interfaces:

__init__, __enter__, __exit__

Example:
>>> wait_timer = EasyTimer()
>>> with wait_timer:
>>>    func(...)
>>> time_ = wait_timer.value  # in second
__init__(cuda=True)[源代码]
Overview:

Init class EasyTimer

Arguments:
  • cuda (bool): Whether to build timer with cuda type

TimeWrapperTime

class ding.utils.time_helper.TimeWrapperTime[源代码]
Overview:

A class method that inherit from TimeWrapper class

Interfaces:

start_time, end_time

classmethod end_time()[源代码]
Overview:

Implement and override the end_time method in TimeWrapper class

Returns:
  • time(float): The time between start_time and end_time

classmethod start_time()[源代码]
Overview:

Implement and override the start_time method in TimeWrapper class

WatchDog

class ding.utils.time_helper.WatchDog(timeout: int = 1)[源代码]
Overview:

Simple watchdog timer to detect timeouts

Arguments:
  • timeout (int): Timeout value of the watchdog [seconds].

备注

If it is not reset before exceeding this value, TimeourError raised.

Interfaces:

start, stop

Examples:
>>> watchdog = WatchDog(x) # x is a timeout value
>>> ...
>>> watchdog.start()
>>> ... # Some function
__init__(timeout: int = 1)[源代码]
Overview:

Initialize watchdog with timeout value.

Arguments:
  • timeout (int): Timeout value of the watchdog [seconds].

static _event(signum: Any, frame: Any)[源代码]
Overview:

Event handler for watchdog.

Arguments:
  • signum (Any): Signal number.

  • frame (Any): Current stack frame.

start()[源代码]
Overview:

Start watchdog.

stop()[源代码]
Overview:

Stop watchdog with alarm(0), SIGALRM, and SIG_DFL signals.

loader.base

Please refer to ding/utils/loader/base for more details.

ILoaderClass

class ding.utils.loader.base.ILoaderClass[源代码]
Overview:

Base class of loader.

Interfaces:

__init__, _load, load, check, __call__, __and__, __or__, __rshift__

__check(value: _ValueType) bool
Overview:

Check whether the value is valid.

Arguments:
  • value (_ValueType): The value to be checked.

__load(value: _ValueType) _ValueType
Overview:

Load the value.

Arguments:
  • value (_ValueType): The value to be loaded.

abstract _load(value: _ValueType) _ValueType[源代码]
Overview:

Load the value.

Arguments:
  • value (_ValueType): The value to be loaded.

check(value: _ValueType) bool[源代码]
Overview:

Check whether the value is valid.

Arguments:
  • value (_ValueType): The value to be checked.

load(value: _ValueType) _ValueType[源代码]
Overview:

Load the value.

Arguments:
  • value (_ValueType): The value to be loaded.

loader.collection

Please refer to ding/utils/loader/collection for more details.

CollectionError

class ding.utils.loader.collection.CollectionError(errors: List[Tuple[int, Exception]])[源代码]
Overview:

Collection error.

Interfaces:

__init__, errors

Properties:

errors

__init__(errors: List[Tuple[int, Exception]])[源代码]
Overview:

Initialize the CollectionError.

Arguments:
  • errors (COLLECTION_ERRORS): The errors.

_abc_impl = <_abc._abc_data object>
property errors: List[Tuple[int, Exception]]
Overview:

Get the errors.

collection

ding.utils.loader.collection.collection(loader, type_back: bool = True) ILoaderClass[源代码]
Overview:

Create a collection loader.

Arguments:
  • loader (ILoaderClass): The loader.

  • type_back (bool): Whether to convert the type back.

tuple

ding.utils.loader.collection.tuple_(*loaders) ILoaderClass[源代码]
Overview:

Create a tuple loader.

Arguments:
  • loaders (tuple): The loaders.

length

ding.utils.loader.collection.length(min_length: int | None = None, max_length: int | None = None) ILoaderClass[源代码]
Overview:

Create a length loader.

Arguments:
  • min_length (int): The minimum length.

  • max_length (int): The maximum length.

length_is

ding.utils.loader.collection.length_is(length_: int) ILoaderClass[源代码]
Overview:

Create a length loader.

Arguments:

contains

ding.utils.loader.collection.contains(content) ILoaderClass[源代码]
Overview:

Create a contains loader.

Arguments:
  • content (Any): The content.

cofilter

ding.utils.loader.collection.cofilter(checker: Callable[[Any], bool], type_back: bool = True) ILoaderClass[源代码]
Overview:

Create a cofilter loader.

Arguments:
  • checker (Callable[[Any], bool]): The checker.

  • type_back (bool): Whether to convert the type back.

tpselector

ding.utils.loader.collection.tpselector(*indices) ILoaderClass[源代码]
Overview:

Create a tuple selector loader.

Arguments:
  • indices (tuple): The indices.

loader.dict

Please refer to ding/utils/loader/dict for more details.

DictError

class ding.utils.loader.dict.DictError(errors: Mapping[str, Exception])[源代码]
Overview:

Dict error.

Interfaces:

__init__, errors

Properties:

errors

__init__(errors: Mapping[str, Exception])[源代码]
Overview:

Initialize the DictError.

Arguments:
  • errors (DICT_ERRORS): The errors.

_abc_impl = <_abc._abc_data object>
property errors: Mapping[str, Exception]
Overview:

Get the errors.

dict

ding.utils.loader.dict.dict_(**kwargs) ILoaderClass[源代码]
Overview:

Create a dict loader.

Arguments:
  • kwargs (Mapping[str, ILoaderClass]): The loaders.

loader.exception

Please refer to ding/utils/loader/exception for more details.

CompositeStructureError

class ding.utils.loader.exception.CompositeStructureError[源代码]
Overview:

Composite structure error.

Interfaces:

__init__, errors

Properties:

errors

_abc_impl = <_abc._abc_data object>
abstract property errors: List[Tuple[int | str, Exception]]
Overview:

Get the errors.

loader.mapping

Please refer to ding/utils/loader/mapping for more details.

MappingError

class ding.utils.loader.mapping.MappingError(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[源代码]
Overview:

Mapping error.

Interfaces:

__init__, errors

__init__(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[源代码]
Overview:

Initialize the MappingError.

Arguments:
  • key_errors (MAPPING_ERRORS): The key errors.

  • value_errors (MAPPING_ERRORS): The value errors.

_abc_impl = <_abc._abc_data object>
errors() List[Tuple[str, Exception]][源代码]
Overview:

Get the errors.

key_errors() List[Tuple[str, Exception]][源代码]
Overview:

Get the key errors.

value_errors() List[Tuple[str, Exception]][源代码]
Overview:

Get the value errors.

mapping

ding.utils.loader.mapping.mapping(key_loader, value_loader, type_back: bool = True) ILoaderClass[源代码]
Overview:

Create a mapping loader.

Arguments:
  • key_loader (ILoaderClass): The key loader.

  • value_loader (ILoaderClass): The value loader.

  • type_back (bool): Whether to convert the type back.

mpfilter

ding.utils.loader.mapping.mpfilter(check: Callable[[Any, Any], bool], type_back: bool = True) ILoaderClass[源代码]
Overview:

Create a mapping filter loader.

Arguments:
  • check (Callable[[Any, Any], bool]): The check function.

  • type_back (bool): Whether to convert the type back.

mpkeys

ding.utils.loader.mapping.mpkeys() ILoaderClass[源代码]
Overview:

Create a mapping keys loader.

mpvalues

ding.utils.loader.mapping.mpvalues() ILoaderClass[源代码]
Overview:

Create a mapping values loader.

mpitems

ding.utils.loader.mapping.mpitems() ILoaderClass[源代码]
Overview:

Create a mapping items loader.

item

ding.utils.loader.mapping.item(key) ILoaderClass[源代码]
Overview:

Create a item loader.

Arguments:
  • key (Any): The key.

item_or

ding.utils.loader.mapping.item_or(key, default) ILoaderClass[源代码]
Overview:

Create a item or loader.

Arguments:
  • key (Any): The key.

  • default (Any): The default value.

loader.norm

Please refer to ding/utils/loader/norm for more details.

_callable_to_norm

ding.utils.loader.norm._callable_to_norm(func: Callable[[Any], Any]) INormClass[源代码]
Overview:

Convert callable to norm.

Arguments:
  • func (Callable[[Any], Any]): The callable to be converted.

norm

ding.utils.loader.norm.norm(value) INormClass[源代码]
Overview:

Convert value to norm.

Arguments:
  • value (Any): The value to be converted.

normfunc

ding.utils.loader.norm.normfunc(func)[源代码]
Overview:

Convert function to norm function.

Arguments:
  • func (Callable[[Any], Any]): The function to be converted.

_unary

ding.utils.loader.norm._unary(a: INormClass, func: Callable[[Any], Any]) INormClass[源代码]
Overview:

Create a unary norm.

Arguments:
  • a (INormClass): The norm.

  • func (UNARY_FUNC): The function.

_binary

ding.utils.loader.norm._binary(a: INormClass, b: INormClass, func: Callable[[Any, Any], Any]) INormClass[源代码]
Overview:

Create a binary norm.

Arguments:
  • a (INormClass): The first norm.

  • b (INormClass): The second norm.

  • func (BINARY_FUNC): The function.

_binary_reducing

ding.utils.loader.norm._binary_reducing(func: Callable[[Any, Any], Any], zero)[源代码]
Overview:

Create a binary reducing norm.

Arguments:
  • func (BINARY_FUNC): The function.

  • zero (Any): The zero value.

INormClass

class ding.utils.loader.norm.INormClass[源代码]
Overview:

The norm class.

Interfaces:

__call__, __add__, __radd__, __sub__, __rsub__, __mul__, __rmul__, __matmul__, __rmatmul__, __truediv__, __rtruediv__, __floordiv__, __rfloordiv__, __mod__, __rmod__, __pow__, __rpow__, __lshift__, __rlshift__, __rshift__, __rrshift__, __and__, __rand__, __or__, __ror__, __xor__, __rxor__, __invert__, __pos__, __neg__, __eq__, __ne__, __lt__, __le__, __gt__, __ge__

abstract _call(value)[源代码]
Overview:

Call the norm.

Arguments:
  • value (Any): The value to be normalized.

lcmp

ding.utils.loader.norm.lcmp(first, *items)[源代码]
Overview:

Compare the items.

Arguments:
  • first (Any): The first item.

  • items (Any): The other items.

loader.number

Please refer to ding/utils/loader/number for more details.

numeric

ding.utils.loader.number.numeric(int_ok: bool = True, float_ok: bool = True, inf_ok: bool = True) ILoaderClass[源代码]
Overview:

Create a numeric loader.

Arguments:
  • int_ok (bool): Whether int is allowed.

  • float_ok (bool): Whether float is allowed.

  • inf_ok (bool): Whether inf is allowed.

interval

ding.utils.loader.number.interval(left: int | float | None = None, right: int | float | None = None, left_ok: bool = True, right_ok: bool = True, eps=0.0) ILoaderClass[源代码]
Overview:

Create a interval loader.

Arguments:
  • left (Optional[NUMBER_TYPING]): The left bound.

  • right (Optional[NUMBER_TYPING]): The right bound.

  • left_ok (bool): Whether left bound is allowed.

  • right_ok (bool): Whether right bound is allowed.

  • eps (float): The epsilon.

is_negative

ding.utils.loader.number.is_negative() ILoaderClass[源代码]
Overview:

Create a negative loader.

is_positive

ding.utils.loader.number.is_positive() ILoaderClass[源代码]
Overview:

Create a positive loader.

non_negative

ding.utils.loader.number.non_negative() ILoaderClass[源代码]
Overview:

Create a non-negative loader.

non_positive

ding.utils.loader.number.non_positive() ILoaderClass[源代码]
Overview:

Create a non-positive loader.

negative

ding.utils.loader.number.negative() ILoaderClass[源代码]
Overview:

Create a negative loader.

positive

ding.utils.loader.number.positive() ILoaderClass[源代码]
Overview:

Create a positive loader.

_math_binary

ding.utils.loader.number._math_binary(func: Callable[[Any, Any], Any], attachment) ILoaderClass[源代码]
Overview:

Create a math binary loader.

Arguments:
  • func (Callable[[Any, Any], Any]): The function.

  • attachment (Any): The attachment.

plus

ding.utils.loader.number.plus(addend) ILoaderClass[源代码]
Overview:

Create a plus loader.

Arguments:
  • addend (Any): The addend.

minus

ding.utils.loader.number.minus(subtrahend) ILoaderClass[源代码]
Overview:

Create a minus loader.

Arguments:
  • subtrahend (Any): The subtrahend.

minus_with

ding.utils.loader.number.minus_with(minuend) ILoaderClass[源代码]
Overview:

Create a minus loader.

Arguments:
  • minuend (Any): The minuend.

multi

ding.utils.loader.number.multi(multiplier) ILoaderClass[源代码]
Overview:

Create a multi loader.

Arguments:
  • multiplier (Any): The multiplier.

divide

ding.utils.loader.number.divide(divisor) ILoaderClass[源代码]
Overview:

Create a divide loader.

Arguments:
  • divisor (Any): The divisor.

divide_with

ding.utils.loader.number.divide_with(dividend) ILoaderClass[源代码]
Overview:

Create a divide loader.

Arguments:
  • dividend (Any): The dividend.

power

ding.utils.loader.number.power(index) ILoaderClass[源代码]
Overview:

Create a power loader.

Arguments:
  • index (Any): The index.

power_with

ding.utils.loader.number.power_with(base) ILoaderClass[源代码]
Overview:

Create a power loader.

Arguments:
  • base (Any): The base.

msum

ding.utils.loader.number.msum(*items) ILoaderClass[源代码]
Overview:

Create a sum loader.

Arguments:
  • items (tuple): The items.

mmulti

ding.utils.loader.number.mmulti(*items) ILoaderClass[源代码]
Overview:

Create a multi loader.

Arguments:
  • items (tuple): The items.

_msinglecmp

ding.utils.loader.number._msinglecmp(first, op, second) ILoaderClass[源代码]
Overview:

Create a single compare loader.

Arguments:
  • first (Any): The first item.

  • op (str): The operator.

  • second (Any): The second item.

mcmp

ding.utils.loader.number.mcmp(first, *items) ILoaderClass[源代码]
Overview:

Create a multi compare loader.

Arguments:
  • first (Any): The first item.

  • items (tuple): The items.

loader.string

Please refer to ding/utils/loader/string for more details.

enum

ding.utils.loader.string.enum(*items, case_sensitive: bool = True) ILoaderClass[源代码]
Overview:

Create an enum loader.

Arguments:
  • items (Iterable[str]): The items.

  • case_sensitive (bool): Whether case sensitive.

_to_regexp

ding.utils.loader.string._to_regexp(regexp) Pattern[源代码]
Overview:

Convert regexp to re.Pattern.

Arguments:
  • regexp (Union[str, re.Pattern]): The regexp.

rematch

ding.utils.loader.string.rematch(regexp: str | Pattern) ILoaderClass[源代码]
Overview:

Create a rematch loader.

Arguments:
  • regexp (Union[str, re.Pattern]): The regexp.

regrep

ding.utils.loader.string.regrep(regexp: str | Pattern, group: int = 0) ILoaderClass[源代码]
Overview:

Create a regrep loader.

Arguments:
  • regexp (Union[str, re.Pattern]): The regexp.

  • group (int): The group.

loader.types

Please refer to ding/utils/loader/types for more details.

is_type

ding.utils.loader.types.is_type(type_: type) ILoaderClass[源代码]
Overview:

Create a type loader.

Arguments:
  • type_ (type): The type.

to_type

ding.utils.loader.types.to_type(type_: type) ILoaderClass[源代码]
Overview:

Create a type loader.

Arguments:
  • type_ (type): The type.

is_callable

ding.utils.loader.types.is_callable() ILoaderClass[源代码]
Overview:

Create a callable loader.

prop

ding.utils.loader.types.prop(attr_name: str) ILoaderClass[源代码]
Overview:

Create a attribute loader.

Arguments:
  • attr_name (str): The attribute name.

method

ding.utils.loader.types.method(method_name: str) ILoaderClass[源代码]
Overview:

Create a method loader.

Arguments:
  • method_name (str): The method name.

fcall

ding.utils.loader.types.fcall(*args, **kwargs) ILoaderClass[源代码]
Overview:

Create a function loader.

Arguments:
  • args (Tuple[Any]): The args.

  • kwargs (Dict[str, Any]): The kwargs.

fpartial

ding.utils.loader.types.fpartial(*args, **kwargs) ILoaderClass[源代码]
Overview:

Create a partial function loader.

Arguments:
  • args (Tuple[Any]): The args.

  • kwargs (Dict[str, Any]): The kwargs.

loader.utils

Please refer to ding/utils/loader/utils for more details.

keep

ding.utils.loader.utils.keep() ILoaderClass[源代码]
Overview:

Create a keep loader.

raw

ding.utils.loader.utils.raw(value) ILoaderClass[源代码]
Overview:

Create a raw loader.

optional

ding.utils.loader.utils.optional(loader) ILoaderClass[源代码]
Overview:

Create a optional loader.

Arguments:
  • loader (ILoaderClass): The loader.

check_only

ding.utils.loader.utils.check_only(loader) ILoaderClass[源代码]
Overview:

Create a check only loader.

Arguments:
  • loader (ILoaderClass): The loader.

check

ding.utils.loader.utils.check(loader) ILoaderClass[源代码]
Overview:

Create a check loader.

Arguments:
  • loader (ILoaderClass): The loader.