ding.policy.mbpolicy.utils¶
ding.policy.mbpolicy.utils
¶
RewardEMA
¶
Bases: object
running mean and std
q_evaluation(obss, actions, q_critic_fn)
¶
Overview
Evaluate (observation, action) pairs along the trajectory
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- obss (
|
obj: |
required | |
- actions (
|
obj: |
required | |
- q_critic_fn (
|
obj: |
required |
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
|
Shapes
:math:N: time step
:math:B: batch size
:math:O: observation dimension
:math:A: action dimension
- obss: [N, B, O]
- actions: [N, B, A]
- q_value: [N, B]
Full Source Code
../ding/policy/mbpolicy/utils.py
1from typing import Callable, Tuple, Union 2import torch 3from torch import Tensor 4from ding.torch_utils import fold_batch, unfold_batch 5from ding.rl_utils import generalized_lambda_returns 6from ding.torch_utils.network.dreamer import static_scan 7 8 9def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], 10 Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]: 11 """ 12 Overview: 13 Evaluate (observation, action) pairs along the trajectory 14 15 Arguments: 16 - obss (:obj:`torch.Tensor`): the observations along the trajectory 17 - actions (:obj:`torch.Size`): the actions along the trajectory 18 - q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)` 19 20 Returns: 21 - q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory 22 23 Shapes: 24 :math:`N`: time step 25 :math:`B`: batch size 26 :math:`O`: observation dimension 27 :math:`A`: action dimension 28 29 - obss: [N, B, O] 30 - actions: [N, B, A] 31 - q_value: [N, B] 32 33 """ 34 obss, dim = fold_batch(obss, 1) 35 actions, _ = fold_batch(actions, 1) 36 q_values = q_critic_fn(obss, actions) 37 # twin critic 38 if isinstance(q_values, list): 39 return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)] 40 return unfold_batch(q_values, dim) 41 42 43def imagine(cfg, world_model, start, actor, horizon, repeats=None): 44 dynamics = world_model.dynamics 45 flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) 46 start = {k: flatten(v) for k, v in start.items()} 47 48 def step(prev, _): 49 state, _, _ = prev 50 feat = dynamics.get_feat(state) 51 inp = feat.detach() 52 action = actor(inp).sample() 53 succ = dynamics.img_step(state, action, sample=cfg.imag_sample) 54 return succ, feat, action 55 56 succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None)) 57 states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} 58 59 return feats, states, actions 60 61 62def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent): 63 if "discount" in world_model.heads: 64 inp = world_model.dynamics.get_feat(imag_state) 65 discount = cfg.discount * world_model.heads["discount"](inp).mean 66 # TODO whether to detach 67 discount = discount.detach() 68 else: 69 discount = cfg.discount * torch.ones_like(reward) 70 71 value = critic(imag_feat).mode() 72 # value(imag_horizon, 16*64, 1) 73 # action(imag_horizon, 16*64, ch) 74 # discount(imag_horizon, 16*64, 1) 75 target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_) 76 weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach() 77 return target, weights, value[:-1] 78 79 80def compute_actor_loss( 81 cfg, 82 actor, 83 reward_ema, 84 imag_feat, 85 imag_state, 86 imag_action, 87 target, 88 actor_ent, 89 state_ent, 90 weights, 91 base, 92): 93 metrics = {} 94 inp = imag_feat.detach() 95 policy = actor(inp) 96 actor_ent = policy.entropy() 97 # Q-val for actor is not transformed using symlog 98 if cfg.reward_EMA: 99 offset, scale = reward_ema(target) 100 normed_target = (target - offset) / scale 101 normed_base = (base - offset) / scale 102 adv = normed_target - normed_base 103 metrics.update(tensorstats(normed_target, "normed_target")) 104 values = reward_ema.values 105 metrics["EMA_005"] = values[0].detach().cpu().numpy().item() 106 metrics["EMA_095"] = values[1].detach().cpu().numpy().item() 107 108 actor_target = adv 109 if cfg.actor_entropy > 0: 110 actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] 111 actor_target += actor_entropy 112 metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item() 113 if cfg.actor_state_entropy > 0: 114 state_entropy = cfg.actor_state_entropy * state_ent[:-1] 115 actor_target += state_entropy 116 metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item() 117 actor_loss = -torch.mean(weights[:-1] * actor_target) 118 return actor_loss, metrics 119 120 121class RewardEMA(object): 122 """running mean and std""" 123 124 def __init__(self, device, alpha=1e-2): 125 self.device = device 126 self.values = torch.zeros((2, )).to(device) 127 self.alpha = alpha 128 self.range = torch.tensor([0.05, 0.95]).to(device) 129 130 def __call__(self, x): 131 flat_x = torch.flatten(x.detach()) 132 x_quantile = torch.quantile(input=flat_x, q=self.range) 133 self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values 134 scale = torch.clip(self.values[1] - self.values[0], min=1.0) 135 offset = self.values[0] 136 return offset.detach(), scale.detach() 137 138 139def tensorstats(tensor, prefix=None): 140 metrics = { 141 'mean': torch.mean(tensor).detach().cpu().numpy(), 142 'std': torch.std(tensor).detach().cpu().numpy(), 143 'min': torch.min(tensor).detach().cpu().numpy(), 144 'max': torch.max(tensor).detach().cpu().numpy(), 145 } 146 if prefix: 147 metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()} 148 return metrics