Skip to content

ding.worker.learner.comm.utils

ding.worker.learner.comm.utils

Full Source Code

../ding/worker/learner/comm/utils.py

1import time 2import os 3from ding.interaction import Slave, TaskFail 4from ding.utils import lists_to_dicts 5 6 7class NaiveLearner(Slave): 8 9 def __init__(self, *args, prefix='', **kwargs): 10 super().__init__(*args, **kwargs) 11 self._prefix = prefix 12 13 def _process_task(self, task): 14 task_name = task['name'] 15 if task_name == 'resource': 16 return {'cpu': 'xxx', 'gpu': 'xxx'} 17 elif task_name == 'learner_start_task': 18 time.sleep(1) 19 self.task_info = task['task_info'] 20 self.count = 0 21 return {'message': 'learner task has started'} 22 elif task_name == 'learner_get_data_task': 23 time.sleep(0.01) 24 return { 25 'task_id': self.task_info['task_id'], 26 'buffer_id': self.task_info['buffer_id'], 27 'batch_size': 2, 28 'cur_learner_iter': 1 29 } 30 elif task_name == 'learner_learn_task': 31 data = task['data'] 32 if data is None: 33 raise TaskFail(result={'message': 'no data'}) 34 time.sleep(0.1) 35 data = lists_to_dicts(data) 36 assert 'data_id' in data.keys() 37 priority_keys = ['replay_unique_id', 'replay_buffer_idx', 'priority'] 38 self.count += 1 39 ret = { 40 'info': { 41 'learner_step': self.count 42 }, 43 'task_id': self.task_info['task_id'], 44 'buffer_id': self.task_info['buffer_id'] 45 } 46 ret['info']['priority_info'] = {k: data[k] for k in priority_keys} 47 if self.count > 5: 48 ret['info']['learner_done'] = True 49 os.popen('touch {}_final_model.pth'.format(self._prefix)) 50 return ret 51 elif task_name == 'learner_close_task': 52 return {'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id']} 53 else: 54 raise TaskFail( 55 result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) 56 )