1from abc import ABC, abstractmethod 2from collections import defaultdict 3from easydict import EasyDict 4import copy 5 6from ding.utils import import_module, COMMANDER_REGISTRY, LimitedSpaceContainer 7 8 9class BaseCommander(ABC): 10 r""" 11 Overview: 12 Base parallel commander abstract class. 13 Interface: 14 get_collector_task 15 """ 16 17 @classmethod 18 def default_config(cls: type) -> EasyDict: 19 cfg = EasyDict(copy.deepcopy(cls.config)) 20 cfg.cfg_type = cls.__name__ + 'Dict' 21 return cfg 22 23 @abstractmethod 24 def get_collector_task(self) -> dict: 25 raise NotImplementedError 26 27 def judge_collector_finish(self, task_id: str, info: dict) -> bool: 28 collector_done = info.get('collector_done', False) 29 if collector_done: 30 return True 31 return False 32 33 def judge_learner_finish(self, task_id: str, info: dict) -> bool: 34 learner_done = info.get('learner_done', False) 35 if learner_done: 36 return True 37 return False 38 39 40@COMMANDER_REGISTRY.register('naive') 41class NaiveCommander(BaseCommander): 42 r""" 43 Overview: 44 A naive implementation of parallel commander. 45 Interface: 46 __init__, get_collector_task, get_learner_task, finsh_collector_task, finish_learner_task, 47 notify_fail_collector_task, notify_fail_learner_task, update_learner_info 48 """ 49 config = dict( 50 collector_task_space=1, 51 learner_task_space=1, 52 eval_interval=60, 53 ) 54 55 def __init__(self, cfg: dict) -> None: 56 r""" 57 Overview: 58 Init the naive commander according to config 59 Arguments: 60 - cfg (:obj:`dict`): The config to init commander. Should include \ 61 "collector_task_space" and "learner_task_space". 62 """ 63 self._cfg = cfg 64 self._exp_name = cfg.exp_name 65 commander_cfg = self._cfg.policy.other.commander 66 self._collector_task_space = LimitedSpaceContainer(0, commander_cfg.collector_task_space) 67 self._learner_task_space = LimitedSpaceContainer(0, commander_cfg.learner_task_space) 68 69 self._collector_env_cfg = copy.deepcopy(self._cfg.env) 70 self._collector_env_cfg.pop('collector_episode_num') 71 self._collector_env_cfg.pop('evaluator_episode_num') 72 self._collector_env_cfg.manager.episode_num = self._cfg.env.collector_episode_num 73 74 self._collector_task_count = 0 75 self._learner_task_count = 0 76 self._learner_info = defaultdict(list) 77 self._learner_task_finish_count = 0 78 self._collector_task_finish_count = 0 79 80 def get_collector_task(self) -> dict: 81 r""" 82 Overview: 83 Get a new collector task when ``collector_task_count`` is smaller than ``collector_task_space``. 84 Return: 85 - task (:obj:`dict`): New collector task. 86 """ 87 if self._collector_task_space.acquire_space(): 88 self._collector_task_count += 1 89 collector_cfg = copy.deepcopy(self._cfg.policy.collect.collector) 90 collector_cfg.collect_setting = {'eps': 0.9} 91 collector_cfg.eval_flag = False 92 collector_cfg.policy = copy.deepcopy(self._cfg.policy) 93 collector_cfg.policy_update_path = 'test.pth' 94 collector_cfg.env = self._collector_env_cfg 95 collector_cfg.exp_name = self._exp_name 96 return { 97 'task_id': 'collector_task_id{}'.format(self._collector_task_count), 98 'buffer_id': 'test', 99 'collector_cfg': collector_cfg, 100 } 101 else: 102 return None 103 104 def get_learner_task(self) -> dict: 105 r""" 106 Overview: 107 Get the new learner task when task_count is less than task_space 108 Return: 109 - task (:obj:`dict`): the new learner task 110 """ 111 if self._learner_task_space.acquire_space(): 112 self._learner_task_count += 1 113 learner_cfg = copy.deepcopy(self._cfg.policy.learn.learner) 114 learner_cfg.exp_name = self._exp_name 115 return { 116 'task_id': 'learner_task_id{}'.format(self._learner_task_count), 117 'policy_id': 'test.pth', 118 'buffer_id': 'test', 119 'learner_cfg': learner_cfg, 120 'replay_buffer_cfg': copy.deepcopy(self._cfg.policy.other.replay_buffer), 121 'policy': copy.deepcopy(self._cfg.policy), 122 } 123 else: 124 return None 125 126 def finish_collector_task(self, task_id: str, finished_task: dict) -> None: 127 r""" 128 Overview: 129 finish collector task will add the collector_task_finish_count 130 """ 131 self._collector_task_space.release_space() 132 self._collector_task_finish_count += 1 133 134 def finish_learner_task(self, task_id: str, finished_task: dict) -> str: 135 r""" 136 Overview: 137 finish learner task will add the learner_task_finish_count and get the buffer_id of task to close the buffer 138 Return: 139 the finished_task buffer_id 140 """ 141 self._learner_task_finish_count += 1 142 self._learner_task_space.release_space() 143 return finished_task['buffer_id'] 144 145 def notify_fail_collector_task(self, task: dict) -> None: 146 r""" 147 Overview: 148 naive coordinator will pass when need to notify_fail_collector_task 149 """ 150 self._collector_task_space.release_space() 151 152 def notify_fail_learner_task(self, task: dict) -> None: 153 r""" 154 Overview: 155 naive coordinator will pass when need to notify_fail_learner_task 156 """ 157 self._learner_task_space.release_space() 158 159 def update_learner_info(self, task_id: str, info: dict) -> None: 160 r""" 161 Overview: 162 append the info to learner: 163 Arguments: 164 - task_id (:obj:`str`): the learner task_id 165 - info (:obj:`dict`): the info to append to learner 166 """ 167 self._learner_info[task_id].append(info) 168 169 def increase_collector_task_space(self): 170 r"""" 171 Overview: 172 Increase task space when a new collector has added dynamically. 173 """ 174 self._collector_task_space.increase_space() 175 176 def decrease_collector_task_space(self): 177 r"""" 178 Overview: 179 Decrease task space when a new collector has removed dynamically. 180 """ 181 self._collector_task_space.decrease_space() 182 183 184def create_parallel_commander(cfg: EasyDict) -> BaseCommander: 185 r""" 186 Overview: 187 create the commander according to cfg 188 Arguments: 189 - cfg (:obj:`dict`): the commander cfg to create, should include import_names and parallel_commander_type 190 """ 191 cfg = EasyDict(cfg) 192 import_names = cfg.policy.other.commander.import_names 193 import_module(import_names) 194 return COMMANDER_REGISTRY.build(cfg.policy.other.commander.type, cfg=cfg) 195 196 197def get_parallel_commander_cls(cfg: EasyDict) -> type: 198 cfg = EasyDict(cfg) 199 import_module(cfg.get('import_names', [])) 200 return COMMANDER_REGISTRY.get(cfg.type)