Skip to content

ding.worker.learner.comm.base_comm_learner

ding.worker.learner.comm.base_comm_learner

BaseCommLearner

Bases: ABC

Overview

Abstract baseclass for CommLearner.

Interfaces: init, send_policy, get_data, send_learn_info, start, close Property: hooks4call

__init__(cfg)

Overview

Initialization method.

Arguments: - cfg (:obj:EasyDict): Config dict

send_policy(state_dict) abstractmethod

Overview

Save learner's policy in corresponding path. Will be registered in base learner.

Arguments: - state_dict (:obj:dict): State dict of the runtime policy.

get_data(batch_size) abstractmethod

Overview

Get batched meta data from coordinator. Will be registered in base learner.

Arguments: - batch_size (:obj:int): Batch size. Returns: - stepdata (:obj:list): A list of training data, each element is one trajectory.

send_learn_info(learn_info) abstractmethod

Overview

Send learn info to coordinator. Will be registered in base learner.

Arguments: - learn_info (:obj:dict): Learn info in dict type.

start()

Overview

Start comm learner.

close()

Overview

Close comm learner.

hooks4call()

Returns:

Type Description
list
  • hooks (:obj:list): The hooks which comm learner has. Will be registered in learner as well.

create_comm_learner(cfg)

Overview

Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values, or raise an KeyError. In other words, a derived comm learner must first register, then can call create_comm_learner to get the instance.

Arguments: - cfg (:obj:dict): Learner config. Necessary keys: [import_names, comm_learner_type]. Returns: - learner (:obj:BaseCommLearner): The created new comm learner, should be an instance of one of comm_map's values.

Full Source Code

../ding/worker/learner/comm/base_comm_learner.py

1from abc import ABC, abstractmethod, abstractproperty 2from easydict import EasyDict 3 4from ding.utils import EasyTimer, import_module, get_task_uid, dist_init, dist_finalize, COMM_LEARNER_REGISTRY 5from ding.policy import create_policy 6from ding.worker.learner import create_learner 7 8 9class BaseCommLearner(ABC): 10 """ 11 Overview: 12 Abstract baseclass for CommLearner. 13 Interfaces: 14 __init__, send_policy, get_data, send_learn_info, start, close 15 Property: 16 hooks4call 17 """ 18 19 def __init__(self, cfg: 'EasyDict') -> None: # noqa 20 """ 21 Overview: 22 Initialization method. 23 Arguments: 24 - cfg (:obj:`EasyDict`): Config dict 25 """ 26 self._cfg = cfg 27 self._learner_uid = get_task_uid() 28 self._timer = EasyTimer() 29 if cfg.multi_gpu: 30 self._rank, self._world_size = dist_init() 31 else: 32 self._rank, self._world_size = 0, 1 33 self._multi_gpu = cfg.multi_gpu 34 self._end_flag = True 35 36 @abstractmethod 37 def send_policy(self, state_dict: dict) -> None: 38 """ 39 Overview: 40 Save learner's policy in corresponding path. 41 Will be registered in base learner. 42 Arguments: 43 - state_dict (:obj:`dict`): State dict of the runtime policy. 44 """ 45 raise NotImplementedError 46 47 @abstractmethod 48 def get_data(self, batch_size: int) -> list: 49 """ 50 Overview: 51 Get batched meta data from coordinator. 52 Will be registered in base learner. 53 Arguments: 54 - batch_size (:obj:`int`): Batch size. 55 Returns: 56 - stepdata (:obj:`list`): A list of training data, each element is one trajectory. 57 """ 58 raise NotImplementedError 59 60 @abstractmethod 61 def send_learn_info(self, learn_info: dict) -> None: 62 """ 63 Overview: 64 Send learn info to coordinator. 65 Will be registered in base learner. 66 Arguments: 67 - learn_info (:obj:`dict`): Learn info in dict type. 68 """ 69 raise NotImplementedError 70 71 def start(self) -> None: 72 """ 73 Overview: 74 Start comm learner. 75 """ 76 self._end_flag = False 77 78 def close(self) -> None: 79 """ 80 Overview: 81 Close comm learner. 82 """ 83 self._end_flag = True 84 if self._multi_gpu: 85 dist_finalize() 86 87 @abstractproperty 88 def hooks4call(self) -> list: 89 """ 90 Returns: 91 - hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well. 92 """ 93 raise NotImplementedError 94 95 def _create_learner(self, task_info: dict) -> 'BaseLearner': # noqa 96 """ 97 Overview: 98 Receive ``task_info`` passed from coordinator and create a learner. 99 Arguments: 100 - task_info (:obj:`dict`): Task info dict from coordinator. Should be like \ 101 {"learner_cfg": xxx, "policy": xxx}. 102 Returns: 103 - learner (:obj:`BaseLearner`): Created base learner. 104 105 .. note:: 106 Three methods('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set. 107 The reason why they are set here rather than base learner is that, they highly depend on the specific task. 108 Only after task info is passed from coordinator to comm learner through learner slave, can they be 109 clarified and initialized. 110 """ 111 # Prepare learner config and instantiate a learner object. 112 learner_cfg = EasyDict(task_info['learner_cfg']) 113 learner = create_learner(learner_cfg, dist_info=[self._rank, self._world_size], exp_name=learner_cfg.exp_name) 114 # Set 3 methods and dataloader in created learner that are necessary in parallel setting. 115 for item in ['get_data', 'send_policy', 'send_learn_info']: 116 setattr(learner, item, getattr(self, item)) 117 # Set policy in created learner. 118 policy_cfg = task_info['policy'] 119 policy_cfg = EasyDict(policy_cfg) 120 learner.policy = create_policy(policy_cfg, enable_field=['learn']).learn_mode 121 learner.setup_dataloader() 122 return learner 123 124 125def create_comm_learner(cfg: EasyDict) -> BaseCommLearner: 126 """ 127 Overview: 128 Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values, 129 or raise an KeyError. In other words, a derived comm learner must first register, 130 then can call ``create_comm_learner`` to get the instance. 131 Arguments: 132 - cfg (:obj:`dict`): Learner config. Necessary keys: [import_names, comm_learner_type]. 133 Returns: 134 - learner (:obj:`BaseCommLearner`): The created new comm learner, should be an instance of one of \ 135 comm_map's values. 136 """ 137 import_module(cfg.get('import_names', [])) 138 return COMM_LEARNER_REGISTRY.build(cfg.type, cfg=cfg)