1import os 2import time 3from typing import Union, Dict, Callable 4from queue import Queue 5from threading import Thread 6 7from ding.utils import read_file, save_file, COMM_COLLECTOR_REGISTRY 8from ding.utils.file_helper import save_to_di_store 9from ding.interaction import Slave, TaskFail 10from .base_comm_collector import BaseCommCollector 11 12 13class CollectorSlave(Slave): 14 """ 15 Overview: 16 A slave, whose master is coordinator. 17 Used to pass message between comm collector and coordinator. 18 Interfaces: 19 __init__, _process_task 20 """ 21 22 # override 23 def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None: 24 """ 25 Overview: 26 Init callback functions additionally. Callback functions are methods in comm collector. 27 """ 28 super().__init__(*args, **kwargs) 29 self._callback_fn = callback_fn 30 self._current_task_info = None 31 32 def _process_task(self, task: dict) -> Union[dict, TaskFail]: 33 """ 34 Overview: 35 Process a task according to input task info dict, which is passed in by master coordinator. 36 For each type of task, you can refer to corresponding callback function in comm collector for details. 37 Arguments: 38 - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". 39 Returns: 40 - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. 41 """ 42 task_name = task['name'] 43 if task_name == 'resource': 44 return self._callback_fn['deal_with_resource']() 45 elif task_name == 'collector_start_task': 46 self._current_task_info = task['task_info'] 47 self._callback_fn['deal_with_collector_start'](self._current_task_info) 48 return {'message': 'collector task has started'} 49 elif task_name == 'collector_data_task': 50 data = self._callback_fn['deal_with_collector_data']() 51 data['buffer_id'] = self._current_task_info['buffer_id'] 52 data['task_id'] = self._current_task_info['task_id'] 53 return data 54 elif task_name == 'collector_close_task': 55 data = self._callback_fn['deal_with_collector_close']() 56 data['task_id'] = self._current_task_info['task_id'] 57 return data 58 else: 59 raise TaskFail( 60 result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) 61 ) 62 63 64@COMM_COLLECTOR_REGISTRY.register('flask_fs') 65class FlaskFileSystemCollector(BaseCommCollector): 66 """ 67 Overview: 68 An implementation of CommLearner, using flask and the file system. 69 Interfaces: 70 __init__, deal_with_resource, deal_with_collector_start, deal_with_collector_data, deal_with_collector_close,\ 71 get_policy_update_info, send_stepdata, send_metadata, start, close 72 """ 73 74 # override 75 def __init__(self, cfg: dict) -> None: 76 """ 77 Overview: 78 Initialization method. 79 Arguments: 80 - cfg (:obj:`EasyDict`): Config dict 81 """ 82 BaseCommCollector.__init__(self, cfg) 83 host, port = cfg.host, cfg.port 84 self._callback_fn = { 85 'deal_with_resource': self.deal_with_resource, 86 'deal_with_collector_start': self.deal_with_collector_start, 87 'deal_with_collector_data': self.deal_with_collector_data, 88 'deal_with_collector_close': self.deal_with_collector_close, 89 } 90 self._slave = CollectorSlave(host, port, callback_fn=self._callback_fn) 91 92 self._path_policy = cfg.path_policy 93 self._path_data = cfg.path_data 94 if not os.path.exists(self._path_data): 95 try: 96 os.mkdir(self._path_data) 97 except Exception as e: 98 pass 99 self._metadata_queue = Queue(8) 100 self._collector_close_flag = False 101 self._collector = None 102 103 def deal_with_resource(self) -> dict: 104 """ 105 Overview: 106 Callback function in ``CollectorSlave``. Return how many resources are needed to start current collector. 107 Returns: 108 - resource (:obj:`dict`): Resource info dict, including ['gpu', 'cpu']. 109 """ 110 return {'gpu': 1, 'cpu': 20} 111 112 def deal_with_collector_start(self, task_info: dict) -> None: 113 """ 114 Overview: 115 Callback function in ``CollectorSlave``. 116 Create a collector and start a collector thread of the created one. 117 Arguments: 118 - task_info (:obj:`dict`): Task info dict. 119 Note: 120 In ``_create_collector`` method in base class ``BaseCommCollector``, 4 methods 121 'send_metadata', 'send_stepdata', 'get_policy_update_info', and policy are set. 122 You can refer to it for details. 123 """ 124 self._collector_close_flag = False 125 self._collector = self._create_collector(task_info) 126 self._collector_thread = Thread(target=self._collector.start, args=(), daemon=True, name='collector_start') 127 self._collector_thread.start() 128 129 def deal_with_collector_data(self) -> dict: 130 """ 131 Overview: 132 Callback function in ``CollectorSlave``. Get data sample dict from ``_metadata_queue``, 133 which will be sent to coordinator afterwards. 134 Returns: 135 - data (:obj:`Any`): Data sample dict. 136 """ 137 while True: 138 if not self._metadata_queue.empty(): 139 data = self._metadata_queue.get() 140 break 141 else: 142 time.sleep(0.1) 143 return data 144 145 def deal_with_collector_close(self) -> dict: 146 self._collector_close_flag = True 147 finish_info = self._collector.get_finish_info() 148 self._collector.close() 149 self._collector_thread.join() 150 del self._collector_thread 151 self._collector = None 152 return finish_info 153 154 # override 155 def get_policy_update_info(self, path: str) -> dict: 156 """ 157 Overview: 158 Get policy information in corresponding path. 159 Arguments: 160 - path (:obj:`str`): path to policy update information. 161 """ 162 if self._collector_close_flag: 163 return 164 if self._path_policy not in path: 165 path = os.path.join(self._path_policy, path) 166 return read_file(path, use_lock=True) 167 168 # override 169 def send_stepdata(self, path: str, stepdata: list) -> None: 170 """ 171 Overview: 172 Save collector's step data in corresponding path. 173 Arguments: 174 - path (:obj:`str`): Path to save data. 175 - stepdata (:obj:`Any`): Data of one step. 176 """ 177 if save_to_di_store: 178 if self._collector_close_flag: 179 return b'0' * 20 # return an object reference that doesn't exist 180 object_ref = save_to_di_store(stepdata) 181 # print('send_stepdata:', path, 'object ref:', object_ref, 'len:', len(stepdata)) 182 return object_ref 183 184 if self._collector_close_flag: 185 return 186 name = os.path.join(self._path_data, path) 187 save_file(name, stepdata, use_lock=False) 188 189 # override 190 def send_metadata(self, metadata: dict) -> None: 191 """ 192 Overview: 193 Store learn info dict in queue, which will be retrieved by callback function "deal_with_collector_learn" 194 in collector slave, then will be sent to coordinator. 195 Arguments: 196 - metadata (:obj:`Any`): meta data. 197 """ 198 if self._collector_close_flag: 199 return 200 necessary_metadata_keys = set(['data_id', 'policy_iter']) 201 necessary_info_keys = set(['collector_done', 'cur_episode', 'cur_sample', 'cur_step']) 202 assert necessary_metadata_keys.issubset(set(metadata.keys()) 203 ) or necessary_info_keys.issubset(set(metadata.keys())) 204 while True: 205 if not self._metadata_queue.full(): 206 self._metadata_queue.put(metadata) 207 break 208 else: 209 time.sleep(0.1) 210 211 def start(self) -> None: 212 """ 213 Overview: 214 Start comm collector itself and the collector slave. 215 """ 216 BaseCommCollector.start(self) 217 self._slave.start() 218 219 def close(self) -> None: 220 """ 221 Overview: 222 Close comm collector itself and the collector slave. 223 """ 224 if self._end_flag: 225 return 226 total_sleep_count = 0 227 while self._collector is not None and total_sleep_count < 10: 228 self._collector.info("please first close collector") 229 time.sleep(1) 230 total_sleep_count += 1 231 self._slave.close() 232 BaseCommCollector.close(self) 233 234 def __del__(self) -> None: 235 self.close()