Skip to content

ding.model.template.bcq

ding.model.template.bcq

BCQ

Bases: Module

Overview

Model of BCQ (Batch-Constrained deep Q-learning). Off-Policy Deep Reinforcement Learning without Exploration. https://arxiv.org/abs/1812.02900

Interface: forward, compute_actor, compute_critic, compute_vae, compute_eval Property: mode

__init__(obs_shape, action_shape, actor_head_hidden_size=[400, 300], critic_head_hidden_size=[400, 300], activation=nn.ReLU(), vae_hidden_dims=[750, 750], phi=0.05)

Overview

Initialize neural network, i.e. agent Q network and actor.

Arguments: - obs_shape (:obj:int): the dimension of observation state - action_shape (:obj:int): the dimension of action shape - actor_hidden_size (:obj:list): the list of hidden size of actor - critic_hidden_size (:obj:'list'): the list of hidden size of critic - activation (:obj:nn.Module): Activation function in network, defaults to nn.ReLU(). - vae_hidden_dims (:obj:list): the list of hidden size of vae

forward(inputs, mode)

Overview

The unique execution (forward) method of BCQ method, and one can indicate different modes to implement different computation graph, including compute_actor and compute_critic in BCQ.

Mode compute_actor: Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - output (:obj:Dict): Output dict data, including action tensor. Mode compute_critic: Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - output (:obj:Dict): Output dict data, including q_value tensor. Mode compute_vae: Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - outputs (:obj:Dict): Dict containing keywords recons_action (:obj:torch.Tensor), prediction_residual (:obj:torch.Tensor), input (:obj:torch.Tensor), mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and z (:obj:torch.Tensor). Mode compute_eval: Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - output (:obj:Dict): Output dict data, including action tensor. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model(inputs, mode='compute_actor') >>> outputs = model(inputs, mode='compute_critic') >>> outputs = model(inputs, mode='compute_vae') >>> outputs = model(inputs, mode='compute_eval')

.. note:: For specific examples, one can refer to API doc of compute_actor and compute_critic respectively.

compute_critic(inputs)

Overview

Use critic network to compute q value.

Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - outputs (:obj:Dict): Dict containing keywords q_value (:obj:torch.Tensor). Shapes: - inputs (:obj:Dict): :math:(B, N, D), where B is batch size, N is sample number, D is input dimension. - outputs (:obj:Dict): :math:(B, N). Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_critic(inputs)

compute_actor(inputs)

Overview

Use actor network to compute action.

Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - outputs (:obj:Dict): Dict containing keywords action (:obj:torch.Tensor). Shapes: - inputs (:obj:Dict): :math:(B, N, D), where B is batch size, N is sample number, D is input dimension. - outputs (:obj:Dict): :math:(B, N). Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_actor(inputs)

compute_vae(inputs)

Overview

Use vae network to compute action.

Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - outputs (:obj:Dict): Dict containing keywords recons_action (:obj:torch.Tensor), prediction_residual (:obj:torch.Tensor), input (:obj:torch.Tensor), mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and z (:obj:torch.Tensor). Shapes: - inputs (:obj:Dict): :math:(B, N, D), where B is batch size, N is sample number, D is input dimension. - outputs (:obj:Dict): :math:(B, N). Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_vae(inputs)

compute_eval(inputs)

Overview

Use actor network to compute action.

Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - outputs (:obj:Dict): Dict containing keywords action (:obj:torch.Tensor). Shapes: - inputs (:obj:Dict): :math:(B, N, D), where B is batch size, N is sample number, D is input dimension. - outputs (:obj:Dict): :math:(B, N). Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_eval(inputs)

Full Source Code

../ding/model/template/bcq.py

1from typing import Union, Dict, Optional, List 2from easydict import EasyDict 3import numpy as np 4import torch 5import torch.nn as nn 6 7from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 8from ..common import RegressionHead, ReparameterizationHead 9from .vae import VanillaVAE 10 11 12@MODEL_REGISTRY.register('bcq') 13class BCQ(nn.Module): 14 """ 15 Overview: 16 Model of BCQ (Batch-Constrained deep Q-learning). 17 Off-Policy Deep Reinforcement Learning without Exploration. 18 https://arxiv.org/abs/1812.02900 19 Interface: 20 ``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval`` 21 Property: 22 ``mode`` 23 """ 24 25 mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval'] 26 27 def __init__( 28 self, 29 obs_shape: Union[int, SequenceType], 30 action_shape: Union[int, SequenceType, EasyDict], 31 actor_head_hidden_size: List = [400, 300], 32 critic_head_hidden_size: List = [400, 300], 33 activation: Optional[nn.Module] = nn.ReLU(), 34 vae_hidden_dims: List = [750, 750], 35 phi: float = 0.05 36 ) -> None: 37 """ 38 Overview: 39 Initialize neural network, i.e. agent Q network and actor. 40 Arguments: 41 - obs_shape (:obj:`int`): the dimension of observation state 42 - action_shape (:obj:`int`): the dimension of action shape 43 - actor_hidden_size (:obj:`list`): the list of hidden size of actor 44 - critic_hidden_size (:obj:'list'): the list of hidden size of critic 45 - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). 46 - vae_hidden_dims (:obj:`list`): the list of hidden size of vae 47 """ 48 super(BCQ, self).__init__() 49 obs_shape: int = squeeze(obs_shape) 50 action_shape = squeeze(action_shape) 51 self.action_shape = action_shape 52 self.input_size = obs_shape 53 self.phi = phi 54 55 critic_input_size = self.input_size + action_shape 56 self.critic = nn.ModuleList() 57 for _ in range(2): 58 net = [] 59 d = critic_input_size 60 for dim in critic_head_hidden_size: 61 net.append(nn.Linear(d, dim)) 62 net.append(activation) 63 d = dim 64 net.append(nn.Linear(d, 1)) 65 self.critic.append(nn.Sequential(*net)) 66 67 net = [] 68 d = critic_input_size 69 for dim in actor_head_hidden_size: 70 net.append(nn.Linear(d, dim)) 71 net.append(activation) 72 d = dim 73 net.append(nn.Linear(d, 1)) 74 self.actor = nn.Sequential(*net) 75 76 self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims) 77 78 def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]: 79 """ 80 Overview: 81 The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \ 82 different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ. 83 Mode compute_actor: 84 Arguments: 85 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 86 Returns: 87 - output (:obj:`Dict`): Output dict data, including action tensor. 88 Mode compute_critic: 89 Arguments: 90 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 91 Returns: 92 - output (:obj:`Dict`): Output dict data, including q_value tensor. 93 Mode compute_vae: 94 Arguments: 95 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 96 Returns: 97 - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ 98 (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ 99 ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ 100 ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). 101 Mode compute_eval: 102 Arguments: 103 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 104 Returns: 105 - output (:obj:`Dict`): Output dict data, including action tensor. 106 Examples: 107 >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} 108 >>> model = BCQ(32, 6) 109 >>> outputs = model(inputs, mode='compute_actor') 110 >>> outputs = model(inputs, mode='compute_critic') 111 >>> outputs = model(inputs, mode='compute_vae') 112 >>> outputs = model(inputs, mode='compute_eval') 113 114 .. note:: 115 For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. 116 """ 117 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 118 return getattr(self, mode)(inputs) 119 120 def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 121 """ 122 Overview: 123 Use critic network to compute q value. 124 Arguments: 125 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 126 Returns: 127 - outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`). 128 Shapes: 129 - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. 130 - outputs (:obj:`Dict`): :math:`(B, N)`. 131 Examples: 132 >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} 133 >>> model = BCQ(32, 6) 134 >>> outputs = model.compute_critic(inputs) 135 """ 136 obs, action = inputs['obs'], inputs['action'] 137 if len(action.shape) == 1: # (B, ) -> (B, 1) 138 action = action.unsqueeze(1) 139 x = torch.cat([obs, action], dim=-1) 140 x = [m(x).squeeze() for m in self.critic] 141 return {'q_value': x} 142 143 def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: 144 """ 145 Overview: 146 Use actor network to compute action. 147 Arguments: 148 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 149 Returns: 150 - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). 151 Shapes: 152 - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. 153 - outputs (:obj:`Dict`): :math:`(B, N)`. 154 Examples: 155 >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} 156 >>> model = BCQ(32, 6) 157 >>> outputs = model.compute_actor(inputs) 158 """ 159 input = torch.cat([inputs['obs'], inputs['action']], -1) 160 x = self.actor(input) 161 action = self.phi * 1 * torch.tanh(x) 162 action = (action + inputs['action']).clamp(-1, 1) 163 return {'action': action} 164 165 def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 166 """ 167 Overview: 168 Use vae network to compute action. 169 Arguments: 170 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 171 Returns: 172 - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \ 173 ``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \ 174 ``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). 175 Shapes: 176 - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. 177 - outputs (:obj:`Dict`): :math:`(B, N)`. 178 Examples: 179 >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} 180 >>> model = BCQ(32, 6) 181 >>> outputs = model.compute_vae(inputs) 182 """ 183 return self.vae.forward(inputs) 184 185 def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 186 """ 187 Overview: 188 Use actor network to compute action. 189 Arguments: 190 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 191 Returns: 192 - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). 193 Shapes: 194 - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. 195 - outputs (:obj:`Dict`): :math:`(B, N)`. 196 Examples: 197 >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} 198 >>> model = BCQ(32, 6) 199 >>> outputs = model.compute_eval(inputs) 200 """ 201 obs = inputs['obs'] 202 obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0) 203 z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5) 204 sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action'] 205 action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action'] 206 q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0] 207 idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1) 208 idx = idx.repeat_interleave(action.shape[-1], dim=-1) 209 action = action.gather(0, idx).squeeze() 210 return {'action': action}