ding.model.template.pdqn¶
ding.model.template.pdqn
¶
PDQN
¶
Bases: Module
Overview
The neural network and computation graph of PDQN(https://arxiv.org/abs/1810.06394v1) and MPDQN(https://arxiv.org/abs/1905.04388) algorithms for parameterized action space. This model supports parameterized action space with discrete action_type and continuous action_arg. In principle, PDQN consists of x network (continuous action parameter network) and Q network (discrete action type network). But for simplicity, the code is split into encoder and actor_head, which contain the encoder and head of the above two networks respectively.
Interface:
__init__, forward, compute_discrete, compute_continuous.
__init__(obs_shape, action_shape, encoder_hidden_size_list=[128, 128, 64], dueling=True, head_hidden_size=None, head_layer_num=1, activation=nn.ReLU(), norm_type=None, multi_pass=False, action_mask=None)
¶
Overview
Init the PDQN (encoder + head) Model according to input arguments.
Arguments:
- obs_shape (:obj:Union[int, SequenceType]): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:EasyDict): Action space shape in dict type, such as EasyDict({'action_type_shape': 3, 'action_args_shape': 5}).
- encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder, the last element must match head_hidden_size.
- dueling (:obj:dueling): Whether choose DuelingHead or DiscreteHead(default).
- head_hidden_size (:obj:Optional[int]): The hidden_size of head network.
- head_layer_num (:obj:int): The number of layers used in the head network to compute Q value output.
- activation (:obj:Optional[nn.Module]): The type of activation function in networks if None then default set it to nn.ReLU().
- norm_type (:obj:Optional[str]): The type of normalization in networks, see ding.torch_utils.fc_block for more details.
- multi_pass (:obj:Optional[bool]): Whether to use multi pass version.
- action_mask: (:obj:Optional[list]): An action mask indicating how action args are associated to each discrete action. For example, if there are 3 discrete action, 4 continous action args, and the first discrete action associates with the first continuous action args, the second discrete action associates with the second continuous action args, and the third discrete action associates with the remaining 2 action args, the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4.
forward(inputs, mode)
¶
Overview
PDQN forward computation graph, input observation tensor to predict q_value for discrete actions and values for continuous action_args.
Arguments:
- inputs (:obj:Union[torch.Tensor, Dict, EasyDict]): Inputs including observation and other info according to mode.
- mode (:obj:str): Name of the forward mode.
Shapes:
- inputs (:obj:torch.Tensor): :math:(B, N), where B is batch size and N is obs_shape.
compute_continuous(inputs)
¶
Overview
Use observation tensor to predict continuous action args.
Arguments:
- inputs (:obj:torch.Tensor): Observation inputs.
Returns:
- outputs (:obj:Dict): A dict with key 'action_args'.
- 'action_args' (:obj:torch.Tensor): The continuous action args.
Shapes:
- inputs (:obj:torch.Tensor): :math:(B, N), where B is batch size and N is obs_shape.
- action_args (:obj:torch.Tensor): :math:(B, M), where M is action_args_shape.
Examples:
>>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )})
>>> model = PDQN(4, act_shape)
>>> inputs = torch.randn(64, 4)
>>> outputs = model.forward(inputs, mode='compute_continuous')
>>> assert outputs['action_args'].shape == torch.Size([64, 5])
compute_discrete(inputs)
¶
Overview
Use observation tensor and continuous action args to predict discrete action types.
Arguments:
- inputs (:obj:Union[Dict, EasyDict]): A dict with keys 'state', 'action_args'.
- state (:obj:torch.Tensor): Observation inputs.
- action_args (:obj:torch.Tensor): Action parameters are used to concatenate with the observation and serve as input to the discrete action type network.
Returns:
- outputs (:obj:Dict): A dict with keys 'logit', 'action_args'.
- 'logit': The logit value for each discrete action.
- 'action_args': The continuous action args(same as the inputs['action_args']) for later usage.
Examples:
>>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )})
>>> model = PDQN(4, act_shape)
>>> inputs = {'state': torch.randn(64, 4), 'action_args': torch.randn(64, 5)}
>>> outputs = model.forward(inputs, mode='compute_discrete')
>>> assert outputs['logit'].shape == torch.Size([64, 3])
>>> assert outputs['action_args'].shape == torch.Size([64, 5])
Full Source Code
../ding/model/template/pdqn.py