Skip to content

ding.policy.command_mode_policy_instance

ding.policy.command_mode_policy_instance

EpsCommandModePolicy

DDPGCommandModePolicy

BCCommandModePolicy

Bases: BehaviourCloningPolicy, DummyCommandModePolicy

Full Source Code

../ding/policy/command_mode_policy_instance.py

1from ding.utils import POLICY_REGISTRY 2from ding.rl_utils import get_epsilon_greedy_fn 3from .base_policy import CommandModePolicy 4 5from .dqn import DQNPolicy, DQNSTDIMPolicy 6from .mdqn import MDQNPolicy 7from .c51 import C51Policy 8from .qrdqn import QRDQNPolicy 9from .iqn import IQNPolicy 10from .fqf import FQFPolicy 11from .rainbow import RainbowDQNPolicy 12from .r2d2 import R2D2Policy 13from .r2d2_gtrxl import R2D2GTrXLPolicy 14from .r2d2_collect_traj import R2D2CollectTrajPolicy 15from .sqn import SQNPolicy 16from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy 17from .vmpo import VMPOPolicy 18from .offppo_collect_traj import OffPPOCollectTrajPolicy 19from .ppg import PPGPolicy, PPGOffPolicy 20from .pg import PGPolicy 21from .a2c import A2CPolicy 22from .impala import IMPALAPolicy 23from .ngu import NGUPolicy 24from .ddpg import DDPGPolicy 25from .td3 import TD3Policy 26from .td3_vae import TD3VAEPolicy 27from .td3_bc import TD3BCPolicy 28from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy 29from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy 30from .mbpolicy.dreamer import DREAMERPolicy 31from .qmix import QMIXPolicy 32from .wqmix import WQMIXPolicy 33from .collaq import CollaQPolicy 34from .coma import COMAPolicy 35from .atoc import ATOCPolicy 36from .acer import ACERPolicy 37from .qtran import QTRANPolicy 38from .sql import SQLPolicy 39from .bc import BehaviourCloningPolicy 40from .ibc import IBCPolicy 41 42from .dqfd import DQFDPolicy 43from .r2d3 import R2D3Policy 44 45from .d4pg import D4PGPolicy 46from .cql import CQLPolicy, DiscreteCQLPolicy 47from .iql import IQLPolicy 48from .dt import DTPolicy 49from .pdqn import PDQNPolicy 50from .madqn import MADQNPolicy 51from .bdq import BDQPolicy 52from .bcq import BCQPolicy 53from .edac import EDACPolicy 54from .prompt_pg import PromptPGPolicy 55from .plan_diffuser import PDPolicy 56from .happo import HAPPOPolicy 57from .prompt_awr import PromptAWRPolicy 58 59 60class EpsCommandModePolicy(CommandModePolicy): 61 62 def _init_command(self) -> None: 63 r""" 64 Overview: 65 Command mode init method. Called by ``self.__init__``. 66 Set the eps_greedy rule according to the config for command 67 """ 68 eps_cfg = self._cfg.other.eps 69 self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) 70 71 def _get_setting_collect(self, command_info: dict) -> dict: 72 r""" 73 Overview: 74 Collect mode setting information including eps 75 Arguments: 76 - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] 77 Returns: 78 - collect_setting (:obj:`dict`): Including eps in collect mode. 79 """ 80 # Decay according to `learner_train_iter` 81 # step = command_info['learner_train_iter'] 82 # Decay according to `envstep` 83 step = command_info['envstep'] 84 return {'eps': self.epsilon_greedy(step)} 85 86 def _get_setting_learn(self, command_info: dict) -> dict: 87 return {} 88 89 def _get_setting_eval(self, command_info: dict) -> dict: 90 return {} 91 92 93class DummyCommandModePolicy(CommandModePolicy): 94 95 def _init_command(self) -> None: 96 pass 97 98 def _get_setting_collect(self, command_info: dict) -> dict: 99 return {} 100 101 def _get_setting_learn(self, command_info: dict) -> dict: 102 return {} 103 104 def _get_setting_eval(self, command_info: dict) -> dict: 105 return {} 106 107 108@POLICY_REGISTRY.register('bdq_command') 109class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy): 110 pass 111 112 113@POLICY_REGISTRY.register('mdqn_command') 114class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy): 115 pass 116 117 118@POLICY_REGISTRY.register('dqn_command') 119class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): 120 pass 121 122 123@POLICY_REGISTRY.register('dqn_stdim_command') 124class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy): 125 pass 126 127 128@POLICY_REGISTRY.register('dqfd_command') 129class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy): 130 pass 131 132 133@POLICY_REGISTRY.register('c51_command') 134class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): 135 pass 136 137 138@POLICY_REGISTRY.register('qrdqn_command') 139class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy): 140 pass 141 142 143@POLICY_REGISTRY.register('iqn_command') 144class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy): 145 pass 146 147 148@POLICY_REGISTRY.register('fqf_command') 149class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy): 150 pass 151 152 153@POLICY_REGISTRY.register('rainbow_command') 154class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy): 155 pass 156 157 158@POLICY_REGISTRY.register('r2d2_command') 159class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy): 160 pass 161 162 163@POLICY_REGISTRY.register('r2d2_gtrxl_command') 164class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy): 165 pass 166 167 168@POLICY_REGISTRY.register('r2d2_collect_traj_command') 169class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy): 170 pass 171 172 173@POLICY_REGISTRY.register('r2d3_command') 174class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy): 175 pass 176 177 178@POLICY_REGISTRY.register('sqn_command') 179class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy): 180 pass 181 182 183@POLICY_REGISTRY.register('sql_command') 184class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy): 185 pass 186 187 188@POLICY_REGISTRY.register('ppo_command') 189class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy): 190 pass 191 192 193@POLICY_REGISTRY.register('vmpo_command') 194class VMPOCommandModePolicy(VMPOPolicy, DummyCommandModePolicy): 195 pass 196 197 198@POLICY_REGISTRY.register('happo_command') 199class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy): 200 pass 201 202 203@POLICY_REGISTRY.register('ppo_stdim_command') 204class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy): 205 pass 206 207 208@POLICY_REGISTRY.register('ppo_pg_command') 209class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy): 210 pass 211 212 213@POLICY_REGISTRY.register('ppo_offpolicy_command') 214class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy): 215 pass 216 217 218@POLICY_REGISTRY.register('offppo_collect_traj_command') 219class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy): 220 pass 221 222 223@POLICY_REGISTRY.register('pg_command') 224class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy): 225 pass 226 227 228@POLICY_REGISTRY.register('a2c_command') 229class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy): 230 pass 231 232 233@POLICY_REGISTRY.register('impala_command') 234class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy): 235 pass 236 237 238@POLICY_REGISTRY.register('ppg_offpolicy_command') 239class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy): 240 pass 241 242 243@POLICY_REGISTRY.register('ppg_command') 244class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy): 245 pass 246 247 248@POLICY_REGISTRY.register('madqn_command') 249class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy): 250 pass 251 252 253@POLICY_REGISTRY.register('ddpg_command') 254class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy): 255 256 def _init_command(self) -> None: 257 r""" 258 Overview: 259 Command mode init method. Called by ``self.__init__``. 260 If hybrid action space, set the eps_greedy rule according to the config for command, 261 otherwise, just a empty method 262 """ 263 if self._cfg.action_space == 'hybrid': 264 eps_cfg = self._cfg.other.eps 265 self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) 266 267 def _get_setting_collect(self, command_info: dict) -> dict: 268 r""" 269 Overview: 270 Collect mode setting information including eps when hybrid action space 271 Arguments: 272 - command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep'] 273 Returns: 274 - collect_setting (:obj:`dict`): Including eps in collect mode. 275 """ 276 if self._cfg.action_space == 'hybrid': 277 # Decay according to `learner_step` 278 # step = command_info['learner_step'] 279 # Decay according to `envstep` 280 step = command_info['envstep'] 281 return {'eps': self.epsilon_greedy(step)} 282 else: 283 return {} 284 285 def _get_setting_learn(self, command_info: dict) -> dict: 286 return {} 287 288 def _get_setting_eval(self, command_info: dict) -> dict: 289 return {} 290 291 292@POLICY_REGISTRY.register('td3_command') 293class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy): 294 pass 295 296 297@POLICY_REGISTRY.register('td3_vae_command') 298class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy): 299 pass 300 301 302@POLICY_REGISTRY.register('td3_bc_command') 303class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy): 304 pass 305 306 307@POLICY_REGISTRY.register('sac_command') 308class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy): 309 pass 310 311 312@POLICY_REGISTRY.register('mbsac_command') 313class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy): 314 pass 315 316 317@POLICY_REGISTRY.register('stevesac_command') 318class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy): 319 pass 320 321 322@POLICY_REGISTRY.register('dreamer_command') 323class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): 324 pass 325 326 327@POLICY_REGISTRY.register('cql_command') 328class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): 329 pass 330 331 332@POLICY_REGISTRY.register('iql_command') 333class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): 334 pass 335 336 337@POLICY_REGISTRY.register('discrete_cql_command') 338class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): 339 pass 340 341 342@POLICY_REGISTRY.register('dt_command') 343class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy): 344 pass 345 346 347@POLICY_REGISTRY.register('qmix_command') 348class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy): 349 pass 350 351 352@POLICY_REGISTRY.register('wqmix_command') 353class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy): 354 pass 355 356 357@POLICY_REGISTRY.register('collaq_command') 358class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy): 359 pass 360 361 362@POLICY_REGISTRY.register('coma_command') 363class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy): 364 pass 365 366 367@POLICY_REGISTRY.register('atoc_command') 368class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy): 369 pass 370 371 372@POLICY_REGISTRY.register('acer_command') 373class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy): 374 pass 375 376 377@POLICY_REGISTRY.register('qtran_command') 378class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy): 379 pass 380 381 382@POLICY_REGISTRY.register('ngu_command') 383class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy): 384 pass 385 386 387@POLICY_REGISTRY.register('d4pg_command') 388class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy): 389 pass 390 391 392@POLICY_REGISTRY.register('pdqn_command') 393class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy): 394 pass 395 396 397@POLICY_REGISTRY.register('discrete_sac_command') 398class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy): 399 pass 400 401 402@POLICY_REGISTRY.register('sqil_sac_command') 403class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy): 404 pass 405 406 407@POLICY_REGISTRY.register('ibc_command') 408class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy): 409 pass 410 411 412@POLICY_REGISTRY.register('bcq_command') 413class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy): 414 pass 415 416 417@POLICY_REGISTRY.register('edac_command') 418class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy): 419 pass 420 421 422@POLICY_REGISTRY.register('pd_command') 423class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy): 424 pass 425 426 427@POLICY_REGISTRY.register('bc_command') 428class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy): 429 430 def _init_command(self) -> None: 431 r""" 432 Overview: 433 Command mode init method. Called by ``self.__init__``. 434 Set the eps_greedy rule according to the config for command 435 """ 436 if self._cfg.continuous: 437 noise_cfg = self._cfg.collect.noise_sigma 438 self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type) 439 else: 440 eps_cfg = self._cfg.other.eps 441 self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) 442 443 def _get_setting_collect(self, command_info: dict) -> dict: 444 r""" 445 Overview: 446 Collect mode setting information including eps 447 Arguments: 448 - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] 449 Returns: 450 - collect_setting (:obj:`dict`): Including eps in collect mode. 451 """ 452 if self._cfg.continuous: 453 # Decay according to `learner_step` 454 step = command_info['learner_step'] 455 return {'sigma': self.epsilon_greedy(step)} 456 else: 457 # Decay according to `envstep` 458 step = command_info['envstep'] 459 return {'eps': self.epsilon_greedy(step)} 460 461 def _get_setting_learn(self, command_info: dict) -> dict: 462 return {} 463 464 def _get_setting_eval(self, command_info: dict) -> dict: 465 return {} 466 467 468@POLICY_REGISTRY.register('prompt_pg_command') 469class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): 470 pass 471 472 473@POLICY_REGISTRY.register('prompt_awr_command') 474class PromptAWRCommandModePolicy(PromptAWRPolicy, DummyCommandModePolicy): 475 pass