Skip to content

ding.rl_utils.adder

ding.rl_utils.adder

Adder

Bases: object

Overview

Adder is a component that handles different transformations and calculations for transitions in Collector Module(data generation and processing), such as GAE, n-step return, transition sampling etc.

Interface: init, get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample

get_gae(data, last_value, gamma, gae_lambda, cuda) classmethod

Overview

Get GAE advantage for stacked transitions(T timestep, 1 batch). Call gae for calculation.

Arguments: - data (:obj:list): Transitions list, each element is a transition dict with at least ['value', 'reward']. - last_value (:obj:torch.Tensor): The last value(i.e.: the T+1 timestep) - gamma (:obj:float): The future discount factor, should be in [0, 1], defaults to 0.99. - gae_lambda (:obj:float): GAE lambda parameter, should be in [0, 1], defaults to 0.97, when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. - cuda (:obj:bool): Whether use cuda in GAE computation Returns: - data (:obj:list): transitions list like input one, but each element owns extra advantage key 'adv' Examples: >>> B, T = 2, 3 # batch_size, timestep >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)] >>> last_value = torch.randn(B) >>> gamma = 0.99 >>> gae_lambda = 0.95 >>> cuda = False >>> data = Adder.get_gae(data, last_value, gamma, gae_lambda, cuda)

get_gae_with_default_last_value(data, done, gamma, gae_lambda, cuda) classmethod

Overview

Like get_gae above to get GAE advantage for stacked transitions. However, this function is designed in case last_value is not passed. If transition is not done yet, it wouold assign last value in data as last_value, discard the last element in data (i.e. len(data) would decrease by 1), and then call get_gae. Otherwise it would make last_value equal to 0.

Arguments: - data (:obj:deque): Transitions list, each element is a transition dict with at least['value', 'reward'] - done (:obj:bool): Whether the transition reaches the end of an episode(i.e. whether the env is done) - gamma (:obj:float): The future discount factor, should be in [0, 1], defaults to 0.99. - gae_lambda (:obj:float): GAE lambda parameter, should be in [0, 1], defaults to 0.97, when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. - cuda (:obj:bool): Whether use cuda in GAE computation Returns: - data (:obj:List[Dict[str, Any]]): transitions list like input one, but each element owns extra advantage key 'adv' Examples: >>> B, T = 2, 3 # batch_size, timestep >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)] >>> done = False >>> gamma = 0.99 >>> gae_lambda = 0.95 >>> cuda = False >>> data = Adder.get_gae_with_default_last_value(data, done, gamma, gae_lambda, cuda)

get_nstep_return_data(data, nstep, cum_reward=False, correct_terminate_gamma=True, gamma=0.99) classmethod

Overview

Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.

Arguments: - data (:obj:deque): Transitions list, each element is a transition dict - nstep (:obj:int): Number of steps. If equals to 1, return data directly; Otherwise update with nstep value. Returns: - data (:obj:deque): Transitions list like input one, but each element updated with nstep value. Examples: >>> data = [dict( >>> obs=torch.randn(B), >>> reward=torch.randn(1), >>> next_obs=torch.randn(B), >>> done=False) for _ in range(T)] >>> nstep = 2 >>> data = Adder.get_nstep_return_data(data, nstep)

get_train_sample(data, unroll_len, last_fn_type='last', null_transition=None) classmethod

Overview

Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element. If unroll_len equals to 1, which means no process is needed, can directly return data. Otherwise, data will be splitted according to unroll_len, process residual part according to last_fn_type and call lists_to_dicts to form sampled training data.

Arguments: - data (:obj:List[Dict[str, Any]]): Transitions list, each element is a transition dict - unroll_len (:obj:int): Learn training unroll length - last_fn_type (:obj:str): The method type name for dealing with last residual data in a traj after splitting, should be in ['last', 'drop', 'null_padding'] - null_transition (:obj:Optional[dict]): Dict type null transition, used in null_padding Returns: - data (:obj:List[Dict[str, Any]]): Transitions list processed after unrolling

Full Source Code

../ding/rl_utils/adder.py

1from typing import List, Dict, Any, Optional 2from collections import deque 3import copy 4import torch 5 6from ding.utils import list_split, lists_to_dicts 7from ding.rl_utils.gae import gae, gae_data 8 9 10class Adder(object): 11 """ 12 Overview: 13 Adder is a component that handles different transformations and calculations for transitions 14 in Collector Module(data generation and processing), such as GAE, n-step return, transition sampling etc. 15 Interface: 16 __init__, get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample 17 """ 18 19 @classmethod 20 def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: float, gae_lambda: float, 21 cuda: bool) -> List[Dict[str, Any]]: 22 """ 23 Overview: 24 Get GAE advantage for stacked transitions(T timestep, 1 batch). Call ``gae`` for calculation. 25 Arguments: 26 - data (:obj:`list`): Transitions list, each element is a transition dict with at least \ 27 ``['value', 'reward']``. 28 - last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep) 29 - gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99. 30 - gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \ 31 when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. 32 - cuda (:obj:`bool`): Whether use cuda in GAE computation 33 Returns: 34 - data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv' 35 Examples: 36 >>> B, T = 2, 3 # batch_size, timestep 37 >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)] 38 >>> last_value = torch.randn(B) 39 >>> gamma = 0.99 40 >>> gae_lambda = 0.95 41 >>> cuda = False 42 >>> data = Adder.get_gae(data, last_value, gamma, gae_lambda, cuda) 43 """ 44 value = torch.stack([d['value'] for d in data]) 45 next_value = torch.stack([d['value'] for d in data][1:] + [last_value]) 46 reward = torch.stack([d['reward'] for d in data]) 47 if cuda: 48 value = value.cuda() 49 next_value = next_value.cuda() 50 reward = reward.cuda() 51 52 adv = gae(gae_data(value, next_value, reward, None, None), gamma, gae_lambda) 53 54 if cuda: 55 adv = adv.cpu() 56 for i in range(len(data)): 57 data[i]['adv'] = adv[i] 58 return data 59 60 @classmethod 61 def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float, gae_lambda: float, 62 cuda: bool) -> List[Dict[str, Any]]: 63 """ 64 Overview: 65 Like ``get_gae`` above to get GAE advantage for stacked transitions. However, this function is designed in 66 case ``last_value`` is not passed. If transition is not done yet, it wouold assign last value in ``data`` 67 as ``last_value``, discard the last element in ``data`` (i.e. len(data) would decrease by 1), and then call 68 ``get_gae``. Otherwise it would make ``last_value`` equal to 0. 69 Arguments: 70 - data (:obj:`deque`): Transitions list, each element is a transition dict with \ 71 at least['value', 'reward'] 72 - done (:obj:`bool`): Whether the transition reaches the end of an episode(i.e. whether the env is done) 73 - gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99. 74 - gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \ 75 when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. 76 - cuda (:obj:`bool`): Whether use cuda in GAE computation 77 Returns: 78 - data (:obj:`List[Dict[str, Any]]`): transitions list like input one, but each element owns \ 79 extra advantage key 'adv' 80 Examples: 81 >>> B, T = 2, 3 # batch_size, timestep 82 >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)] 83 >>> done = False 84 >>> gamma = 0.99 85 >>> gae_lambda = 0.95 86 >>> cuda = False 87 >>> data = Adder.get_gae_with_default_last_value(data, done, gamma, gae_lambda, cuda) 88 """ 89 if done: 90 last_value = torch.zeros_like(data[-1]['value']) 91 else: 92 last_data = data.pop() 93 last_value = last_data['value'] 94 return cls.get_gae(data, last_value, gamma, gae_lambda, cuda) 95 96 @classmethod 97 def get_nstep_return_data( 98 cls, 99 data: deque, 100 nstep: int, 101 cum_reward=False, 102 correct_terminate_gamma=True, 103 gamma=0.99, 104 ) -> deque: 105 """ 106 Overview: 107 Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element. 108 Arguments: 109 - data (:obj:`deque`): Transitions list, each element is a transition dict 110 - nstep (:obj:`int`): Number of steps. If equals to 1, return ``data`` directly; \ 111 Otherwise update with nstep value. 112 Returns: 113 - data (:obj:`deque`): Transitions list like input one, but each element updated with nstep value. 114 Examples: 115 >>> data = [dict( 116 >>> obs=torch.randn(B), 117 >>> reward=torch.randn(1), 118 >>> next_obs=torch.randn(B), 119 >>> done=False) for _ in range(T)] 120 >>> nstep = 2 121 >>> data = Adder.get_nstep_return_data(data, nstep) 122 """ 123 if nstep == 1: 124 return data 125 fake_reward = torch.zeros_like(data[0]['reward']) 126 next_obs_flag = 'next_obs' in data[0] 127 for i in range(len(data) - nstep): 128 # update keys ['next_obs', 'reward', 'done'] with their n-step value 129 if next_obs_flag: 130 data[i]['next_obs'] = data[i + nstep]['obs'] # do not need deepcopy 131 if cum_reward: 132 data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)]) 133 else: 134 # data[i]['reward'].shape = (1) or (agent_num, 1) 135 # single agent env: shape (1) -> (n_step) 136 # multi-agent env: shape (agent_num, 1) -> (agent_num, n_step) 137 data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)], dim=-1) 138 data[i]['done'] = data[i + nstep - 1]['done'] 139 if correct_terminate_gamma: 140 data[i]['value_gamma'] = gamma ** nstep 141 for i in range(max(0, len(data) - nstep), len(data)): 142 if next_obs_flag: 143 data[i]['next_obs'] = data[-1]['next_obs'] # do not need deepcopy 144 if cum_reward: 145 data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(len(data) - i)]) 146 else: 147 data[i]['reward'] = torch.cat( 148 [data[i + j]['reward'] 149 for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))], 150 dim=-1 151 ) 152 data[i]['done'] = data[-1]['done'] 153 if correct_terminate_gamma: 154 data[i]['value_gamma'] = gamma ** (len(data) - i - 1) 155 return data 156 157 @classmethod 158 def get_train_sample( 159 cls, 160 data: List[Dict[str, Any]], 161 unroll_len: int, 162 last_fn_type: str = 'last', 163 null_transition: Optional[dict] = None 164 ) -> List[Dict[str, Any]]: 165 """ 166 Overview: 167 Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element. 168 If ``unroll_len`` equals to 1, which means no process is needed, can directly return ``data``. 169 Otherwise, ``data`` will be splitted according to ``unroll_len``, process residual part according to 170 ``last_fn_type`` and call ``lists_to_dicts`` to form sampled training data. 171 Arguments: 172 - data (:obj:`List[Dict[str, Any]]`): Transitions list, each element is a transition dict 173 - unroll_len (:obj:`int`): Learn training unroll length 174 - last_fn_type (:obj:`str`): The method type name for dealing with last residual data in a traj \ 175 after splitting, should be in ['last', 'drop', 'null_padding'] 176 - null_transition (:obj:`Optional[dict]`): Dict type null transition, used in ``null_padding`` 177 Returns: 178 - data (:obj:`List[Dict[str, Any]]`): Transitions list processed after unrolling 179 """ 180 if unroll_len == 1: 181 return data 182 else: 183 # cut data into pieces whose length is unroll_len 184 split_data, residual = list_split(data, step=unroll_len) 185 186 def null_padding(): 187 template = copy.deepcopy(residual[0]) 188 template['null'] = True 189 if isinstance(template['obs'], dict): 190 template['obs'] = {k: torch.zeros_like(v) for k, v in template['obs'].items()} 191 else: 192 template['obs'] = torch.zeros_like(template['obs']) 193 if 'action' in template: 194 template['action'] = torch.zeros_like(template['action']) 195 template['done'] = True 196 template['reward'] = torch.zeros_like(template['reward']) 197 if 'value_gamma' in template: 198 template['value_gamma'] = 0. 199 null_data = [cls._get_null_transition(template, null_transition) for _ in range(miss_num)] 200 return null_data 201 202 if residual is not None: 203 miss_num = unroll_len - len(residual) 204 if last_fn_type == 'drop': 205 # drop the residual part 206 pass 207 elif last_fn_type == 'last': 208 if len(split_data) > 0: 209 # copy last datas from split_data's last element, and insert in front of residual 210 last_data = copy.deepcopy(split_data[-1][-miss_num:]) 211 split_data.append(last_data + residual) 212 else: 213 # get null transitions using ``null_padding``, and insert behind residual 214 null_data = null_padding() 215 split_data.append(residual + null_data) 216 elif last_fn_type == 'null_padding': 217 # same to the case of 'last' type and split_data is empty 218 null_data = null_padding() 219 split_data.append(residual + null_data) 220 # collate unroll_len dicts according to keys 221 if len(split_data) > 0: 222 split_data = [lists_to_dicts(d, recursive=True) for d in split_data] 223 return split_data 224 225 @classmethod 226 def _get_null_transition(cls, template: dict, null_transition: Optional[dict] = None) -> dict: 227 """ 228 Overview: 229 Get null transition for padding. If ``cls._null_transition`` is None, return input ``template`` instead. 230 Arguments: 231 - template (:obj:`dict`): The template for null transition. 232 - null_transition (:obj:`Optional[dict]`): Dict type null transition, used in ``null_padding`` 233 Returns: 234 - null_transition (:obj:`dict`): The deepcopied null transition. 235 """ 236 if null_transition is not None: 237 return copy.deepcopy(null_transition) 238 else: 239 return copy.deepcopy(template) 240 241 242get_gae = Adder.get_gae 243get_gae_with_default_last_value = Adder.get_gae_with_default_last_value 244get_nstep_return_data = Adder.get_nstep_return_data 245get_train_sample = Adder.get_train_sample