Skip to content

ding.model.template.sqn

ding.model.template.sqn

Full Source Code

../ding/model/template/sqn.py

1from typing import Dict 2import torch 3import torch.nn as nn 4 5from ding.utils import MODEL_REGISTRY 6from .q_learning import DQN 7 8 9@MODEL_REGISTRY.register('sqn') 10class SQN(nn.Module): 11 12 def __init__(self, *args, **kwargs) -> None: 13 super(SQN, self).__init__() 14 self.q0 = DQN(*args, **kwargs) 15 self.q1 = DQN(*args, **kwargs) 16 17 def forward(self, data: torch.Tensor) -> Dict: 18 output0 = self.q0(data) 19 output1 = self.q1(data) 20 return { 21 'q_value': [output0['logit'], output1['logit']], 22 'logit': output0['logit'], 23 }