Skip to content

ding.model.template.hpt

ding.model.template.hpt

HPT

Bases: Module

Overview

The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. The Policy Stem utilizes cross-attention to process input data, and the Dueling Head computes Q-values for discrete action spaces.

Interfaces

__init__, forward

GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py]

__init__(state_dim, action_dim)

Overview

Initialize the HPT model, including the Policy Stem and the Dueling Head.

Parameters:

Name Type Description Default
- state_dim (

obj:int): The dimension of the input state.

required
- action_dim (

obj:int): The dimension of the action space.

required

.. note:: The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens.

forward(x)

Overview

Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): The input tensor representing the state.

required

Returns:

Type Description
  • q_values (:obj:torch.Tensor): The predicted Q-values for each action.

PolicyStem

Bases: Module

Overview

The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. It extracts features from the input and then applies cross-attention to generate a set of latent tokens.

Interfaces

__init__, init_cross_attn, compute_latent, forward

.. note:: This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction.

device property

Returns the device on which the model parameters are located.

__init__(feature_dim=8, token_dim=128)

Overview

Initialize the Policy Stem module with a feature extractor and cross-attention mechanism.

Parameters:

Name Type Description Default
- feature_dim (

obj:int): The dimension of the input features.

required
- token_dim (

obj:int): The dimension of the latent tokens generated

required

init_cross_attn()

Initialize cross-attention module and learnable tokens.

compute_latent(x)

Overview

Compute latent representations of the input data using the feature extractor and cross-attention.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): Input tensor with shape [B, T, ..., F].

required

Returns:

Type Description
Tensor
  • stem_tokens (:obj:torch.Tensor): Latent tokens with shape [B, 16, 128].

forward(x)

Overview

Forward pass to compute latent tokens.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): Input tensor.

required

Returns:

Type Description
Tensor
  • torch.Tensor: Latent tokens tensor.

CrossAttention

Bases: Module

__init__(query_dim, heads=8, dim_head=64, dropout=0.0)

Overview

CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors, and returns the output tensor after applying attention.

Parameters:

Name Type Description Default
- query_dim (

obj:int): The dimension of the query input.

required
- heads (

obj:int, optional): The number of attention heads. Defaults to 8.

required
- dim_head (

obj:int, optional): The dimension of each attention head. Defaults to 64.

required
- dropout (

obj:float, optional): The dropout probability. Defaults to 0.0.

required

forward(x, context, mask=None)

Overview

Forward pass of the CrossAttention module. Computes the attention between the query and context tensors.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): The query input tensor.

required
- context (

obj:torch.Tensor): The context input tensor.

required
- mask (

obj:Optional[torch.Tensor]): The attention mask tensor. Defaults to None.

required

Returns:

Type Description
Tensor
  • torch.Tensor: The output tensor after applying attention.

Full Source Code

../ding/model/template/hpt.py

1from typing import Union, Optional, Dict, Callable, List 2from einops import rearrange, repeat 3import torch 4import torch.nn as nn 5from ding.model.common.head import DuelingHead 6from ding.utils import MODEL_REGISTRY, squeeze 7 8 9@MODEL_REGISTRY.register('hpt') 10class HPT(nn.Module): 11 """ 12 Overview: 13 The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. 14 The Policy Stem utilizes cross-attention to process input data, 15 and the Dueling Head computes Q-values for discrete action spaces. 16 17 Interfaces: 18 ``__init__``, ``forward`` 19 20 GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py] 21 22 """ 23 24 def __init__(self, state_dim: int, action_dim: int): 25 """ 26 Overview: 27 Initialize the HPT model, including the Policy Stem and the Dueling Head. 28 29 Arguments: 30 - state_dim (:obj:`int`): The dimension of the input state. 31 - action_dim (:obj:`int`): The dimension of the action space. 32 33 .. note:: 34 The Policy Stem is initialized with cross-attention, 35 and the Dueling Head is set to process the resulting tokens. 36 """ 37 super(HPT, self).__init__() 38 # Initialise Policy Stem 39 self.policy_stem = PolicyStem(state_dim, 128) 40 self.policy_stem.init_cross_attn() 41 42 action_dim = squeeze(action_dim) 43 # Dueling Head, input is 16*128, output is action dimension 44 self.head = DuelingHead(hidden_size=16 * 128, output_size=action_dim) 45 46 def forward(self, x: torch.Tensor): 47 """ 48 Overview: 49 Forward pass of the HPT model. 50 Computes latent tokens from the input state and passes them through the Dueling Head. 51 52 Arguments: 53 - x (:obj:`torch.Tensor`): The input tensor representing the state. 54 55 Returns: 56 - q_values (:obj:`torch.Tensor`): The predicted Q-values for each action. 57 """ 58 # Policy Stem Outputs [B, 16, 128] 59 tokens = self.policy_stem.compute_latent(x) 60 # Flatten Operation 61 tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128] 62 # Enter to Dueling Head 63 q_values = self.head(tokens_flattened) 64 return q_values 65 66 67class PolicyStem(nn.Module): 68 """ 69 Overview: 70 The Policy Stem module is responsible for processing input features 71 and generating latent tokens using a cross-attention mechanism. 72 It extracts features from the input and then applies cross-attention 73 to generate a set of latent tokens. 74 75 Interfaces: 76 ``__init__``, ``init_cross_attn``, ``compute_latent``, ``forward`` 77 78 .. note:: 79 This module is inspired by the implementation in the Perceiver IO model 80 and uses attention mechanisms for feature extraction. 81 """ 82 INIT_CONST = 0.02 83 84 def __init__(self, feature_dim: int = 8, token_dim: int = 128): 85 """ 86 Overview: 87 Initialize the Policy Stem module with a feature extractor and cross-attention mechanism. 88 89 Arguments: 90 - feature_dim (:obj:`int`): The dimension of the input features. 91 - token_dim (:obj:`int`): The dimension of the latent tokens generated 92 by the attention mechanism. 93 """ 94 super().__init__() 95 # Initialise the feature extraction module 96 self.feature_extractor = nn.Linear(feature_dim, token_dim) 97 # Initialise CrossAttention 98 self.init_cross_attn() 99 100 def init_cross_attn(self): 101 """Initialize cross-attention module and learnable tokens.""" 102 token_num = 16 103 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * self.INIT_CONST) 104 self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) 105 106 def compute_latent(self, x: torch.Tensor) -> torch.Tensor: 107 """ 108 Overview: 109 Compute latent representations of the input data using 110 the feature extractor and cross-attention. 111 112 Arguments: 113 - x (:obj:`torch.Tensor`): Input tensor with shape [B, T, ..., F]. 114 115 Returns: 116 - stem_tokens (:obj:`torch.Tensor`): Latent tokens with shape [B, 16, 128]. 117 """ 118 # Using the Feature Extractor 119 stem_feat = self.feature_extractor(x) 120 stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) 121 # Calculating latent tokens using CrossAttention 122 stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) 123 stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) 124 return stem_tokens 125 126 def forward(self, x: torch.Tensor) -> torch.Tensor: 127 """ 128 Overview: 129 Forward pass to compute latent tokens. 130 131 Arguments: 132 - x (:obj:`torch.Tensor`): Input tensor. 133 134 Returns: 135 - torch.Tensor: Latent tokens tensor. 136 """ 137 return self.compute_latent(x) 138 139 @property 140 def device(self) -> torch.device: 141 """Returns the device on which the model parameters are located.""" 142 return next(self.parameters()).device 143 144 145class CrossAttention(nn.Module): 146 147 def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): 148 """ 149 Overview: 150 CrossAttention module used in the Perceiver IO model. 151 It computes the attention between the query and context tensors, 152 and returns the output tensor after applying attention. 153 154 Arguments: 155 - query_dim (:obj:`int`): The dimension of the query input. 156 - heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. 157 - dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. 158 - dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. 159 """ 160 super().__init__() 161 inner_dim = dim_head * heads 162 context_dim = query_dim 163 # Scaling factor for the attention logits to ensure stable gradients. 164 self.scale = dim_head ** -0.5 165 self.heads = heads 166 167 self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 168 self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 169 self.to_out = nn.Linear(inner_dim, query_dim) 170 171 self.dropout = nn.Dropout(dropout) 172 173 def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 174 """ 175 Overview: 176 Forward pass of the CrossAttention module. 177 Computes the attention between the query and context tensors. 178 179 Arguments: 180 - x (:obj:`torch.Tensor`): The query input tensor. 181 - context (:obj:`torch.Tensor`): The context input tensor. 182 - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor. Defaults to None. 183 184 Returns: 185 - torch.Tensor: The output tensor after applying attention. 186 """ 187 h = self.heads 188 q = self.to_q(x) 189 k, v = self.to_kv(context).chunk(2, dim=-1) 190 q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 191 sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 192 193 if mask is not None: 194 # fill in the masks with negative values 195 mask = rearrange(mask, "b ... -> b (...)") 196 max_neg_value = -torch.finfo(sim.dtype).max 197 mask = repeat(mask, "b j -> (b h) () j", h=h) 198 sim.masked_fill_(~mask, max_neg_value) 199 200 # attention, what we cannot get enough of 201 attn = sim.softmax(dim=-1) 202 203 # dropout 204 attn = self.dropout(attn) 205 out = torch.einsum("b i j, b j d -> b i d", attn, v) 206 out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 207 return self.to_out(out)