ding.torch_utils.network.scatter_connection¶
ding.torch_utils.network.scatter_connection
¶
ScatterConnection
¶
Bases: Module
Overview
Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, and these tensors are scattered into a feature map with map size.
Interfaces:
__init__, forward, xy_forward
__init__(scatter_type)
¶
Overview
Initialize the ScatterConnection object.
Arguments:
- scatter_type (:obj:str): The scatter type, which decides the behavior when two entities have the same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the second one. If 'cover', the first one will be covered by the second one.
forward(x, spatial_size, location)
¶
Overview
Scatter input tensor 'x' into a spatial feature map.
Arguments:
- x (:obj:torch.Tensor): The input tensor of shape (B, M, N), where B is the batch size, M is the number of entities, and N is the dimension of entity attributes.
- spatial_size (:obj:Tuple[int, int]): The size (H, W) of the spatial feature map into which 'x' will be scattered, where H is the height and W is the width.
- location (:obj:torch.Tensor): The tensor of locations of shape (B, M, 2). Each location should be (y, x).
Returns:
- output (:obj:torch.Tensor): The scattered feature map of shape (B, N, H, W).
Note:
When there are some overlapping in locations, 'cover' mode will result in the loss of information.
'add' mode is used as a temporary substitute.
xy_forward(x, spatial_size, coord_x, coord_y)
¶
Overview
Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates.
Arguments:
- x (:obj:torch.Tensor): The input tensor of shape (B, M, N), where B is the batch size, M is the number of entities, and N is the dimension of entity attributes.
- spatial_size (:obj:Tuple[int, int]): The size (H, W) of the spatial feature map into which 'x' will be scattered, where H is the height and W is the width.
- coord_x (:obj:torch.Tensor): The x-coordinates tensor of shape (B, M).
- coord_y (:obj:torch.Tensor): The y-coordinates tensor of shape (B, M).
Returns:
- output (:obj:torch.Tensor): The scattered feature map of shape (B, N, H, W).
Note:
When there are some overlapping in locations, 'cover' mode will result in the loss of information.
'add' mode is used as a temporary substitute.
shape_fn_scatter_connection(args, kwargs)
¶
Overview
Return the shape of scatter_connection for HPC.
Arguments:
- args (:obj:Tuple): The arguments passed to the scatter_connection function.
- kwargs (:obj:Dict): The keyword arguments passed to the scatter_connection function.
Returns:
- shape (:obj:List[int]): A list representing the shape of scatter_connection, in the form of [B, M, N, H, W, scatter_type].
Full Source Code
../ding/torch_utils/network/scatter_connection.py