ding.utils¶
autolog¶
Please refer to ding/utils/autolog for more details.
TimeMode¶
- class ding.utils.autolog.TimeMode(value)[源代码]¶
- Overview:
Mode that used to decide the format of range_values function
ABSOLUTE: use absolute time RELATIVE_LIFECYCLE: use relative time based on property’s lifecycle RELATIVE_CURRENT_TIME: use relative time based on current time
- ABSOLUTE = 0¶
- RELATIVE_CURRENT_TIME = 2¶
- RELATIVE_LIFECYCLE = 1¶
RangedData¶
- class ding.utils.autolog.RangedData(expire: float, use_pickle: bool = False)[源代码]¶
- Overview:
A data structure that can store data for a period of time.
- Interfaces:
__init__,append,extend,current,history,expire,__bool__,_get_time.- Properties:
expire (
float): The expire time.
- __append(time_: float, data: _Tp)¶
- Overview:
Append the data.
- __append_item(time_: float, data: _Tp)¶
- Overview:
Append the data item.
- Arguments:
time_ (
float): The time.data (
_Tp): The data item.
- __check_expire()¶
- Overview:
Check the expire time.
- __current()¶
- Overview:
Get the current data.
- __flush_history()¶
- Overview:
Flush the history data.
- __get_data_item(data_id: int) _Tp¶
- Overview:
Get the data item.
- Arguments:
data_id (
int): The data id.
- __history()¶
- Overview:
Get the history data.
- __history_yield()¶
- Overview:
Yield the history data.
- __init__(expire: float, use_pickle: bool = False)[源代码]¶
- Overview:
Initialize the RangedData object.
- Arguments:
expire (
float): The expire time of the data.use_pickle (
bool): Whether to use pickle to serialize the data.
- __registry_data_item(data: _Tp) int¶
- Overview:
Registry the data item.
- Arguments:
data (
_Tp): The data item.
- __remove_data_item(data_id: int)¶
- Overview:
Remove the data item.
- Arguments:
data_id (
int): The data id.
- _abc_impl = <_abc._abc_data object>¶
- property expire: float¶
- Overview:
Get the expire time.
TimeRangedData¶
- class ding.utils.autolog.TimeRangedData(time_: BaseTime, expire: float)[源代码]¶
- Overview:
A data structure that can store data for a period of time.
- Interfaces:
__init__,_get_time,append,extend,current,history,expire,__bool__.- Properties:
time (
BaseTime): The time.expire (
float): The expire time.
- _abc_impl = <_abc._abc_data object>¶
- property time¶
- Overview:
Get the time.
LoggedModel¶
- class ding.utils.autolog.LoggedModel(time_: _TimeObjectType, expire: _TimeType)[源代码]¶
- Overview:
A model with timeline (integered time, such as 1st, 2nd, 3rd, can also be modeled as a kind of self-defined discrete time, such as the implement of TickTime). Serveral values have association with each other can be maintained together by using LoggedModel.
- Example:
Define AvgList model like this
>>> from ding.utils.autolog import LoggedValue, LoggedModel >>> class AvgList(LoggedModel): >>> value = LoggedValue(float) >>> __property_names = ['value'] >>> >>> def __init__(self, time_: BaseTime, expire: Union[int, float]): >>> LoggedModel.__init__(self, time_, expire) >>> # attention, original value must be set in __init__ function, or it will not >>> # be activated, the timeline of this value will also be unexpectedly affected. >>> self.value = 0.0 >>> self.__register() >>> >>> def __register(self): >>> def __avg_func(prop_name: str) -> float: # function to calculate average value of properties >>> records = self.range_values[prop_name]() >>> (_start_time, _), _ = records[0] >>> (_, _end_time), _ = records[-1] >>> >>> _duration = _end_time - _start_time >>> _sum = sum([_value * (_end_time - _begin_time) for (_begin_time, _end_time), _value in records]) >>> >>> return _sum / _duration >>> >>> for _prop_name in self.__property_names: >>> self.register_attribute_value('avg', _prop_name, partial(__avg_func, prop_name=_prop_name))
Use it like this
>>> from ding.utils.autolog import NaturalTime, TimeMode >>> >>> if __name__ == "__main__": >>> _time = NaturalTime() >>> ll = AvgList(_time, expire=10) >>> >>> # just do something here ... >>> >>> print(ll.range_values['value']()) # original range_values function in LoggedModel of last 10 secs >>> print(ll.range_values['value'](TimeMode.ABSOLUTE)) # use absolute time >>> print(ll.avg['value']()) # average value of last 10 secs
- Interfaces:
__init__,time,expire,fixed_time,current_time,freeze,unfreeze,register_attribute_value,__getattr__,get_property_attribute- Property:
time (
BaseTime): The time.expire (
float): The expire time.
- __get_property_ranged_data(name: str) TimeRangedData¶
- Overview:
Get ranged data of one property.
- Arguments:
name (
str): The property name.
- __get_range_values_func(name: str)¶
- Overview:
Get range_values function of one property.
- Arguments:
name (
str): The property name.
- __init_properties()¶
- Overview:
Initialize all properties.
- property __properties: List[str]¶
- Overview:
Get all property names.
- __register_default_funcs()¶
- Overview:
Register default functions.
- _abc_impl = <_abc._abc_data object>¶
- current_time() float | int[源代码]¶
- Overview:
Get current time (real time that regardless of time proxy’s frozen statement)
- Returns:
int or float: current time
- property expire: _TimeType¶
- Overview:
Get expire time
- Returns:
int or float: time that old value records expired
- fixed_time() float | int[源代码]¶
- Overview:
Get fixed time (will be frozen time if time proxy is frozen) This feature can be useful when adding value replay feature (in the future)
- Returns:
int or float: fixed time
- freeze()[源代码]¶
- Overview:
Freeze time proxy object. This feature can be useful when adding value replay feature (in the future)
- get_property_attribute(property_name: str) List[str][源代码]¶
- Overview:
Find all registered attributes (except common “range_values” attribute, since “range_values” is not added to
self.__prop2attr) of one given property.- Arguments:
property_name (
str): name of property to query attributes
- Returns:
attr_list (
List[str]): the registered attributes list of the input property
- register_attribute_value(attribute_name: str, property_name: str, value: Any)[源代码]¶
- Overview:
Register a new attribute for one of the values. Example can be found in overview of class.
- Arguments:
attribute_name (
str): name of attributeproperty_name (
str): name of propertyvalue (
Any): value of attribute
- property time: _TimeObjectType¶
- Overview:
Get original time object passed in, can execute method (such as step()) by this property.
- Returns:
BaseTime: time object used by this model
BaseTime¶
NaturalTime¶
TickTime¶
- class ding.utils.autolog.TickTime(init: int = 0)[源代码]¶
- Overview:
Tick time object
- Interfaces:
__init__,step,time- Example:
>>> from ding.utils.autolog.time_ctl import TickTime >>> time_ = TickTime()
- __init__(init: int = 0)[源代码]¶
- Overview:
Constructor of TickTime
- Arguments:
init (
int): initial time, default is 0
- _abc_impl = <_abc._abc_data object>¶
- step(delta: int = 1) int[源代码]¶
- Overview
Step the time forward for this TickTime
- Arguments:
delta (
int): steps to step forward, default is 1
- Returns:
time (
int): new time after stepping
- Example:
>>> from ding.utils.autolog.time_ctl import TickTime >>> time_ = TickTime(0) >>> time_.step() 1 >>> time_.step(2) 3
TimeProxy¶
- class ding.utils.autolog.TimeProxy(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]¶
- Overview:
Proxy of time object, it can freeze time, sometimes useful when reproducing. This object is thread-safe, and also freeze and unfreeze operation is strictly ordered.
- Interfaces:
__init__,freeze,unfreeze,time,current_time- Examples:
>>> from ding.utils.autolog.time_ctl import TickTime, TimeProxy >>> tick_time_ = TickTime() >>> time_ = TimeProxy(tick_time_) >>> tick_time_.step() >>> print(tick_time_.time(), time_.time(), time_.current_time()) 1 1 1 >>> time_.freeze() >>> tick_time_.step() >>> print(tick_time_.time(), time_.time(), time_.current_time()) 2 1 2 >>> time_.unfreeze() >>> print(tick_time_.time(), time_.time(), time_.current_time()) 2 2 2
- __init__(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]¶
- _abc_impl = <_abc._abc_data object>¶
- current_time() int | float[源代码]¶
- Overview:
Get current time (will not be frozen time)
- Returns:
int or float: current time
- property is_frozen: bool¶
- Overview:
Get if this time proxy object is frozen
- Returns:
bool: true if it is frozen, otherwise false
LoggedValue¶
- class ding.utils.autolog.LoggedValue(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[源代码]¶
- Overview:
LoggedValue can be used as property in LoggedModel, for it has __get__ and __set__ method. This class’s instances will be associated with their owner LoggedModel instance, all the LoggedValue of one LoggedModel will shared the only one time object (defined in time_ctl), so that timeline can be managed properly.
- Interfaces:
__init__,__get__,__set__- Properties:
__property_name (
str): The name of the property.
- __get_ranged_data(instance) TimeRangedData¶
- Overview:
Get the ranged data.
- Interfaces:
__get_ranged_data
- __init__(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[源代码]¶
- Overview:
Initialize the LoggedValue object.
- Interfaces:
__init__
- property __property_name¶
- Overview:
Get the name of the property.
data.structure¶
Please refer to ding/utils/data/structure for more details.
Cache¶
- class ding.utils.data.structure.Cache(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False)[源代码]¶
- Overview:
Data cache for reducing concurrent pressure, with timeout and full queue eject mechanism
- Interfaces:
__init__,push_data,get_cached_data_iter,run,close- Property:
remain_data_count
- __init__(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False) None[源代码]¶
- Overview:
Initialize the cache object.
- Arguments:
maxlen (
int): Maximum length of the cache queue.timeout (
float): Maximum second of the data can remain in the cache.monitor_interval (
float): Interval of the timeout monitor thread checks the time._debug (
bool): Whether to use debug mode or not, which enables debug print info.
- _warn_if_timeout() bool[源代码]¶
- Overview:
Return whether is timeout.
- Returns
result: (
bool) Whether is timeout.
- close() None[源代码]¶
- Overview:
Shut down the cache internal thread and send the end flag to send queue’s iterator.
- dprint(s: str) None[源代码]¶
- Overview:
In debug mode, print debug str.
- Arguments:
s (
str): Debug info to be printed.
- get_cached_data_iter() callable_iterator[源代码]¶
- Overview:
Get the iterator of the send queue. Once a data is pushed into send queue, it can be accessed by this iterator. ‘STOP’ is the end flag of this iterator.
- Returns:
iterator (
callable_iterator) The send queue iterator.
- push_data(data: Any) None[源代码]¶
- Overview:
Push data into receive queue, if the receive queue is full(after push), then push all the data in receive queue into send queue.
- Arguments:
data (
Any): The data which needs to be added into receive queue
小技巧
thread-safe
- property remain_data_count: int¶
- Overview:
Return receive queue’s remain data count
- Returns:
count (
int): The size of the receive queue.
LifoDeque¶
data.base_dataloader¶
Please refer to ding/utils/data/base_dataloader for more details.
IDataLoader¶
- class ding.utils.data.base_dataloader.IDataLoader[源代码]¶
- Overview:
Base class of data loader
- Interfaces:
__init__,__next__,__iter__,_get_data,close
data.collate_fn¶
Please refer to ding/utils/data/collate_fn for more details.
ttorch_collate¶
- ding.utils.data.collate_fn.ttorch_collate(x, json: bool = False, cat_1dim: bool = True)[源代码]¶
- Overview:
Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested dictionary of tensors.
- Arguments:
x : The input list of tensors or nested dictionaries of tensors.
json (
bool): If True, converts the output to JSON format. Defaults to False.cat_1dim (
bool): If True, concatenates tensors with shape (B, 1) along the last dimension. Defaults to True.
- Returns:
The collated output tensor or nested dictionary of tensors.
- Examples:
>>> # case 1: Collate a list of tensors >>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])] >>> collated = ttorch_collate(tensors) collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> # case 2: Collate a nested dictionary of tensors >>> nested_dict = { 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), 'c': torch.tensor([7, 8, 9]) } >>> collated = ttorch_collate(nested_dict) collated = { 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), 'c': torch.tensor([7, 8, 9]) } >>> # case 3: Collate a list of nested dictionaries of tensors >>> nested_dicts = [ {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])}, {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])} ] >>> collated = ttorch_collate(nested_dicts) collated = { 'a': torch.tensor([[1, 2, 3], [7, 8, 9]]), 'b': torch.tensor([[4, 5, 6], [10, 11, 12]]) }
default_collate¶
- ding.utils.data.collate_fn.default_collate(batch: Sequence, cat_1dim: bool = True, ignore_prefix: list = ['collate_ignore']) Tensor | Mapping | Sequence[源代码]¶
- Overview:
Put each data field into a tensor with outer dimension batch size.
- Arguments:
batch (
Sequence): A data sequence, whose length is batch size, whose element is one piece of data.cat_1dim (
bool): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True.ignore_prefix (
list): A list of prefixes to ignore when collating dictionaries, defaults to [‘collate_ignore’].
- Returns:
ret (
Union[torch.Tensor, Mapping, Sequence]): the collated data, with batch size into each data field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
- Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n) >>> a = [torch.zeros(2,3) for _ in range(4)] >>> default_collate(a).shape torch.Size([4, 2, 3]) >>> >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, ) >>> a = [[0 for __ in range(3)] for _ in range(4)] >>> default_collate(a) [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])] >>> >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->> >>> # a dict whose values are tensors with shape :math:`(B, m, n)` >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] >>> print(a[0][2].shape, a[0][3].shape) torch.Size([2, 3]) torch.Size([3, 4]) >>> b = default_collate(a) >>> print(b[2].shape, b[3].shape) torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
timestep_collate¶
- ding.utils.data.collate_fn.timestep_collate(batch: List[Dict[str, Any]]) Dict[str, Tensor | list][源代码]¶
- Overview:
Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep.
- Arguments:
batch(
List[Dict[str, Any]]): A list of dictionaries with length B, where each dictionary represents a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the data field and the value is a sequence of torch.Tensor objects with any shape.
- Returns:
ret(
Dict[str, Union[torch.Tensor, list]]): The collated data, with the timestep and batch size incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, …].
- Examples:
>>> batch = [ {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]}, {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]} ] >>> collated_data = timestep_collate(batch) >>> print(collated_data['data'].shape) torch.Size([2, 2, 3])
diff_shape_collate¶
- ding.utils.data.collate_fn.diff_shape_collate(batch: Sequence) Tensor | Mapping | Sequence[源代码]¶
- Overview:
Collates a batch of data with different shapes. This function is similar to default_collate, but it allows tensors in the batch to have None values, which is common in StarCraft observations.
- Arguments:
batch (
Sequence): A sequence of data, where each element is a piece of data.
- Returns:
ret (
Union[torch.Tensor, Mapping, Sequence]): The collated data, with the batch size applied to each data field. The return type depends on the original element type and can be a torch.Tensor, Mapping, or Sequence.
- Examples:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n) >>> a = [torch.zeros(2,3) for _ in range(4)] >>> diff_shape_collate(a).shape torch.Size([4, 2, 3]) >>> >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, ) >>> a = [[0 for __ in range(3)] for _ in range(4)] >>> diff_shape_collate(a) [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])] >>> >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->> >>> # a dict whose values are tensors with shape :math:`(B, m, n)` >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] >>> print(a[0][2].shape, a[0][3].shape) torch.Size([2, 3]) torch.Size([3, 4]) >>> b = diff_shape_collate(a) >>> print(b[2].shape, b[3].shape) torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
default_decollate¶
- ding.utils.data.collate_fn.default_decollate(batch: Tensor | Sequence | Mapping, ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state']) List[Any][源代码]¶
- Overview:
Drag out batch_size collated data’s batch size to decollate it, which is the reverse operation of
default_collate.- Arguments:
batch (
Union[torch.Tensor, Sequence, Mapping]): The collated data batch. It can be a tensor, sequence, or mapping.ignore(
List[str]): A list of names to be ignored. Only applicable if the inputbatchis a dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to [‘prev_state’, ‘prev_actor_state’, ‘prev_critic_state’].
- Returns:
ret (
List[Any]): A list with B elements, where B is the batch size.
- Examples:
>>> batch = { 'a': [ [1, 2, 3], [4, 5, 6] ], 'b': [ [7, 8, 9], [10, 11, 12] ]} >>> default_decollate(batch) { 0: {'a': [1, 2, 3], 'b': [7, 8, 9]}, 1: {'a': [4, 5, 6], 'b': [10, 11, 12]}, }
data.dataloader¶
Please refer to ding/utils/data/dataloader for more details.
AsyncDataLoader¶
- class ding.utils.data.dataloader.AsyncDataLoader(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0)[源代码]¶
- Overview:
An asynchronous dataloader.
- Interfaces:
__init__,__iter__,__next__,_get_data,_async_loop,_worker_loop,_cuda_loop,_get_data,close
- __init__(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0) None[源代码]¶
- Overview:
Init dataloader with input parameters. If
data_sourceisdict, data will only be processed inget_data_threadand put intoasync_train_queue. Ifdata_sourceisCallable, data will be processed by implementing functions, and can be sorted in two types:num_workers== 0 or 1: Only main worker will process it and put intoasync_train_queue.num_workers> 1: Main worker will divide a job into several pieces, push every job intojob_queue; Then slave workers get jobs and implement; Finally they will push procesed data intoasync_train_queue.
At the last step, if
devicecontains “cuda”, data inasync_train_queuewill be transferred tocuda_queuefor uer to access.- Arguments:
data_source (
Union[Callable, dict]): The data source, e.g. function to be implemented(Callable), replay buffer’s real data(dict), etc.batch_size (
int): Batch size.device (
str): Device.chunk_size (
int): The size of a chunked piece in a batch, should exactly dividebatch_size, only function when there are more than 1 worker.collate_fn (
Callable): The function which is used to collate batch size into each data field.num_workers (
int): Number of extra workers. 0 or 1 means only 1 main worker and no extra ones, i.e. Multiprocessing is disabled. More than 1 means multiple workers implemented by multiprocessing are to processs data respectively.
- _async_loop(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[源代码]¶
- Overview:
Main worker process. Run through
self.async_process. Firstly, get data fromself.get_data_thread. If multiple workers, put data inself.job_queuefor further multiprocessing operation; If only one worker, process data and put directly intoself.async_train_queue.- Arguments:
p (
tm.multiprocessing.connection): Parent connection.c (
tm.multiprocessing.connection): Child connection.
- _cuda_loop() None[源代码]¶
- Overview:
Only when using cuda, would this be run as a thread through
self.cuda_thread. Get data fromself.async_train_queue, change its device and put it intoself.cuda_queue
- _get_data(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[源代码]¶
- Overview:
Init dataloader with input parameters. Will run as a thread through
self.get_data_thread.- Arguments:
p (
tm.multiprocessing.connection): Parent connection.c (
tm.multiprocessing.connection): Child connection.
data.dataset¶
Please refer to ding/utils/data/dataset for more details.
DatasetStatistics¶
NaiveRLDataset¶
D4RLDataset¶
- class ding.utils.data.dataset.D4RLDataset(cfg: dict)[源代码]¶
- Overview:
D4RL dataset, which is used for offline RL algorithms.
- Interfaces:
__init__,__len__,__getitem__- Properties:
mean (
np.ndarray): Mean of the dataset.std (
np.ndarray): Std of the dataset.action_bounds (
np.ndarray): Action bounds of the dataset.statistics (
dict): Statistics of the dataset.
- _cal_statistics(dataset, env, eps=0.001, add_action_buffer=True)[源代码]¶
- Overview:
Calculate the statistics of the dataset.
- Arguments:
dataset (
Dict[str, np.ndarray]): The d4rl dataset.env (
gym.Env): The environment.eps (
float): Epsilon.
- _load_d4rl(dataset: Dict[str, ndarray]) None[源代码]¶
- Overview:
Load the d4rl dataset.
- Arguments:
dataset (
Dict[str, np.ndarray]): The d4rl dataset.
- _normalize_states(dataset)[源代码]¶
- Overview:
Normalize the states.
- Arguments:
dataset (
Dict[str, np.ndarray]): The d4rl dataset.
- property action_bounds: ndarray¶
- Overview:
Get the action bounds of the dataset.
- property data: List¶
- property mean¶
- Overview:
Get the mean of the dataset.
- property statistics: dict¶
- Overview:
Get the statistics of the dataset.
- property std¶
- Overview:
Get the std of the dataset.
HDF5Dataset¶
- class ding.utils.data.dataset.HDF5Dataset(cfg: dict)[源代码]¶
- Overview:
HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. The hdf5 format is a common format for storing large numerical arrays in Python. For more details, please refer to https://support.hdfgroup.org/HDF5/.
- Interfaces:
__init__,__len__,__getitem__- Properties:
mean (
np.ndarray): Mean of the dataset.std (
np.ndarray): Std of the dataset.action_bounds (
np.ndarray): Action bounds of the dataset.statistics (
dict): Statistics of the dataset.
- __getitem__(idx: int) Dict[str, Tensor][源代码]¶
- Overview:
Get the item of the dataset.
- Arguments:
idx (
int): The index of the dataset.
- _cal_statistics(eps: float = 0.001)[源代码]¶
- Overview:
Calculate the statistics of the dataset.
- Arguments:
eps (
float): Epsilon.
- _load_data(dataset: Dict[str, ndarray]) None[源代码]¶
- Overview:
Load the dataset.
- Arguments:
dataset (
Dict[str, np.ndarray]): The dataset.
- property action_bounds: ndarray¶
- Overview:
Get the action bounds of the dataset.
- property mean¶
- Overview:
Get the mean of the dataset.
- property statistics: dict¶
- Overview:
Get the statistics of the dataset.
- property std¶
- Overview:
Get the std of the dataset.
D4RLTrajectoryDataset¶
- class ding.utils.data.dataset.D4RLTrajectoryDataset(cfg: dict)[源代码]¶
- Overview:
D4RL trajectory dataset, which is used for offline RL algorithms.
- Interfaces:
__init__,__len__,__getitem__
- D4RL_DATASET_STATS = {'halfcheetah-medium-expert-v2': {'state_mean': [-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, -0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, 8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, 0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314], 'state_std': [0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, 0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, 1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, 6.4811787605285645, 6.378620147705078]}, 'halfcheetah-medium-replay-v2': {'state_mean': [-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, -0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, 3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, -0.015839405357837677], 'state_std': [0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, 0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, 1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, 6.085654258728027, 7.25300407409668]}, 'halfcheetah-medium-v2': {'state_mean': [-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, -0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, 5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, 0.013382787816226482], 'state_std': [0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, 0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, 1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, 5.671932697296143, 7.4982590675354]}, 'hopper-medium-expert-v2': {'state_mean': [1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, 0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, -0.1766270101070404, -0.11862941086292267, -0.12097819894552231], 'state_std': [0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, 0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, 5.725032806396484]}, 'hopper-medium-replay-v2': {'state_mean': [1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, 0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, -0.5287045240402222, -0.14465883374214172, -0.19652697443962097], 'state_std': [0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, 1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, 5.108601093292236]}, 'hopper-medium-v2': {'state_mean': [1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, 2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, -0.18540096282958984, -0.28461286425590515], 'state_std': [0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, 0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, 5.607253551483154]}, 'walker2d-medium-expert-v2': {'state_mean': [1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, 0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, 3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, -0.27366524934768677], 'state_std': [0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, 0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, 1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, 4.039782524108887, 5.891613960266113]}, 'walker2d-medium-replay-v2': {'state_mean': [1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, -0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, -0.08934258669614792, -0.2992438077926636, -0.5984178185462952], 'state_std': [0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, 0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, 2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, 3.845186948776245, 5.4768385887146]}, 'walker2d-medium-v2': {'state_mean': [1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, -0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, 0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654], 'state_std': [0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, 0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, 1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, 3.7445690631866455, 5.5851287841796875]}}¶
- REF_MAX_SCORE = {'halfcheetah': 12135.0, 'hopper': 3234.3, 'walker2d': 4592.3}¶
- REF_MIN_SCORE = {'halfcheetah': -280.178953, 'hopper': -20.272305, 'walker2d': 1.629008}¶
- __getitem__(idx: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor][源代码]¶
- Overview:
Get the item of the dataset.
- Arguments:
idx (
int): The index of the dataset.
D4RLDiffuserDataset¶
FixedReplayBuffer¶
- class ding.utils.data.dataset.FixedReplayBuffer(data_dir: str, replay_suffix: int, *args, **kwargs)[源代码]¶
- Overview:
Object composed of a list of OutofGraphReplayBuffers.
- Interfaces:
__init__,get_transition_elements,sample_transition_batch
- __init__(data_dir: str, replay_suffix: int, *args, **kwargs)[源代码]¶
- Overview:
Initialize the FixedReplayBuffer class.
- Arguments:
data_dir (
str): Log directory from which to load the replay buffer.replay_suffix (
int): If not None, then only load the replay buffer corresponding to the specific suffix in data directory.args (
list): Arbitrary extra arguments.kwargs (
dict): Arbitrary keyword arguments.
- _load_buffer(suffix)[源代码]¶
- Overview:
Loads a OutOfGraphReplayBuffer replay buffer.
- Arguments:
suffix (
int): The suffix of the replay buffer.
PCDataset¶
- class ding.utils.data.dataset.PCDataset(all_data)[源代码]¶
- Overview:
Dataset for Procedure Cloning.
- Interfaces:
__init__,__len__,__getitem__
- __getitem__(item)[源代码]¶
- Overview:
Get the item of the dataset.
- Arguments:
item (
int): The index of the dataset.
load_bfs_datasets¶
BCODataset¶
- class ding.utils.data.dataset.BCODataset(data=None)[源代码]¶
- Overview:
Dataset for Behavioral Cloning from Observation.
- Interfaces:
__init__,__len__,__getitem__- Properties:
obs (
np.ndarray): The observation array.action (
np.ndarray): The action array.
- __getitem__(idx)[源代码]¶
- Overview:
Get the item of the dataset.
- Arguments:
idx (
int): The index of the dataset.
- __init__(data=None)[源代码]¶
- Overview:
Initialization method of BCODataset.
- Arguments:
data (
dict): The data dict.
- property action¶
- Overview:
Get the action array.
- property obs¶
- Overview:
Get the observation array.
SequenceDataset¶
- class ding.utils.data.dataset.SequenceDataset(cfg)[源代码]¶
- Overview:
Dataset for diffuser.
- Interfaces:
__init__,__len__,__getitem__
- __getitem__(idx, eps=0.0001)[源代码]¶
- Overview:
Get the item of the dataset.
- Arguments:
idx (
int): The index of the dataset.eps (
float): The epsilon.
- __init__(cfg)[源代码]¶
- Overview:
Initialization method of SequenceDataset.
- Arguments:
cfg (
dict): The config dict.
- get_conditions(observations)[源代码]¶
- Overview:
Get the conditions on current observation for planning.
- Arguments:
observations (
np.ndarray): The observation array.
- make_indices(path_lengths, horizon)[源代码]¶
- Overview:
Make indices for sampling from dataset. Each index maps to a datapoint.
- Arguments:
path_lengths (
np.ndarray): The path length array.horizon (
int): The horizon.
- maze2d_set_terminals(env, dataset)[源代码]¶
- Overview:
Set the terminals for maze2d.
- Arguments:
env (
gym.Env): The gym env.dataset (
dict): The dataset dict.
- normalize(keys=['observations', 'actions'])[源代码]¶
- Overview:
Normalize the dataset, normalize fields that will be predicted by the diffusion model
- Arguments:
keys (
list): The list of keys.
- normalize_value(value)[源代码]¶
- Overview:
Normalize the value.
- Arguments:
value (
np.ndarray): The value array.
hdf5_save¶
naive_save¶
offline_data_save_type¶
create_dataset¶
bfs_helper¶
Please refer to ding/utils/bfs_helper for more details.
get_vi_sequence¶
- ding.utils.bfs_helper.get_vi_sequence(env: Env, observation: ndarray) Tuple[ndarray, List][源代码]¶
- Overview:
Given an instance of the maze environment and the current observation, using Broad-First-Search (BFS) algorithm to plan an optimal path and record the result.
- Arguments:
env (
Env): The instance of the maze environment.observation (
np.ndarray): The current observation.
- Returns:
output (
Tuple[np.ndarray, List]): The BFS result.output[0]contains the BFS map after each iteration andoutput[1]contains the optimal actions before reaching the finishing point.
collection_helper¶
Please refer to ding/utils/collection_helper for more details.
iter_mapping¶
- ding.utils.collection_helper.iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType])[源代码]¶
- Overview:
Map a list of iterable elements to input iteration callable
- Arguments:
iter_(
_IterType list): The list for iterationmapping (
Callable [[_IterType], _IterTargetType]): A callable that maps iterable elements function.
- Return:
(
iter_mapping object): Iteration results
- Example:
>>> iterable_list = [1, 2, 3, 4, 5] >>> _iter = iter_mapping(iterable_list, lambda x: x ** 2) >>> print(list(_iter)) [1, 4, 9, 16, 25]
compression_helper¶
Please refer to ding/utils/compression_helper for more details.
CloudPickleWrapper¶
dummy_compressor¶
zlib_data_compressor¶
- ding.utils.compression_helper.zlib_data_compressor(data: Any) bytes[源代码]¶
- Overview:
Takes the input compressed data and return the compressed original data (zlib compressor) in binary format.
- Arguments:
data (
Any): The input data of the compressor.
- Returns:
output (
bytes): The compressed byte-like result.
- Examples:
>>> zlib_data_compressor("Hello")
lz4_data_compressor¶
- ding.utils.compression_helper.lz4_data_compressor(data: Any) bytes[源代码]¶
- Overview:
Return the compressed original data (lz4 compressor).The compressor outputs in binary format.
- Arguments:
data (
Any): The input data of the compressor.
- Returns:
output (
bytes): The compressed byte-like result.
- Examples:
>>> lz4.block.compress(pickle.dumps("Hello")) b'R Hello.'
jpeg_data_compressor¶
- ding.utils.compression_helper.jpeg_data_compressor(data: ndarray) bytes[源代码]¶
- Overview:
To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the buffer. This function encodes the observation numpy arr to the jpeg strings.
- Arguments:
data (
np.array): the observation numpy arr.
- Returns:
img_str (
bytes): The compressed byte-like result.
get_data_compressor¶
- ding.utils.compression_helper.get_data_compressor(name: str)[源代码]¶
- Overview:
Get the data compressor according to the input name.
- Arguments:
name(
str): Name of the compressor, support['lz4', 'zlib', 'jpeg', 'none']
- Return:
compressor (
Callable): Corresponding data_compressor, taking input data returning compressed data.
- Example:
>>> compress_fn = get_data_compressor('lz4') >>> compressed_data = compressed(input_data)
dummy_decompressor¶
lz4_data_decompressor¶
zlib_data_decompressor¶
jpeg_data_decompressor¶
- ding.utils.compression_helper.jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) ndarray[源代码]¶
- Overview:
To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the buffer. This function decodes the observation numpy arr from the jpeg strings.
- Arguments:
compressed_data (
bytes): The jpeg strings.- gray_scale (
bool): If the observation is gray,gray_scale=True, if the observation is RGB,
gray_scale=False.
- gray_scale (
- Returns:
arr (
np.ndarray): The decompressed numpy array.
get_data_decompressor¶
- ding.utils.compression_helper.get_data_decompressor(name: str) Callable[源代码]¶
- Overview:
Get the data decompressor according to the input name.
- Arguments:
name(
str): Name of the decompressor, support['lz4', 'zlib', 'none']
备注
For all the decompressors, the input of a bytes-like object is required.
- Returns:
decompressor (
Callable): Corresponding data decompressor.
- Examples:
>>> decompress_fn = get_data_decompressor('lz4') >>> origin_data = compressed(compressed_data)
default_helper¶
Please refer to ding/utils/default_helper for more details.
get_shape0¶
- ding.utils.default_helper.get_shape0(data: List | Dict | Tensor | Tensor) int[源代码]¶
- Overview:
Get shape[0] of data’s torch tensor or treetensor
- Arguments:
data (
Union[List,Dict,torch.Tensor,ttorch.Tensor]): data to be analysed
- Returns:
shape[0] (
int): first dimension length of data, usually the batchsize.
lists_to_dicts¶
- ding.utils.default_helper.lists_to_dicts(data: List[dict | NamedTuple] | Tuple[dict | NamedTuple], recursive: bool = False) Mapping[object, object] | NamedTuple[源代码]¶
- Overview:
Transform a list of dicts to a dict of lists.
- Arguments:
- data (
Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]): A dict of lists need to be transformed
- data (
recursive (
bool): whether recursively deals with dict element
- Returns:
newdata (
Union[Mapping[object, object], NamedTuple]): A list of dicts as a result
- Example:
>>> from ding.utils import * >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) {1: [1, 2], 10: [3, 4]}
dicts_to_lists¶
- ding.utils.default_helper.dicts_to_lists(data: Mapping[object, List[object]]) List[Mapping[object, object]][源代码]¶
- Overview:
Transform a dict of lists to a list of dicts.
- Arguments:
data (
Mapping[object, list]): A list of dicts need to be transformed
- Returns:
newdata (
List[Mapping[object, object]]): A dict of lists as a result
- Example:
>>> from ding.utils import * >>> dicts_to_lists({1: [1, 2], 10: [3, 4]}) [{1: 1, 10: 3}, {1: 2, 10: 4}]
override¶
squeeze¶
default_get¶
- ding.utils.default_helper.default_get(data: dict, name: str, default_value: Any | None = None, default_fn: Callable | None = None, judge_fn: Callable | None = None) Any[源代码]¶
- Overview:
Getting the value by input, checks generically on the inputs with at least
dataandname. Ifnameexists indata, get the value atname; else, addnametodefault_get_setwith value generated bydefault_fn(or directly asdefault_value) that is checked by `` judge_fn`` to be legal.- Arguments:
data(
dict): Data input dictionaryname(
str): Key namedefault_value(
Optional[Any]) = None,default_fn(
Optional[Callable]) = Valuejudge_fn(
Optional[Callable]) = None
- Returns:
ret(
list): Splitted dataresidual(
list): Residule list
list_split¶
- ding.utils.default_helper.list_split(data: list, step: int) List[list][源代码]¶
- Overview:
Split list of data by step.
- Arguments:
data(
list): List of data for splitingstep(
int): Number of step for spliting
- Returns:
ret(
list): List of splitted data.residual(
list): Residule list. This value isNonewhendatadividessteps.
- Example:
>>> list_split([1,2,3,4],2) ([[1, 2], [3, 4]], None) >>> list_split([1,2,3,4],3) ([[1, 2, 3]], [4])
error_wrapper¶
- ding.utils.default_helper.error_wrapper(fn, default_ret, warning_msg='')[源代码]¶
- Overview:
wrap the function, so that any Exception in the function will be catched and return the default_ret
- Arguments:
fn (
Callable): the function to be wrapeddefault_ret (
obj): the default return when an Exception occurred in the function
- Returns:
wrapper (
Callable): the wrapped function
- Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py) >>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink. >>> if is_fake_link: >>> return 0 >>> return error_wrapper(link.get_rank, 0)()
LimitedSpaceContainer¶
- class ding.utils.default_helper.LimitedSpaceContainer(min_val: int, max_val: int)[源代码]¶
- Overview:
A space simulator.
- Interfaces:
__init__,get_residual_space,release_space
- __init__(min_val: int, max_val: int) None[源代码]¶
- Overview:
Set
min_valandmax_valof the container, also setcurtomin_valfor initialization.- Arguments:
min_val (
int): Min volume of the container, usually 0.max_val (
int): Max volume of the container.
- acquire_space() bool[源代码]¶
- Overview:
Try to get one pice of space. If there is one, return True; Otherwise return False.
- Returns:
flag (
bool): Whether there is any piece of residual space.
deep_merge_dicts¶
deep_update¶
- ding.utils.default_helper.deep_update(original: dict, new_dict: dict, new_keys_allowed: bool = False, whitelist: List[str] | None = None, override_all_if_type_changes: List[str] | None = None)[源代码]¶
- Overview:
Update original dict with values from new_dict recursively.
- Arguments:
original (
dict): Dictionary with default values.new_dict (
dict): Dictionary with values to be updatednew_keys_allowed (
bool): Whether new keys are allowed.- whitelist (
Optional[List[str]]): List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level.
- whitelist (
- override_all_if_type_changes(
Optional[List[str]]): List of top level keys with value=dict, for which we always simply override the entire value (
dict), if the “type” key in that value dict changes.
- override_all_if_type_changes(
备注
If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the whitelist, then new subkeys can be introduced.
flatten_dict¶
- ding.utils.default_helper.flatten_dict(data: dict, delimiter: str = '/') dict[源代码]¶
- Overview:
Flatten the dict, see example
- Arguments:
data (
dict): Original nested dictdelimiter (str): Delimiter of the keys of the new dict
- Returns:
data (
dict): Flattened nested dict
- Example:
>>> a {'a': {'b': 100}} >>> flatten_dict(a) {'a/b': 100}
set_pkg_seed¶
- ding.utils.default_helper.set_pkg_seed(seed: int, use_cuda: bool = True) None[源代码]¶
- Overview:
Side effect function to set seed for
random,numpy random, andtorch's manual seed. This is usaually used in entry scipt in the section of setting random seed for all package and instance- Argument:
seed(
int): Set seeduse_cuda(
bool) Whether use cude
- Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py >>> ... # Set random seed for all package and instance >>> collector_env.seed(seed) >>> evaluator_env.seed(seed, dynamic_seed=False) >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) >>> ... # Set up RL Policy, etc. >>> ...
one_time_warning¶
split_fn¶
split_data_generator¶
RunningMeanStd¶
- class ding.utils.default_helper.RunningMeanStd(epsilon=0.0001, shape=(), device=device(type='cpu'))[源代码]¶
- Overview:
Wrapper to update new variable, new mean, and new count
- Interfaces:
__init__,update,reset,new_shape- Properties:
mean,std,_epsilon,_shape,_mean,_var,_count
- __init__(epsilon=0.0001, shape=(), device=device(type='cpu'))[源代码]¶
- Overview:
Initialize
self.Seehelp(type(self))for accurate signature; setup the properties.- Arguments:
env (
gym.Env): the environment to wrap.epsilon (
Float): the epsilon used for self for the std outputshape (:obj: np.array): the np array shape used for the expression of this wrapper on attibutes of mean and variance
- property mean: ndarray¶
- Overview:
Property
meangotten fromself._mean
- static new_shape(obs_shape, act_shape, rew_shape)[源代码]¶
- Overview:
Get new shape of observation, acton, and reward; in this case unchanged.
- Arguments:
obs_shape (
Any), act_shape (Any), rew_shape (Any)- Returns:
obs_shape (
Any), act_shape (Any), rew_shape (Any)
- reset()[源代码]¶
- Overview:
Resets the state of the environment and reset properties:
_mean,_var,_count
- property std: ndarray¶
- Overview:
Property
stdcalculated fromself._varand the epsilon value ofself._epsilon
make_key_as_identifier¶
- ding.utils.default_helper.make_key_as_identifier(data: Dict[str, Any]) Dict[str, Any][源代码]¶
- Overview:
Make the key of dict into legal python identifier string so that it is compatible with some python magic method such as
__getattr.- Arguments:
data (
Dict[str, Any]): The original dict data.
- Return:
new_data (
Dict[str, Any]): The new dict data with legal identifier keys.
remove_illegal_item¶
- ding.utils.default_helper.remove_illegal_item(data: Dict[str, Any]) Dict[str, Any][源代码]¶
- Overview:
Remove illegal item in dict info, like str, which is not compatible with Tensor.
- Arguments:
data (
Dict[str, Any]): The original dict data.
- Return:
new_data (
Dict[str, Any]): The new dict data without legal items.
design_helper¶
Please refer to ding/utils/design_helper for more details.
SingletonMetaclass¶
fake_linklink¶
Please refer to ding/utils/fake_linklink for more details.
FakeClass¶
FakeNN¶
FakeLink¶
- class ding.utils.fake_linklink.FakeLink[源代码]¶
- Overview:
Fake link class.
- class allreduceOp_t(Sum, Max)¶
- Max¶
Alias for field number 1
- Sum¶
Alias for field number 0
- _asdict()¶
Return a new dict which maps field names to their values.
- _field_defaults = {}¶
- _fields = ('Sum', 'Max')¶
- classmethod _make(iterable)¶
Make a new allreduceOp_t object from a sequence or iterable
- _replace(**kwds)¶
Return a new allreduceOp_t object replacing specified fields with new values
- nn = <ding.utils.fake_linklink.FakeNN object>¶
- syncbnVarMode_t = syncbnVarMode_t(L2=None)¶
fast_copy¶
Please refer to ding/utils/fast_copy for more details.
_FastCopy¶
- class ding.utils.fast_copy._FastCopy[源代码]¶
- Overview:
The idea of this class comes from this article https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. We use recursive calls to copy each object that needs to be copied, which will be 5x faster than copy.deepcopy.
- Interfaces:
__init__,_copy_list,_copy_dict,_copy_tensor,_copy_ndarray,copy.
- _copy_ndarray(a: ndarray) ndarray[源代码]¶
- Overview:
Copy the ndarray.
- Arguments:
a (
np.ndarray): The ndarray to be copied.
file_helper¶
Please refer to ding/utils/file_helper for more details.
read_from_ceph¶
_get_redis¶
read_from_redis¶
_ensure_rediscluster¶
- ding.utils.file_helper._ensure_rediscluster(startup_nodes=[{'host': '127.0.0.1', 'port': '7000'}])[源代码]¶
- Overview:
Ensures redis usage
- Arguments:
- List of startup nodes (
dict) of host (
str): Host stringport (
int): Port number
- List of startup nodes (
- Returns:
(
RedisCluster(object)): RedisCluster object with givenhost,port, andFalsefordecode_responsesin default.
read_from_rediscluster¶
read_from_file¶
_ensure_memcached¶
read_from_mc¶
read_from_path¶
save_file_ceph¶
save_file_redis¶
save_file_rediscluster¶
read_file¶
- ding.utils.file_helper.read_file(path: str, fs_type: str | None = None, use_lock: bool = False) object[源代码]¶
- Overview:
Read file from path
- Arguments:
path (
str): The path of file to readfs_type (
strorNone): The file system type, support{'normal', 'ceph'}use_lock (
bool): Whetheruse_lockis in local normal file system
save_file¶
- ding.utils.file_helper.save_file(path: str, data: object, fs_type: str | None = None, use_lock: bool = False) None[源代码]¶
- Overview:
Save data to file of path
- Arguments:
path (
str): The path of file to save todata (
object): The data to savefs_type (
strorNone): The file system type, support{'normal', 'ceph'}use_lock (
bool): Whetheruse_lockis in local normal file system
remove_file¶
import_helper¶
Please refer to ding/utils/import_helper for more details.
try_import_ceph¶
try_import_mc¶
try_import_redis¶
try_import_rediscluster¶
try_import_link¶
import_module¶
k8s_helper¶
Please refer to ding/utils/k8s_helper for more details.
get_operator_server_kwargs¶
exist_operator_server¶
pod_exec_command¶
K8sType¶
K8sLauncher¶
- class ding.utils.k8s_helper.K8sLauncher(config_path: str)[源代码]¶
- Overview:
object to manage the K8s cluster
- Interfaces:
__init__,_load,create_cluster,_check_k3d_tools,delete_cluster,preload_images
- __init__(config_path: str) None[源代码]¶
- Overview:
Initialize the K8sLauncher object.
- Arguments:
config_path (
str): The path of the config file.
linklink_dist_helper¶
Please refer to ding/utils/linklink_dist_helper for more details.
get_rank¶
get_world_size¶
broadcast¶
allreduce¶
allreduce_async¶
get_group¶
dist_mode¶
dist_init¶
dist_finalize¶
DistContext¶
simple_group_split¶
- ding.utils.linklink_dist_helper.simple_group_split(world_size: int, rank: int, num_groups: int) List[源代码]¶
- Overview:
Split the group according to
worldsize,rankandnum_groups- Arguments:
world_size (
int): The world sizerank (
int): The ranknum_groups (
int): The number of groups
备注
With faulty input, raise
array split does not result in an equal division
synchronize¶
lock_helper¶
Please refer to ding/utils/lock_helper for more details.
LockContextType¶
LockContext¶
- class ding.utils.lock_helper.LockContext(lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]¶
- Overview:
Generate a LockContext in order to make sure the thread safety.
- Interfaces:
__init__,__enter__,__exit__.- Example:
>>> with LockContext() as lock: >>> print("Do something here.")
- __init__(lock_type: LockContextType = LockContextType.THREAD_LOCK)[源代码]¶
- Overview:
Init the lock according to the given type.
- Arguments:
lock_type (
LockContextType): The type of lock to be used. Defaults to LockContextType.THREAD_LOCK.
get_rw_file_lock¶
FcntlContext¶
- class ding.utils.lock_helper.FcntlContext(lock_path: str)[源代码]¶
- Overview:
A context manager that acquires an exclusive lock on a file using fcntl. This is useful for preventing multiple processes from running the same code.
- Interfaces:
__init__,__enter__,__exit__.- Example:
>>> lock_path = "/path/to/lock/file" >>> with FcntlContext(lock_path) as lock: >>> # Perform operations while the lock is held
get_file_lock¶
- ding.utils.lock_helper.get_file_lock(name: str, op: str) FcntlContext[源代码]¶
- Overview:
Acquires a file lock for the specified file.
- Arguments:
name (
str): The name of the file.op (
str): The operation to perform on the file lock.
log_helper¶
Please refer to ding/utils/log_helper for more details.
build_logger¶
- ding.utils.log_helper.build_logger(path: str, name: str | None = None, need_tb: bool = True, need_text: bool = True, text_level: int | str = 20) Tuple[Logger | None, SummaryWriter | None][源代码]¶
- Overview:
Build text logger and tensorboard logger.
- Arguments:
path (
str): Logger(Textlogger&SummaryWriter)’s saved dirname (
str): The logger file nameneed_tb (
bool): WhetherSummaryWriterinstance would be created and returnedneed_text (
bool): WhetherloggingLoggerinstance would be created and returnedtext_level (
int`orstr): Logging level oflogging.Logger, default set tologging.INFO
- Returns:
logger (
Optional[logging.Logger]): Logger that displays terminal outputtb_logger (
Optional['SummaryWriter']): Saves output to tfboard, only return whenneed_tb.
TBLoggerFactory¶
- class ding.utils.log_helper.TBLoggerFactory[源代码]¶
- Overview:
TBLoggerFactory is a factory class for
SummaryWriter.- Interfaces:
create_logger- Properties:
tb_loggers(Dict[str, SummaryWriter]): A dict that storesSummaryWriterinstances.
- classmethod create_logger(logdir: str) DistributedWriter[源代码]¶
- tb_loggers = {}¶
LoggerFactory¶
- class ding.utils.log_helper.LoggerFactory[源代码]¶
- Overview:
LoggerFactory is a factory class for
logging.Logger.- Interfaces:
create_logger,get_tabulate_vars,get_tabulate_vars_hor
- classmethod create_logger(path: str, name: str = 'default', level: int | str = 20) Logger[源代码]¶
- Overview:
Create logger using logging
- Arguments:
name (
str): Logger’s namepath (
str): Logger’s save dirlevel (
intorstr): Used to set the level. Reference:Logger.setLevelmethod.
- Returns:
(
logging.Logger): new logging logger
pretty_print¶
log_writer_helper¶
Please refer to ding/utils/log_writer_helper for more details.
DistributedWriter¶
- class ding.utils.log_writer_helper.DistributedWriter(*args, **kwargs)[源代码]¶
- Overview:
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. The best way is to use it in conjunction with the
routerto take advantage of the message and event components of the router (seewriter.plugin).- Interfaces:
get_instance,plugin,initialize,__del__
- classmethod get_instance(*args, **kwargs) DistributedWriter[源代码]¶
- Overview:
Get instance and set the root level instance on the first called. If args and kwargs is none, this method will return root instance.
- Arguments:
args (
Tuple): The arguments passed to the__init__function of the parent class, SummaryWriter.kwargs (
Dict): The keyword arguments passed to the__init__function of the parent class, SummaryWriter.
- plugin(router: Parallel, is_writer: bool = False) DistributedWriter[源代码]¶
- Overview:
Plugin
router, so when using this writer with active router, it will automatically send requests to the main writer instead of writing it to the disk. So we can collect data from multiple processes and write them into one file.- Arguments:
router (
Parallel): The router to be plugged in.is_writer (
bool): Whether this writer is the main writer.
- Examples:
>>> DistributedWriter().plugin(router, is_writer=True)
enable_parallel¶
normalizer_helper¶
Please refer to ding/utils/normalizer_helper for more details.
DatasetNormalizer¶
- class ding.utils.normalizer_helper.DatasetNormalizer(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[源代码]¶
- Overview:
The DatasetNormalizer class provides functionality to normalize and unnormalize data in a dataset. It takes a dataset as input and applies a normalizer function to each key in the dataset.
- Interfaces:
__init__,__repr__,normalize,unnormalize.
- __init__(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[源代码]¶
- Overview:
Initialize the NormalizerHelper object.
- Arguments:
dataset (
np.ndarray): The dataset to be normalized.normalizer (
str): The type of normalizer to be used. Can be a string representing the name of the normalizer class.path_lengths (
list): The length of the paths in the dataset. Defaults to None.
flatten¶
- ding.utils.normalizer_helper.flatten(dataset: dict, path_lengths: list) dict[源代码]¶
- Overview:
Flattens dataset of { key: [ n_episodes x max_path_length x dim ] } to { key : [ (n_episodes * sum(path_lengths)) x dim ] }
- Arguments:
dataset (
dict): The dataset to be flattened.path_lengths (
list): A list of path lengths for each episode.
- Returns:
flattened (
dict): The flattened dataset.
Normalizer¶
- class ding.utils.normalizer_helper.Normalizer(X)[源代码]¶
- Overview:
Parent class, subclass by defining the normalize and unnormalize methods
- Interfaces:
__init__,__repr__,normalize,unnormalize.
- __init__(X)[源代码]¶
- Overview:
Initialize the Normalizer object.
- Arguments:
X (
np.ndarray): The data to be normalized.
GaussianNormalizer¶
- class ding.utils.normalizer_helper.GaussianNormalizer(*args, **kwargs)[源代码]¶
- Overview:
A class that normalizes data to zero mean and unit variance.
- Interfaces:
__init__,__repr__,normalize,unnormalize.
- __init__(*args, **kwargs)[源代码]¶
- Overview:
Initialize the GaussianNormalizer object.
- Arguments:
args (
list): The arguments passed to the__init__function of the parent class, i.e., the Normalizer class.kwargs (
dict): The keyword arguments passed to the__init__function of the parent class, i.e., the Normalizer class.
CDFNormalizer¶
- class ding.utils.normalizer_helper.CDFNormalizer(X)[源代码]¶
- Overview:
A class that makes training data uniform (over each dimension) by transforming it with marginal CDFs.
- Interfaces:
__init__,__repr__,normalize,unnormalize.
- __init__(X)[源代码]¶
- Overview:
Initialize the CDFNormalizer object.
- Arguments:
X (
np.ndarray): The data to be normalized.
- normalize(x: ndarray) ndarray[源代码]¶
- Overview:
Normalizes the input data.
- Arguments:
x (
np.ndarray): The input data.
- Returns:
ret (
np.ndarray): The normalized data.
CDFNormalizer1d¶
- class ding.utils.normalizer_helper.CDFNormalizer1d(X: ndarray)[源代码]¶
- Overview:
CDF normalizer for a single dimension. This class provides methods to normalize and unnormalize data using the Cumulative Distribution Function (CDF) approach.
- Interfaces:
__init__,__repr__,normalize,unnormalize.
- __init__(X: ndarray)[源代码]¶
- Overview:
Initialize the CDFNormalizer1d object.
- Arguments:
X (
np.ndarray): The data to be normalized.
empirical_cdf¶
- ding.utils.normalizer_helper.empirical_cdf(sample: ~numpy.ndarray) -> (<class 'numpy.ndarray'>, <class 'numpy.ndarray'>)[源代码]¶
- Overview:
Compute the empirical cumulative distribution function (CDF) of a given sample.
- Arguments:
sample (
np.ndarray): The input sample for which to compute the empirical CDF.
- Returns:
quantiles (
np.ndarray): The unique values in the sample.cumprob (
np.ndarray): The cumulative probabilities corresponding to the quantiles.
- References:
Stack Overflow: https://stackoverflow.com/a/33346366
atleast_2d¶
LimitsNormalizer¶
- class ding.utils.normalizer_helper.LimitsNormalizer(X)[源代码]¶
- Overview:
A class that normalizes and unnormalizes values within specified limits. This class maps values within the range [xmin, xmax] to the range [-1, 1].
- Interfaces:
__init__,__repr__,normalize,unnormalize.
orchestrator_launcher¶
Please refer to ding/utils/orchestrator_launcher for more details.
OrchestratorLauncher¶
- class ding.utils.orchestrator_launcher.OrchestratorLauncher(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack')[源代码]¶
- Overview:
Object to manage di-orchestrator in existing k8s cluster
- Interfaces:
__init__,create_orchestrator,delete_orchestrator
- __init__(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack') None[源代码]¶
- Overview:
Initialize the OrchestratorLauncher object.
- Arguments:
version (
str): The version of di-orchestrator.name (
str): The name of di-orchestrator.cluster (
K8sLauncher): The k8s cluster to deploy di-orchestrator.registry (
str): The docker registry to pull images.cert_manager_version (
str): The version of cert-manager.cert_manager_registry (
str): The docker registry to pull cert-manager images.
create_components_from_config¶
wait_to_be_ready¶
- ding.utils.orchestrator_launcher.wait_to_be_ready(namespace: str, component: str, timeout: int = 120) None[源代码]¶
- Overview:
Wait for the component to be ready.
- Arguments:
namespace (
str): The namespace of the component.component (
str): The name of the component.timeout (
int): The timeout of waiting.
profiler_helper¶
Please refer to ding/utils/profiler_helper for more details.
Profiler¶
- class ding.utils.profiler_helper.Profiler[源代码]¶
- Overview:
A class for profiling code execution. It can be used as a context manager or a decorator.
- Interfaces:
__init__,mkdir,write_profile,profile.
- mkdir(directory: str)[源代码]¶
- OverView:
Create a directory if it doesn’t exist.
- Arguments:
directory (
str): The path of the directory to be created.
pytorch_ddp_dist_helper¶
Please refer to ding/utils/pytorch_ddp_dist_helper for more details.
get_rank¶
get_world_size¶
allreduce¶
allreduce_async¶
reduce_data¶
allreduce_data¶
get_group¶
dist_mode¶
dist_init¶
- ding.utils.pytorch_ddp_dist_helper.dist_init(backend: str = 'nccl', addr: str | None = None, port: str | None = None, rank: int | None = None, world_size: int | None = None) Tuple[int, int][源代码]¶
- Overview:
Initialize the distributed training setting
- Arguments:
backend (
str): The backend of the distributed training, support['nccl', 'gloo']addr (
str): The address of the master nodeport (
str): The port of the master noderank (
int): The rank of current processworld_size (
int): The total number of processes
dist_finalize¶
DDPContext¶
simple_group_split¶
- ding.utils.pytorch_ddp_dist_helper.simple_group_split(world_size: int, rank: int, num_groups: int) List[源代码]¶
- Overview:
Split the group according to
worldsize,rankandnum_groups- Arguments:
world_size (
int): The world sizerank (
int): The ranknum_groups (
int): The number of groups
备注
With faulty input, raise
array split does not result in an equal division
to_ddp_config¶
registry¶
Please refer to ding/utils/registry for more details.
Registry¶
- class ding.utils.registry.Registry(*args, **kwargs)[源代码]¶
- Overview:
A helper class for managing registering modules, it extends a dictionary and provides a register functions.
- Interfaces:
__init__,register,get,build,query,query_details- Examples (creating):
>>> some_registry = Registry({"default": default_module})
- Examples (registering: normal way):
>>> def foo(): >>> ... >>> some_registry.register("foo_module", foo)
- Examples (registering: decorator way):
>>> @some_registry.register("foo_module") >>> @some_registry.register("foo_modeul_nickname") >>> def foo(): >>> ...
- Examples (accessing):
>>> f = some_registry["foo_module"]
- __init__(*args, **kwargs) None[源代码]¶
- Overview:
Initialize the Registry object.
- Arguments:
args (
Tuple): The arguments passed to the__init__function of the parent class, dict.kwargs (
Dict): The keyword arguments passed to the__init__function of the parent class, dict.
- static _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) None[源代码]¶
- Overview:
Register the module.
- Arguments:
module_dict (
dict): The dict to store the module.module_name (
str): The name of the module.module (
Callable): The module to be registered.force_overwrite (
bool): Whether to overwrite the module with the same name.
- build(obj_type: str, *obj_args, **obj_kwargs) object[源代码]¶
- Overview:
Build the object.
- Arguments:
obj_type (
str): The type of the object.obj_args (
Tuple): The arguments passed to the object.obj_kwargs (
Dict): The keyword arguments passed to the object.
- get(module_name: str) Callable[源代码]¶
- Overview:
Get the module.
- Arguments:
module_name (
str): The name of the module.
- query_details(aliases: Iterable | None = None) OrderedDict[源代码]¶
- Overview:
Get the details of the registered modules.
- Arguments:
aliases (
Optional[Iterable]): The aliases of the modules.
- register(module_name: str | None = None, module: Callable | None = None, force_overwrite: bool = False) Callable[源代码]¶
- Overview:
Register the module.
- Arguments:
module_name (
Optional[str]): The name of the module.module (
Optional[Callable]): The module to be registered.force_overwrite (
bool): Whether to overwrite the module with the same name.
render_helper¶
Please refer to ding/utils/render_helper for more details.
render_env¶
render¶
get_env_fps¶
fps¶
- ding.utils.render_helper.fps(env_manager: BaseEnvManager) int[源代码]¶
- Overview:
Render the environment’s fps.
- Arguments:
env (
BaseEnvManager): DI-engine env manager instance.
- Returns:
fps (
int).
scheduler_helper¶
Please refer to ding/utils/scheduler_helper for more details.
Scheduler¶
- class ding.utils.scheduler_helper.Scheduler(merged_scheduler_config: EasyDict)[源代码]¶
- Overview:
Update learning parameters when the trueskill metrics has stopped improving. For example, models often benefits from reducing entropy weight once the learning process stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a ‘patience’ number of epochs, the corresponding parameter is increased or decreased, which decides on the ‘schedule_mode’.
- Arguments:
- schedule_flag (
bool): Indicates whether to use scheduler in training pipeline. Default: False
- schedule_flag (
- schedule_mode (
str): One of ‘reduce’, ‘add’,’multi’,’div’. The schecule_mode decides the way of updating the parameters. Default:’reduce’.
- schedule_mode (
- factor (
float)Amount (greater than 0) by which the parameter will be increased/decreased. Default: 0.05
- factor (
- change_range (
list): Indicates the minimum and maximum value the parameter can reach respectively. Default: [-1,1]
- change_range (
- threshold (
float): Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
- threshold (
- optimize_mode (
str): One of ‘min’, ‘max’, which indicates the sign of optimization objective. Dynamic_threshold = last_metrics + threshold in max mode or last_metrics - threshold in min mode. Default: ‘min’
- optimize_mode (
- patience (
int): Number of epochs with no improvement after which the parameter will be updated. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only update the parameter after the 3rd epoch if the metrics still hasn’t improved then. Default: 10.
- patience (
- cooldown (
int): Number of epochs to wait before resuming normal operation after the parameter has been updated. Default: 0.
- cooldown (
- Interfaces:
__init__, update_param, step
- Property:
in_cooldown, is_better
- __init__(merged_scheduler_config: EasyDict) None[源代码]¶
- Overview:
Initialize the scheduler.
- Arguments:
- merged_scheduler_config (
EasyDict): the scheduler config, which merges the user config and defaul config
- merged_scheduler_config (
- config = {'change_range': [-1, 1], 'cooldown': 0, 'factor': 0.05, 'optimize_mode': 'min', 'patience': 10, 'schedule_flag': False, 'schedule_mode': 'reduce', 'threshold': 0.0001}¶
- property in_cooldown: bool¶
- Overview:
Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler will ignore any bad epochs.
- is_better(cur: float) bool[源代码]¶
- Overview:
Checks whether the current metrics is better than last matric with respect to threshold.
- Args:
cur (
float): current metrics
segment_tree¶
Please refer to ding/utils/segment_tree for more details.
njit¶
SegmentTree¶
- class ding.utils.segment_tree.SegmentTree(capacity: int, operation: Callable, neutral_element: float | None = None)[源代码]¶
- Overview:
Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, non-leaf nodes are to do some operations on its left and right child.
- Interfaces:
__init__,reduce,__setitem__,__getitem__
- __init__(capacity: int, operation: Callable, neutral_element: float | None = None) None[源代码]¶
- Overview:
Initialize the segment tree. Tree’s root node is at index 1.
- Arguments:
capacity (
int): Capacity of the tree (the number of the leaf nodes), should be the power of 2.operation (
function): The operation function to construct the tree, e.g. sum, max, min, etc.neutral_element (
floatorNone): The value of the neutral element, which is used to init all nodes value in the tree.
- reduce(start: int = 0, end: int | None = None) float[源代码]¶
- Overview:
Reduce the tree in range
[start, end)- Arguments:
start (
int): Start index(relative index, the first leaf node is 0), default set to 0end (
intorNone): End index(relative index), default set toself.capacity
- Returns:
reduce_result (
float): The reduce result value, which is dependent on data type and operation
SumSegmentTree¶
- class ding.utils.segment_tree.SumSegmentTree(capacity: int)[源代码]¶
- Overview:
Sum segment tree, which is inherited from
SegmentTree. Init by passingoperation='sum'.- Interfaces:
__init__,find_prefixsum_idx
- __init__(capacity: int) None[源代码]¶
- Overview:
Init sum segment tree by passing
operation='sum'- Arguments:
capacity (
int): Capacity of the tree (the number of the leaf nodes).
- find_prefixsum_idx(prefixsum: float, trust_caller: bool = True) int[源代码]¶
- Overview:
Find the highest non-zero index i, sum_{j}leaf[j] <=
prefixsum(where 0 <= j < i) and sum_{j}leaf[j] >prefixsum(where 0 <= j < i+1)- Arguments:
prefixsum (
float): The target prefixsum.- trust_caller (
bool): Whether to trust caller, which means whether to check whether this tree’s sum is greater than the inputprefixsumby callingreducefunction. Default set to True.
- trust_caller (
- Returns:
idx (
int): Eligible index.
MinSegmentTree¶
_setitem¶
- ding.utils.segment_tree._setitem(tree: ndarray, idx: int, val: float, operation: str) None¶
- Overview:
Set
tree[idx] = val; Then update the related nodes.- Arguments:
tree (
np.ndarray): The tree array.idx (
int): The index of the leaf node.val (
float): The value that will be assigned toleaf[idx].operation (
str): The operation function to construct the tree, e.g. sum, max, min, etc.
_reduce¶
- ding.utils.segment_tree._reduce(tree: ndarray, start: int, end: int, neutral_element: float, operation: str) float¶
- Overview:
Reduce the tree in range
[start, end)- Arguments:
tree (
np.ndarray): The tree array.start (
int): Start index(relative index, the first leaf node is 0).end (
int): End index(relative index).neutral_element (
float): The value of the neutral element, which is used to init all nodes value in the tree.operation (
str): The operation function to construct the tree, e.g. sum, max, min, etc.
_find_prefixsum_idx¶
- ding.utils.segment_tree._find_prefixsum_idx(tree: ndarray, capacity: int, prefixsum: float, neutral_element: float) int¶
- Overview:
Find the highest non-zero index i, sum_{j}leaf[j] <=
prefixsum(where 0 <= j < i) and sum_{j}leaf[j] >prefixsum(where 0 <= j < i+1)- Arguments:
tree (
np.ndarray): The tree array.capacity (
int): Capacity of the tree (the number of the leaf nodes).prefixsum (
float): The target prefixsum.neutral_element (
float): The value of the neutral element, which is used to init all nodes value in the tree.
slurm_helper¶
Please refer to ding/utils/slurm_helper for more details.
get_ip¶
get_manager_node_ip¶
get_cls_info¶
node_to_partition¶
node_to_host¶
find_free_port_slurm¶
system_helper¶
Please refer to ding/utils/system_helper for more details.
get_ip¶
get_pid¶
get_task_uid¶
PropagatingThread¶
- class ding.utils.system_helper.PropagatingThread(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None)[源代码]¶
- Overview:
Subclass of Thread that propagates execution exception in the thread to the caller
- Interfaces:
run,join- Examples:
>>> def func(): >>> raise Exception() >>> t = PropagatingThread(target=func, args=()) >>> t.start() >>> t.join()
find_free_port¶
time_helper_base¶
Please refer to ding/utils/time_helper_base for more details.
TimeWrapper¶
time_helper_cuda¶
Please refer to ding/utils/time_helper_cuda for more details.
get_cuda_time_wrapper¶
- ding.utils.time_helper_cuda.get_cuda_time_wrapper() Callable[[], TimeWrapper][源代码]¶
- Overview:
Return the
TimeWrapperCudaclass, this wrapper aims to ensure compatibility in no cuda device- Returns:
TimeWrapperCuda(
class): SeeTimeWrapperCudaclass
备注
Must use
torch.cuda.synchronize(), reference: <https://blog.csdn.net/u013548568/article/details/81368019>
time_helper¶
Please refer to ding/utils/time_helper for more details.
build_time_helper¶
- ding.utils.time_helper.build_time_helper(cfg: EasyDict | None = None, wrapper_type: str | None = None) Callable[[], TimeWrapper][源代码]¶
- Overview:
Build the timehelper
- Arguments:
- cfg (
dict): The config file, which is a multilevel dict, have large domain like evaluate, common, model, train etc, and each large domain has it’s smaller domain.
- cfg (
wrapper_type (
str): The type of wrapper returned, support['time', 'cuda']
- Returns:
- time_wrapper (
TimeWrapper): Return the corresponding timewrapper, Reference:
ding.utils.timehelper.TimeWrapperTimeandding.utils.timehelper.get_cuda_time_wrapper.
- time_wrapper (
EasyTimer¶
TimeWrapperTime¶
- class ding.utils.time_helper.TimeWrapperTime[源代码]¶
- Overview:
A class method that inherit from
TimeWrapperclass- Interfaces:
start_time,end_time
WatchDog¶
- class ding.utils.time_helper.WatchDog(timeout: int = 1)[源代码]¶
- Overview:
Simple watchdog timer to detect timeouts
- Arguments:
timeout (
int): Timeout value of thewatchdog [seconds].
备注
If it is not reset before exceeding this value,
TimeourErrorraised.- Interfaces:
start,stop- Examples:
>>> watchdog = WatchDog(x) # x is a timeout value >>> ... >>> watchdog.start() >>> ... # Some function
- __init__(timeout: int = 1)[源代码]¶
- Overview:
Initialize watchdog with
timeoutvalue.- Arguments:
timeout (
int): Timeout value of thewatchdog [seconds].
loader.base¶
Please refer to ding/utils/loader/base for more details.
ILoaderClass¶
- class ding.utils.loader.base.ILoaderClass[源代码]¶
- Overview:
Base class of loader.
- Interfaces:
__init__,_load,load,check,__call__,__and__,__or__,__rshift__
- __check(value: _ValueType) bool¶
- Overview:
Check whether the value is valid.
- Arguments:
value (
_ValueType): The value to be checked.
- __load(value: _ValueType) _ValueType¶
- Overview:
Load the value.
- Arguments:
value (
_ValueType): The value to be loaded.
- abstract _load(value: _ValueType) _ValueType[源代码]¶
- Overview:
Load the value.
- Arguments:
value (
_ValueType): The value to be loaded.
loader.collection¶
Please refer to ding/utils/loader/collection for more details.
CollectionError¶
- class ding.utils.loader.collection.CollectionError(errors: List[Tuple[int, Exception]])[源代码]¶
- Overview:
Collection error.
- Interfaces:
__init__,errors- Properties:
errors
- __init__(errors: List[Tuple[int, Exception]])[源代码]¶
- Overview:
Initialize the CollectionError.
- Arguments:
errors (
COLLECTION_ERRORS): The errors.
- _abc_impl = <_abc._abc_data object>¶
- property errors: List[Tuple[int, Exception]]¶
- Overview:
Get the errors.
collection¶
- ding.utils.loader.collection.collection(loader, type_back: bool = True) ILoaderClass[源代码]¶
- Overview:
Create a collection loader.
- Arguments:
loader (
ILoaderClass): The loader.type_back (
bool): Whether to convert the type back.
tuple¶
- ding.utils.loader.collection.tuple_(*loaders) ILoaderClass[源代码]¶
- Overview:
Create a tuple loader.
- Arguments:
loaders (
tuple): The loaders.
length¶
- ding.utils.loader.collection.length(min_length: int | None = None, max_length: int | None = None) ILoaderClass[源代码]¶
- Overview:
Create a length loader.
- Arguments:
min_length (
int): The minimum length.max_length (
int): The maximum length.
length_is¶
- ding.utils.loader.collection.length_is(length_: int) ILoaderClass[源代码]¶
- Overview:
Create a length loader.
- Arguments:
length (
int): The length.
contains¶
- ding.utils.loader.collection.contains(content) ILoaderClass[源代码]¶
- Overview:
Create a contains loader.
- Arguments:
content (
Any): The content.
cofilter¶
- ding.utils.loader.collection.cofilter(checker: Callable[[Any], bool], type_back: bool = True) ILoaderClass[源代码]¶
- Overview:
Create a cofilter loader.
- Arguments:
checker (
Callable[[Any], bool]): The checker.type_back (
bool): Whether to convert the type back.
tpselector¶
- ding.utils.loader.collection.tpselector(*indices) ILoaderClass[源代码]¶
- Overview:
Create a tuple selector loader.
- Arguments:
indices (
tuple): The indices.
loader.dict¶
Please refer to ding/utils/loader/dict for more details.
DictError¶
- class ding.utils.loader.dict.DictError(errors: Mapping[str, Exception])[源代码]¶
- Overview:
Dict error.
- Interfaces:
__init__,errors- Properties:
errors
- __init__(errors: Mapping[str, Exception])[源代码]¶
- Overview:
Initialize the DictError.
- Arguments:
errors (
DICT_ERRORS): The errors.
- _abc_impl = <_abc._abc_data object>¶
- property errors: Mapping[str, Exception]¶
- Overview:
Get the errors.
dict¶
- ding.utils.loader.dict.dict_(**kwargs) ILoaderClass[源代码]¶
- Overview:
Create a dict loader.
- Arguments:
kwargs (
Mapping[str, ILoaderClass]): The loaders.
loader.exception¶
Please refer to ding/utils/loader/exception for more details.
CompositeStructureError¶
loader.mapping¶
Please refer to ding/utils/loader/mapping for more details.
MappingError¶
- class ding.utils.loader.mapping.MappingError(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[源代码]¶
- Overview:
Mapping error.
- Interfaces:
__init__,errors
- __init__(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[源代码]¶
- Overview:
Initialize the MappingError.
- Arguments:
key_errors (
MAPPING_ERRORS): The key errors.value_errors (
MAPPING_ERRORS): The value errors.
- _abc_impl = <_abc._abc_data object>¶
mapping¶
- ding.utils.loader.mapping.mapping(key_loader, value_loader, type_back: bool = True) ILoaderClass[源代码]¶
- Overview:
Create a mapping loader.
- Arguments:
key_loader (
ILoaderClass): The key loader.value_loader (
ILoaderClass): The value loader.type_back (
bool): Whether to convert the type back.
mpfilter¶
- ding.utils.loader.mapping.mpfilter(check: Callable[[Any, Any], bool], type_back: bool = True) ILoaderClass[源代码]¶
- Overview:
Create a mapping filter loader.
- Arguments:
check (
Callable[[Any, Any], bool]): The check function.type_back (
bool): Whether to convert the type back.
mpkeys¶
- ding.utils.loader.mapping.mpkeys() ILoaderClass[源代码]¶
- Overview:
Create a mapping keys loader.
mpvalues¶
- ding.utils.loader.mapping.mpvalues() ILoaderClass[源代码]¶
- Overview:
Create a mapping values loader.
mpitems¶
- ding.utils.loader.mapping.mpitems() ILoaderClass[源代码]¶
- Overview:
Create a mapping items loader.
item¶
- ding.utils.loader.mapping.item(key) ILoaderClass[源代码]¶
- Overview:
Create a item loader.
- Arguments:
key (
Any): The key.
item_or¶
- ding.utils.loader.mapping.item_or(key, default) ILoaderClass[源代码]¶
- Overview:
Create a item or loader.
- Arguments:
key (
Any): The key.default (
Any): The default value.
loader.norm¶
Please refer to ding/utils/loader/norm for more details.
_callable_to_norm¶
- ding.utils.loader.norm._callable_to_norm(func: Callable[[Any], Any]) INormClass[源代码]¶
- Overview:
Convert callable to norm.
- Arguments:
func (
Callable[[Any], Any]): The callable to be converted.
norm¶
- ding.utils.loader.norm.norm(value) INormClass[源代码]¶
- Overview:
Convert value to norm.
- Arguments:
value (
Any): The value to be converted.
normfunc¶
_unary¶
- ding.utils.loader.norm._unary(a: INormClass, func: Callable[[Any], Any]) INormClass[源代码]¶
- Overview:
Create a unary norm.
- Arguments:
a (
INormClass): The norm.func (
UNARY_FUNC): The function.
_binary¶
- ding.utils.loader.norm._binary(a: INormClass, b: INormClass, func: Callable[[Any, Any], Any]) INormClass[源代码]¶
- Overview:
Create a binary norm.
- Arguments:
a (
INormClass): The first norm.b (
INormClass): The second norm.func (
BINARY_FUNC): The function.
_binary_reducing¶
INormClass¶
- class ding.utils.loader.norm.INormClass[源代码]¶
- Overview:
The norm class.
- Interfaces:
__call__,__add__,__radd__,__sub__,__rsub__,__mul__,__rmul__,__matmul__,__rmatmul__,__truediv__,__rtruediv__,__floordiv__,__rfloordiv__,__mod__,__rmod__,__pow__,__rpow__,__lshift__,__rlshift__,__rshift__,__rrshift__,__and__,__rand__,__or__,__ror__,__xor__,__rxor__,__invert__,__pos__,__neg__,__eq__,__ne__,__lt__,__le__,__gt__,__ge__
lcmp¶
loader.number¶
Please refer to ding/utils/loader/number for more details.
numeric¶
- ding.utils.loader.number.numeric(int_ok: bool = True, float_ok: bool = True, inf_ok: bool = True) ILoaderClass[源代码]¶
- Overview:
Create a numeric loader.
- Arguments:
int_ok (
bool): Whether int is allowed.float_ok (
bool): Whether float is allowed.inf_ok (
bool): Whether inf is allowed.
interval¶
- ding.utils.loader.number.interval(left: int | float | None = None, right: int | float | None = None, left_ok: bool = True, right_ok: bool = True, eps=0.0) ILoaderClass[源代码]¶
- Overview:
Create a interval loader.
- Arguments:
left (
Optional[NUMBER_TYPING]): The left bound.right (
Optional[NUMBER_TYPING]): The right bound.left_ok (
bool): Whether left bound is allowed.right_ok (
bool): Whether right bound is allowed.eps (
float): The epsilon.
is_negative¶
- ding.utils.loader.number.is_negative() ILoaderClass[源代码]¶
- Overview:
Create a negative loader.
is_positive¶
- ding.utils.loader.number.is_positive() ILoaderClass[源代码]¶
- Overview:
Create a positive loader.
non_negative¶
- ding.utils.loader.number.non_negative() ILoaderClass[源代码]¶
- Overview:
Create a non-negative loader.
non_positive¶
- ding.utils.loader.number.non_positive() ILoaderClass[源代码]¶
- Overview:
Create a non-positive loader.
negative¶
- ding.utils.loader.number.negative() ILoaderClass[源代码]¶
- Overview:
Create a negative loader.
positive¶
- ding.utils.loader.number.positive() ILoaderClass[源代码]¶
- Overview:
Create a positive loader.
_math_binary¶
- ding.utils.loader.number._math_binary(func: Callable[[Any, Any], Any], attachment) ILoaderClass[源代码]¶
- Overview:
Create a math binary loader.
- Arguments:
func (
Callable[[Any, Any], Any]): The function.attachment (
Any): The attachment.
plus¶
- ding.utils.loader.number.plus(addend) ILoaderClass[源代码]¶
- Overview:
Create a plus loader.
- Arguments:
addend (
Any): The addend.
minus¶
- ding.utils.loader.number.minus(subtrahend) ILoaderClass[源代码]¶
- Overview:
Create a minus loader.
- Arguments:
subtrahend (
Any): The subtrahend.
minus_with¶
- ding.utils.loader.number.minus_with(minuend) ILoaderClass[源代码]¶
- Overview:
Create a minus loader.
- Arguments:
minuend (
Any): The minuend.
multi¶
- ding.utils.loader.number.multi(multiplier) ILoaderClass[源代码]¶
- Overview:
Create a multi loader.
- Arguments:
multiplier (
Any): The multiplier.
divide¶
- ding.utils.loader.number.divide(divisor) ILoaderClass[源代码]¶
- Overview:
Create a divide loader.
- Arguments:
divisor (
Any): The divisor.
divide_with¶
- ding.utils.loader.number.divide_with(dividend) ILoaderClass[源代码]¶
- Overview:
Create a divide loader.
- Arguments:
dividend (
Any): The dividend.
power¶
- ding.utils.loader.number.power(index) ILoaderClass[源代码]¶
- Overview:
Create a power loader.
- Arguments:
index (
Any): The index.
power_with¶
- ding.utils.loader.number.power_with(base) ILoaderClass[源代码]¶
- Overview:
Create a power loader.
- Arguments:
base (
Any): The base.
msum¶
- ding.utils.loader.number.msum(*items) ILoaderClass[源代码]¶
- Overview:
Create a sum loader.
- Arguments:
items (
tuple): The items.
mmulti¶
- ding.utils.loader.number.mmulti(*items) ILoaderClass[源代码]¶
- Overview:
Create a multi loader.
- Arguments:
items (
tuple): The items.
_msinglecmp¶
- ding.utils.loader.number._msinglecmp(first, op, second) ILoaderClass[源代码]¶
- Overview:
Create a single compare loader.
- Arguments:
first (
Any): The first item.op (
str): The operator.second (
Any): The second item.
mcmp¶
- ding.utils.loader.number.mcmp(first, *items) ILoaderClass[源代码]¶
- Overview:
Create a multi compare loader.
- Arguments:
first (
Any): The first item.items (
tuple): The items.
loader.string¶
Please refer to ding/utils/loader/string for more details.
enum¶
- ding.utils.loader.string.enum(*items, case_sensitive: bool = True) ILoaderClass[源代码]¶
- Overview:
Create an enum loader.
- Arguments:
items (
Iterable[str]): The items.case_sensitive (
bool): Whether case sensitive.
_to_regexp¶
rematch¶
- ding.utils.loader.string.rematch(regexp: str | Pattern) ILoaderClass[源代码]¶
- Overview:
Create a rematch loader.
- Arguments:
regexp (
Union[str, re.Pattern]): The regexp.
regrep¶
- ding.utils.loader.string.regrep(regexp: str | Pattern, group: int = 0) ILoaderClass[源代码]¶
- Overview:
Create a regrep loader.
- Arguments:
regexp (
Union[str, re.Pattern]): The regexp.group (
int): The group.
loader.types¶
Please refer to ding/utils/loader/types for more details.
is_type¶
- ding.utils.loader.types.is_type(type_: type) ILoaderClass[源代码]¶
- Overview:
Create a type loader.
- Arguments:
type_ (
type): The type.
to_type¶
- ding.utils.loader.types.to_type(type_: type) ILoaderClass[源代码]¶
- Overview:
Create a type loader.
- Arguments:
type_ (
type): The type.
is_callable¶
- ding.utils.loader.types.is_callable() ILoaderClass[源代码]¶
- Overview:
Create a callable loader.
prop¶
- ding.utils.loader.types.prop(attr_name: str) ILoaderClass[源代码]¶
- Overview:
Create a attribute loader.
- Arguments:
attr_name (
str): The attribute name.
method¶
- ding.utils.loader.types.method(method_name: str) ILoaderClass[源代码]¶
- Overview:
Create a method loader.
- Arguments:
method_name (
str): The method name.
fcall¶
- ding.utils.loader.types.fcall(*args, **kwargs) ILoaderClass[源代码]¶
- Overview:
Create a function loader.
- Arguments:
args (
Tuple[Any]): The args.kwargs (
Dict[str, Any]): The kwargs.
fpartial¶
- ding.utils.loader.types.fpartial(*args, **kwargs) ILoaderClass[源代码]¶
- Overview:
Create a partial function loader.
- Arguments:
args (
Tuple[Any]): The args.kwargs (
Dict[str, Any]): The kwargs.
loader.utils¶
Please refer to ding/utils/loader/utils for more details.
keep¶
- ding.utils.loader.utils.keep() ILoaderClass[源代码]¶
- Overview:
Create a keep loader.
raw¶
- ding.utils.loader.utils.raw(value) ILoaderClass[源代码]¶
- Overview:
Create a raw loader.
optional¶
- ding.utils.loader.utils.optional(loader) ILoaderClass[源代码]¶
- Overview:
Create a optional loader.
- Arguments:
loader (
ILoaderClass): The loader.
check_only¶
- ding.utils.loader.utils.check_only(loader) ILoaderClass[源代码]¶
- Overview:
Create a check only loader.
- Arguments:
loader (
ILoaderClass): The loader.
check¶
- ding.utils.loader.utils.check(loader) ILoaderClass[源代码]¶
- Overview:
Create a check loader.
- Arguments:
loader (
ILoaderClass): The loader.