ding.worker.collector.comm.utils¶
ding.worker.collector.comm.utils
¶
Full Source Code
../ding/worker/collector/comm/utils.py
1import torch 2from ding.interaction.slave import Slave, TaskFail 3 4 5class NaiveCollector(Slave): 6 """ 7 Overview: 8 A slave, whose master is coordinator. 9 Used to pass message between comm collector and coordinator. 10 Interfaces: 11 _process_task, _get_timestep 12 """ 13 14 def __init__(self, *args, prefix='', **kwargs): 15 super().__init__(*args, **kwargs) 16 self._prefix = prefix 17 18 def _process_task(self, task): 19 """ 20 Overview: 21 Process a task according to input task info dict, which is passed in by master coordinator. 22 For each type of task, you can refer to corresponding callback function in comm collector for details. 23 Arguments: 24 - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". 25 Returns: 26 - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. 27 """ 28 task_name = task['name'] 29 if task_name == 'resource': 30 return {'cpu': '20', 'gpu': '1'} 31 elif task_name == 'collector_start_task': 32 self.count = 0 33 self.task_info = task['task_info'] 34 return {'message': 'collector task has started'} 35 elif task_name == 'collector_data_task': 36 self.count += 1 37 data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count) 38 torch.save(self._get_timestep(), data_id) 39 data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0} 40 data['task_id'] = self.task_info['task_id'] 41 if self.count == 20: 42 return { 43 'task_id': self.task_info['task_id'], 44 'collector_done': True, 45 'cur_episode': 1, 46 'cur_step': 314, 47 'cur_sample': 314, 48 } 49 else: 50 return data 51 else: 52 raise TaskFail( 53 result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) 54 ) 55 56 def _get_timestep(self): 57 return [ 58 { 59 'obs': torch.rand(4), 60 'next_obs': torch.randn(4), 61 'reward': torch.randint(0, 2, size=(3, )).float(), 62 'action': torch.randint(0, 2, size=(1, )), 63 'done': False, 64 } 65 ]