Skip to content

ding.rl_utils.exploration

ding.rl_utils.exploration

BaseNoise

Bases: ABC

Overview

Base class for action noise

Interface: init, call Examples: >>> noise_generator = OUNoise() # init one type of noise >>> noise = noise_generator(action.shape, action.device) # generate noise

__init__()

Overview

Initialization method.

__call__(shape, device) abstractmethod

Overview

Generate noise according to action tensor's shape, device.

Arguments: - shape (:obj:tuple): size of the action tensor, output noise's size should be the same. - device (:obj:str): device of the action tensor, output noise's device should be the same as it. Returns: - noise (:obj:torch.Tensor): generated action noise, have the same shape and device with the input action tensor.

GaussianNoise

Bases: BaseNoise

Overview

Derived class for generating gaussian noise, which satisfies :math:X \sim N(\mu, \sigma^2)

Interface: init, call

__init__(mu=0.0, sigma=1.0)

Overview

Initialize :math:\mu and :math:\sigma in Gaussian Distribution.

Arguments: - mu (:obj:float): :math:\mu , mean value. - sigma (:obj:float): :math:\sigma , standard deviation, should be positive.

__call__(shape, device)

Overview

Generate gaussian noise according to action tensor's shape, device

Arguments: - shape (:obj:tuple): size of the action tensor, output noise's size should be the same - device (:obj:str): device of the action tensor, output noise's device should be the same as it Returns: - noise (:obj:torch.Tensor): generated action noise, have the same shape and device with the input action tensor

OUNoise

Bases: BaseNoise

Overview

Derived class for generating Ornstein-Uhlenbeck process noise. Satisfies :math:dx_t=\theta(\mu-x_t)dt + \sigma dW_t, where :math:W_t denotes Weiner Process, acting as a random perturbation term.

Interface: init, reset, call

x0 property writable

Overview

Get self._x0.

__init__(mu=0.0, sigma=0.3, theta=0.15, dt=0.01, x0=0.0)

Overview

Initialize _alpha :math:= heta * dt\, beta :math:= \sigma * \sqrt{dt}, in Ornstein-Uhlenbeck process.

Arguments: - mu (:obj:float): :math:\mu , mean value. - sigma (:obj:float): :math:\sigma , standard deviation of the perturbation noise. - theta (:obj:float): How strongly the noise reacts to perturbations, greater value means stronger reaction. - dt (:obj:float): The derivative of time t. - x0 (:obj:Union[float, torch.Tensor]): The initial state of the noise, should be a scalar or tensor with the same shape as the action tensor.

reset()

Overview

Reset _x to the initial state _x0.

__call__(shape, device, mu=None)

Overview

Generate gaussian noise according to action tensor's shape, device.

Arguments: - shape (:obj:tuple): The size of the action tensor, output noise's size should be the same. - device (:obj:str): The device of the action tensor, output noise's device should be the same as it. - mu (:obj:float): The new mean value :math:\mu, you can set it to None if don't need it. Returns: - noise (:obj:torch.Tensor): generated action noise, have the same shape and device with the input action tensor.

get_epsilon_greedy_fn(start, end, decay, type_='exp')

Overview

Generate an epsilon_greedy function with decay, which inputs current timestep and outputs current epsilon.

Arguments: - start (:obj:float): Epsilon start value. For linear , it should be 1.0. - end (:obj:float): Epsilon end value. - decay (:obj:int): Controls the speed that epsilon decreases from start to end. We recommend epsilon decays according to env step rather than iteration. - type (:obj:str): How epsilon decays, now supports ['linear', 'exp'(exponential)] . Returns: - eps_fn (:obj:function): The epsilon greedy function with decay.

create_noise_generator(noise_type, noise_kwargs)

Overview

Given the key (noise_type), create a new noise generator instance if in noise_mapping's values, or raise an KeyError. In other words, a derived noise generator must first register, then call create_noise generator to get the instance object.

Arguments: - noise_type (:obj:str): the type of noise generator to be created. Returns: - noise (:obj:BaseNoise): the created new noise generator, should be an instance of one of noise_mapping's values.

Full Source Code

../ding/rl_utils/exploration.py

1import math 2from abc import ABC, abstractmethod 3from typing import Callable, Union, Optional 4from copy import deepcopy 5from ding.torch_utils.data_helper import to_device 6 7import torch 8 9 10def get_epsilon_greedy_fn(start: float, end: float, decay: int, type_: str = 'exp') -> Callable: 11 """ 12 Overview: 13 Generate an epsilon_greedy function with decay, which inputs current timestep and outputs current epsilon. 14 Arguments: 15 - start (:obj:`float`): Epsilon start value. For ``linear`` , it should be 1.0. 16 - end (:obj:`float`): Epsilon end value. 17 - decay (:obj:`int`): Controls the speed that epsilon decreases from ``start`` to ``end``. \ 18 We recommend epsilon decays according to env step rather than iteration. 19 - type (:obj:`str`): How epsilon decays, now supports ``['linear', 'exp'(exponential)]`` . 20 Returns: 21 - eps_fn (:obj:`function`): The epsilon greedy function with decay. 22 """ 23 assert type_ in ['linear', 'exp'], type_ 24 if type_ == 'exp': 25 return lambda x: (start - end) * math.exp(-1 * x / decay) + end 26 elif type_ == 'linear': 27 28 def eps_fn(x): 29 if x >= decay: 30 return end 31 else: 32 return (start - end) * (1 - x / decay) + end 33 34 return eps_fn 35 36 37class BaseNoise(ABC): 38 r""" 39 Overview: 40 Base class for action noise 41 Interface: 42 __init__, __call__ 43 Examples: 44 >>> noise_generator = OUNoise() # init one type of noise 45 >>> noise = noise_generator(action.shape, action.device) # generate noise 46 """ 47 48 def __init__(self) -> None: 49 """ 50 Overview: 51 Initialization method. 52 """ 53 super().__init__() 54 55 @abstractmethod 56 def __call__(self, shape: tuple, device: str) -> torch.Tensor: 57 """ 58 Overview: 59 Generate noise according to action tensor's shape, device. 60 Arguments: 61 - shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same. 62 - device (:obj:`str`): device of the action tensor, output noise's device should be the same as it. 63 Returns: 64 - noise (:obj:`torch.Tensor`): generated action noise, \ 65 have the same shape and device with the input action tensor. 66 """ 67 raise NotImplementedError 68 69 70class GaussianNoise(BaseNoise): 71 """ 72 Overview: 73 Derived class for generating gaussian noise, which satisfies :math:`X \sim N(\mu, \sigma^2)` 74 Interface: 75 __init__, __call__ 76 """ 77 78 def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: 79 """ 80 Overview: 81 Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution. 82 Arguments: 83 - mu (:obj:`float`): :math:`\mu` , mean value. 84 - sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive. 85 """ 86 super(GaussianNoise, self).__init__() 87 self._mu = mu 88 assert sigma >= 0, "GaussianNoise's sigma should be positive." 89 self._sigma = sigma 90 91 def __call__(self, shape: tuple, device: str) -> torch.Tensor: 92 """ 93 Overview: 94 Generate gaussian noise according to action tensor's shape, device 95 Arguments: 96 - shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same 97 - device (:obj:`str`): device of the action tensor, output noise's device should be the same as it 98 Returns: 99 - noise (:obj:`torch.Tensor`): generated action noise, \ 100 have the same shape and device with the input action tensor 101 """ 102 noise = torch.randn(shape, device=device) 103 noise = noise * self._sigma + self._mu 104 return noise 105 106 107class OUNoise(BaseNoise): 108 r""" 109 Overview: 110 Derived class for generating Ornstein-Uhlenbeck process noise. 111 Satisfies :math:`dx_t=\theta(\mu-x_t)dt + \sigma dW_t`, 112 where :math:`W_t` denotes Weiner Process, acting as a random perturbation term. 113 Interface: 114 __init__, reset, __call__ 115 """ 116 117 def __init__( 118 self, 119 mu: float = 0.0, 120 sigma: float = 0.3, 121 theta: float = 0.15, 122 dt: float = 1e-2, 123 x0: Optional[Union[float, torch.Tensor]] = 0.0, 124 ) -> None: 125 """ 126 Overview: 127 Initialize ``_alpha`` :math:`=\theta * dt\`, 128 ``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process. 129 Arguments: 130 - mu (:obj:`float`): :math:`\mu` , mean value. 131 - sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise. 132 - theta (:obj:`float`): How strongly the noise reacts to perturbations, \ 133 greater value means stronger reaction. 134 - dt (:obj:`float`): The derivative of time t. 135 - x0 (:obj:`Union[float, torch.Tensor]`): The initial state of the noise, \ 136 should be a scalar or tensor with the same shape as the action tensor. 137 """ 138 super().__init__() 139 self._mu = mu 140 self._alpha = theta * dt 141 self._beta = sigma * math.sqrt(dt) 142 self._x0 = x0 143 self.reset() 144 145 def reset(self) -> None: 146 """ 147 Overview: 148 Reset ``_x`` to the initial state ``_x0``. 149 """ 150 self._x = deepcopy(self._x0) 151 152 def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> torch.Tensor: 153 """ 154 Overview: 155 Generate gaussian noise according to action tensor's shape, device. 156 Arguments: 157 - shape (:obj:`tuple`): The size of the action tensor, output noise's size should be the same. 158 - device (:obj:`str`): The device of the action tensor, output noise's device should be the same as it. 159 - mu (:obj:`float`): The new mean value :math:`\mu`, you can set it to `None` if don't need it. 160 Returns: 161 - noise (:obj:`torch.Tensor`): generated action noise, \ 162 have the same shape and device with the input action tensor. 163 """ 164 if self._x is None or \ 165 (isinstance(self._x, torch.Tensor) and self._x.shape != shape): 166 self._x = torch.zeros(shape) 167 if mu is None: 168 mu = self._mu 169 noise = self._alpha * (mu - self._x) + self._beta * torch.randn(shape) 170 self._x += noise 171 noise = to_device(noise, device) 172 return noise 173 174 @property 175 def x0(self) -> Union[float, torch.Tensor]: 176 """ 177 Overview: 178 Get ``self._x0``. 179 """ 180 return self._x0 181 182 @x0.setter 183 def x0(self, _x0: Union[float, torch.Tensor]) -> None: 184 """ 185 Overview: 186 Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well. 187 """ 188 self._x0 = _x0 189 self.reset() 190 191 192noise_mapping = {'gauss': GaussianNoise, 'ou': OUNoise} 193 194 195def create_noise_generator(noise_type: str, noise_kwargs: dict) -> BaseNoise: 196 """ 197 Overview: 198 Given the key (noise_type), create a new noise generator instance if in noise_mapping's values, 199 or raise an KeyError. In other words, a derived noise generator must first register, 200 then call ``create_noise generator`` to get the instance object. 201 Arguments: 202 - noise_type (:obj:`str`): the type of noise generator to be created. 203 Returns: 204 - noise (:obj:`BaseNoise`): the created new noise generator, should be an instance of one of \ 205 noise_mapping's values. 206 """ 207 if noise_type not in noise_mapping.keys(): 208 raise KeyError("not support noise type: {}".format(noise_type)) 209 else: 210 return noise_mapping[noise_type](**noise_kwargs)