Skip to content

ding.model.template.madqn

ding.model.template.madqn

Full Source Code

../ding/model/template/madqn.py

1import torch.nn as nn 2from ding.utils import MODEL_REGISTRY 3from .qmix import QMix 4 5 6@MODEL_REGISTRY.register('madqn') 7class MADQN(nn.Module): 8 9 def __init__( 10 self, 11 agent_num: int, 12 obs_shape: int, 13 action_shape: int, 14 hidden_size_list: list, 15 global_obs_shape: int = None, 16 mixer: bool = False, 17 global_cooperation: bool = True, 18 lstm_type: str = 'gru', 19 dueling: bool = False 20 ) -> None: 21 super(MADQN, self).__init__() 22 self.current = QMix( 23 agent_num=agent_num, 24 obs_shape=obs_shape, 25 action_shape=action_shape, 26 hidden_size_list=hidden_size_list, 27 global_obs_shape=global_obs_shape, 28 mixer=mixer, 29 lstm_type=lstm_type, 30 dueling=dueling 31 ) 32 self.global_cooperation = global_cooperation 33 if self.global_cooperation: 34 cooperation_obs_shape = global_obs_shape 35 else: 36 cooperation_obs_shape = obs_shape 37 self.cooperation = QMix( 38 agent_num=agent_num, 39 obs_shape=cooperation_obs_shape, 40 action_shape=action_shape, 41 hidden_size_list=hidden_size_list, 42 global_obs_shape=global_obs_shape, 43 mixer=mixer, 44 lstm_type=lstm_type, 45 dueling=dueling 46 ) 47 48 def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict: 49 if cooperation: 50 if self.global_cooperation: 51 data['obs']['agent_state'] = data['obs']['global_state'] 52 return self.cooperation(data, single_step=single_step) 53 else: 54 return self.current(data, single_step=single_step)