Skip to content

ding.torch_utils.network.scatter_connection

ding.torch_utils.network.scatter_connection

ScatterConnection

Bases: Module

Overview

Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, and these tensors are scattered into a feature map with map size.

Interfaces: __init__, forward, xy_forward

__init__(scatter_type)

Overview

Initialize the ScatterConnection object.

Arguments: - scatter_type (:obj:str): The scatter type, which decides the behavior when two entities have the same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the second one. If 'cover', the first one will be covered by the second one.

forward(x, spatial_size, location)

Overview

Scatter input tensor 'x' into a spatial feature map.

Arguments: - x (:obj:torch.Tensor): The input tensor of shape (B, M, N), where B is the batch size, M is the number of entities, and N is the dimension of entity attributes. - spatial_size (:obj:Tuple[int, int]): The size (H, W) of the spatial feature map into which 'x' will be scattered, where H is the height and W is the width. - location (:obj:torch.Tensor): The tensor of locations of shape (B, M, 2). Each location should be (y, x). Returns: - output (:obj:torch.Tensor): The scattered feature map of shape (B, N, H, W). Note: When there are some overlapping in locations, 'cover' mode will result in the loss of information. 'add' mode is used as a temporary substitute.

xy_forward(x, spatial_size, coord_x, coord_y)

Overview

Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates.

Arguments: - x (:obj:torch.Tensor): The input tensor of shape (B, M, N), where B is the batch size, M is the number of entities, and N is the dimension of entity attributes. - spatial_size (:obj:Tuple[int, int]): The size (H, W) of the spatial feature map into which 'x' will be scattered, where H is the height and W is the width. - coord_x (:obj:torch.Tensor): The x-coordinates tensor of shape (B, M). - coord_y (:obj:torch.Tensor): The y-coordinates tensor of shape (B, M). Returns: - output (:obj:torch.Tensor): The scattered feature map of shape (B, N, H, W). Note: When there are some overlapping in locations, 'cover' mode will result in the loss of information. 'add' mode is used as a temporary substitute.

shape_fn_scatter_connection(args, kwargs)

Overview

Return the shape of scatter_connection for HPC.

Arguments: - args (:obj:Tuple): The arguments passed to the scatter_connection function. - kwargs (:obj:Dict): The keyword arguments passed to the scatter_connection function. Returns: - shape (:obj:List[int]): A list representing the shape of scatter_connection, in the form of [B, M, N, H, W, scatter_type].

Full Source Code

../ding/torch_utils/network/scatter_connection.py

1import torch 2import torch.nn as nn 3from typing import Tuple, List 4from ding.hpc_rl import hpc_wrapper 5 6 7def shape_fn_scatter_connection(args, kwargs) -> List[int]: 8 """ 9 Overview: 10 Return the shape of scatter_connection for HPC. 11 Arguments: 12 - args (:obj:`Tuple`): The arguments passed to the scatter_connection function. 13 - kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function. 14 Returns: 15 - shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \ 16 in the form of [B, M, N, H, W, scatter_type]. 17 """ 18 if len(args) <= 1: 19 tmp = list(kwargs['x'].shape) 20 else: 21 tmp = list(args[1].shape) # args[0] is __main__.ScatterConnection object 22 if len(args) <= 2: 23 tmp.extend(kwargs['spatial_size']) 24 else: 25 tmp.extend(args[2]) 26 tmp.append(args[0].scatter_type) 27 return tmp 28 29 30class ScatterConnection(nn.Module): 31 """ 32 Overview: 33 Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, 34 and these tensors are scattered into a feature map with map size. 35 Interfaces: 36 ``__init__``, ``forward``, ``xy_forward`` 37 """ 38 39 def __init__(self, scatter_type: str) -> None: 40 """ 41 Overview: 42 Initialize the ScatterConnection object. 43 Arguments: 44 - scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \ 45 same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \ 46 second one. If 'cover', the first one will be covered by the second one. 47 """ 48 super(ScatterConnection, self).__init__() 49 self.scatter_type = scatter_type 50 assert self.scatter_type in ['cover', 'add'] 51 52 @hpc_wrapper( 53 shape_fn=shape_fn_scatter_connection, 54 namedtuple_data=False, 55 include_args=[0, 2], 56 include_kwargs=['x', 'location'], 57 is_cls_method=True 58 ) 59 def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: 60 """ 61 Overview: 62 Scatter input tensor 'x' into a spatial feature map. 63 Arguments: 64 - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ 65 is the number of entities, and `N` is the dimension of entity attributes. 66 - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ 67 will be scattered, where `H` is the height and `W` is the width. 68 - location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \ 69 Each location should be (y, x). 70 Returns: 71 - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. 72 Note: 73 When there are some overlapping in locations, 'cover' mode will result in the loss of information. 74 'add' mode is used as a temporary substitute. 75 """ 76 device = x.device 77 B, M, N = x.shape 78 x = x.permute(0, 2, 1) 79 H, W = spatial_size 80 index = location[:, :, 1] + location[:, :, 0] * W 81 index = index.unsqueeze(dim=1).repeat(1, N, 1) 82 output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) 83 if self.scatter_type == 'cover': 84 output.scatter_(dim=2, index=index, src=x) 85 elif self.scatter_type == 'add': 86 output.scatter_add_(dim=2, index=index, src=x) 87 output = output.view(B, N, H, W) 88 return output 89 90 def xy_forward( 91 self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y 92 ) -> torch.Tensor: 93 """ 94 Overview: 95 Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates. 96 Arguments: 97 - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ 98 is the number of entities, and `N` is the dimension of entity attributes. 99 - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ 100 will be scattered, where `H` is the height and `W` is the width. 101 - coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`. 102 - coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`. 103 Returns: 104 - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. 105 Note: 106 When there are some overlapping in locations, 'cover' mode will result in the loss of information. 107 'add' mode is used as a temporary substitute. 108 """ 109 device = x.device 110 B, M, N = x.shape 111 x = x.permute(0, 2, 1) 112 H, W = spatial_size 113 index = (coord_x * W + coord_y).long() 114 index = index.unsqueeze(dim=1).repeat(1, N, 1) 115 output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) 116 if self.scatter_type == 'cover': 117 output.scatter_(dim=2, index=index, src=x) 118 elif self.scatter_type == 'add': 119 output.scatter_add_(dim=2, index=index, src=x) 120 output = output.view(B, N, H, W) 121 return output