ding.model.template.edac¶
ding.model.template.edac
¶
EDAC
¶
Bases: Module
Overview
The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC.
Interfaces:
__init__, forward, compute_actor, compute_critic
__init__(obs_shape, action_shape, ensemble_num=2, actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.ReLU(), norm_type=None, **kwargs)
¶
Overview
Initailize the EDAC Model according to input arguments.
Arguments:
- obs_shape (:obj:Union[int, SequenceType]): Observation's shape, such as 128, (156, ).
- action_shape (:obj:Union[int, SequenceType, EasyDict]): Action's shape, such as 4, (3, ), EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- ensemble_num (:obj:int): Q-net number.
- actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor head.
- actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor head.
- critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic head.
- critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic head.
- activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP after each FC layer, if None then default set to nn.ReLU().
- norm_type (:obj:Optional[str]): The type of normalization to after network layer (FC, Conv), see ding.torch_utils.network for more details.
forward(inputs, mode)
¶
Overview
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement different computation graph, including compute_actor and compute_critic in EDAC.
Mode compute_actor:
Arguments:
- inputs (:obj:torch.Tensor): Observation data, defaults to tensor.
Returns:
- output (:obj:Dict): Output dict data, including differnet key-values among distinct action_space.
Mode compute_critic:
Arguments:
- inputs (:obj:Dict): Input dict data, including obs and action tensor.
Returns:
- output (:obj:Dict): Output dict data, including q_value tensor.
.. note::
For specific examples, one can refer to API doc of compute_actor and compute_critic respectively.
compute_actor(obs)
¶
Overview
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
such as action, logit and so on.
Arguments:
- obs (:obj:torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data, i.e. (B, obs_shape).
Returns:
- outputs (:obj:Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]): Actor output varying from action_space: reparameterization.
ReturnsKeys (either):
- logit (:obj:Dict[str, torch.Tensor]): Reparameterization logit, usually in SAC.
- mu (:obj:torch.Tensor): Mean of parameterization gaussion distribution.
- sigma (:obj:torch.Tensor): Standard variation of parameterization gaussion distribution.
Shapes:
- obs (:obj:torch.Tensor): :math:(B, N0), B is batch size and N0 corresponds to obs_shape.
- action (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape.
- logit.mu (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape.
- logit.sigma (:obj:torch.Tensor): :math:(B, N1), B is batch size.
- logit (:obj:torch.Tensor): :math:(B, N2), B is batch size and N2 corresponds to action_shape.action_type_shape.
- action_args (:obj:torch.Tensor): :math:(B, N3), B is batch size and N3 corresponds to action_shape.action_args_shape.
Examples:
>>> model = EDAC(64, 64,)
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
compute_critic(inputs)
¶
Overview
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
output, such as q_value.
Arguments:
- inputs (:obj:Dict[str, torch.Tensor]): Dict strcture of input data, including obs and action tensor
Returns:
- outputs (:obj:Dict[str, torch.Tensor]): Critic output, such as q_value.
ArgumentsKeys:
- obs: (:obj:torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data.
- action (:obj:Union[torch.Tensor, Dict]): Continuous action with same size as action_shape.
ReturnKeys:
- q_value (:obj:torch.Tensor): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:torch.Tensor): :math:(B, N1) or '(Ensemble_num, B, N1)', where B is batch size and N1 is obs_shape.
- action (:obj:torch.Tensor): :math:(B, N2) or '(Ensemble_num, B, N2)', where B is batch size and N4 is action_shape.
- q_value (:obj:torch.Tensor): :math:(Ensemble_num, B), where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = EDAC(obs_shape=(8, ),action_shape=1)
>>> model(inputs, mode='compute_critic')['q_value'] # q value
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=
Full Source Code
../ding/model/template/edac.py