ding.model.template.hpt¶
ding.model.template.hpt
¶
HPT
¶
Bases: Module
Overview
The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. The Policy Stem utilizes cross-attention to process input data, and the Dueling Head computes Q-values for discrete action spaces.
Interfaces
__init__, forward
GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py]
__init__(state_dim, action_dim)
¶
Overview
Initialize the HPT model, including the Policy Stem and the Dueling Head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- state_dim (
|
obj: |
required | |
- action_dim (
|
obj: |
required |
.. note:: The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens.
forward(x)
¶
Overview
Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- x (
|
obj: |
required |
Returns:
| Type | Description |
|---|---|
|
PolicyStem
¶
Bases: Module
Overview
The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. It extracts features from the input and then applies cross-attention to generate a set of latent tokens.
Interfaces
__init__, init_cross_attn, compute_latent, forward
.. note:: This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction.
device
property
¶
Returns the device on which the model parameters are located.
__init__(feature_dim=8, token_dim=128)
¶
Overview
Initialize the Policy Stem module with a feature extractor and cross-attention mechanism.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- feature_dim (
|
obj: |
required | |
- token_dim (
|
obj: |
required |
init_cross_attn()
¶
Initialize cross-attention module and learnable tokens.
compute_latent(x)
¶
Overview
Compute latent representations of the input data using the feature extractor and cross-attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- x (
|
obj: |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
|
forward(x)
¶
Overview
Forward pass to compute latent tokens.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- x (
|
obj: |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
|
CrossAttention
¶
Bases: Module
__init__(query_dim, heads=8, dim_head=64, dropout=0.0)
¶
Overview
CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors, and returns the output tensor after applying attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- query_dim (
|
obj: |
required | |
- heads (
|
obj: |
required | |
- dim_head (
|
obj: |
required | |
- dropout (
|
obj: |
required |
forward(x, context, mask=None)
¶
Overview
Forward pass of the CrossAttention module. Computes the attention between the query and context tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
- x (
|
obj: |
required | |
- context (
|
obj: |
required | |
- mask (
|
obj: |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
|
Full Source Code
../ding/model/template/hpt.py