1from typing import Union, Dict 2import uuid 3import copy 4import os 5import os.path as osp 6from abc import abstractmethod 7from easydict import EasyDict 8from tabulate import tabulate 9 10from ding.league.player import ActivePlayer, HistoricalPlayer, create_player 11from ding.league.shared_payoff import create_payoff 12from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY, \ 13 deep_merge_dicts 14from .metric import LeagueMetricEnv 15 16 17class BaseLeague: 18 """ 19 Overview: 20 League, proposed by Google Deepmind AlphaStar. Can manage multiple players in one league. 21 Interface: 22 get_job_info, judge_snapshot, update_active_player, finish_job, save_checkpoint 23 24 .. note:: 25 In ``__init__`` method, league would also initialized players as well(in ``_init_players`` method). 26 """ 27 28 @classmethod 29 def default_config(cls: type) -> EasyDict: 30 cfg = EasyDict(copy.deepcopy(cls.config)) 31 cfg.cfg_type = cls.__name__ + 'Dict' 32 return cfg 33 34 config = dict( 35 league_type='base', 36 import_names=["ding.league.base_league"], 37 # ---player---- 38 # "player_category" is just a name. Depends on the env. 39 # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss']. 40 player_category=['default'], 41 # Support different types of active players for solo and battle league. 42 # For solo league, supports ['solo_active_player']. 43 # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter']. 44 # active_players=dict(), 45 # "use_pretrain" means whether to use pretrain model to initialize active player. 46 use_pretrain=False, 47 # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player. 48 # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and 49 # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well. 50 # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories. 51 use_pretrain_init_historical=False, 52 pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ), 53 # ---payoff--- 54 payoff=dict( 55 # Supports ['battle'] 56 type='battle', 57 decay=0.99, 58 min_win_rate_games=8, 59 ), 60 metric=dict( 61 mu=0, 62 sigma=25 / 3, 63 beta=25 / 3 / 2, 64 tau=0.0, 65 draw_probability=0.02, 66 ), 67 ) 68 69 def __init__(self, cfg: EasyDict) -> None: 70 """ 71 Overview: 72 Initialization method. 73 Arguments: 74 - cfg (:obj:`EasyDict`): League config. 75 """ 76 self.cfg = deep_merge_dicts(self.default_config(), cfg) 77 self.path_policy = cfg.path_policy 78 if not osp.exists(self.path_policy): 79 os.mkdir(self.path_policy) 80 81 self.league_uid = str(uuid.uuid1()) 82 # TODO dict players 83 self.active_players = [] 84 self.historical_players = [] 85 self.player_path = "./league" 86 self.payoff = create_payoff(self.cfg.payoff) 87 metric_cfg = self.cfg.metric 88 self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability) 89 self._active_players_lock = LockContext(lock_type=LockContextType.THREAD_LOCK) 90 self._init_players() 91 92 def _init_players(self) -> None: 93 """ 94 Overview: 95 Initialize players (active & historical) in the league. 96 """ 97 # Add different types of active players for each player category, according to ``cfg.active_players``. 98 for cate in self.cfg.player_category: # Player's category (Depends on the env) 99 for k, n in self.cfg.active_players.items(): # Active player's type 100 for i in range(n): # This type's active player number 101 name = '{}_{}_{}'.format(k, cate, i) 102 ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name)) 103 player = create_player( 104 self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0, self.metric_env.create_rating() 105 ) 106 if self.cfg.use_pretrain: 107 self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path) 108 self.active_players.append(player) 109 self.payoff.add_player(player) 110 111 # Add pretrain player as the initial HistoricalPlayer for each player category. 112 if self.cfg.use_pretrain_init_historical: 113 for cate in self.cfg.player_category: 114 main_player_name = [k for k in self.cfg.keys() if 'main_player' in k] 115 assert len(main_player_name) == 1, main_player_name 116 main_player_name = main_player_name[0] 117 name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate) 118 parent_name = '{}_{}_0'.format(main_player_name, cate) 119 hp = HistoricalPlayer( 120 self.cfg.get(main_player_name), 121 cate, 122 self.payoff, 123 self.cfg.pretrain_checkpoint_path[cate], 124 name, 125 0, 126 self.metric_env.create_rating(), 127 parent_id=parent_name 128 ) 129 self.historical_players.append(hp) 130 self.payoff.add_player(hp) 131 132 # Save active players' ``player_id``` & ``player_ckpt```. 133 self.active_players_ids = [p.player_id for p in self.active_players] 134 self.active_players_ckpts = [p.checkpoint_path for p in self.active_players] 135 # Validate active players are unique by ``player_id``. 136 assert len(self.active_players_ids) == len(set(self.active_players_ids)) 137 138 def get_job_info(self, player_id: str = None, eval_flag: bool = False) -> dict: 139 """ 140 Overview: 141 Get info dict of the job which is to be launched to an active player. 142 Arguments: 143 - player_id (:obj:`str`): The active player's id. 144 - eval_flag (:obj:`bool`): Whether this is an evaluation job. 145 Returns: 146 - job_info (:obj:`dict`): Job info. 147 ReturnsKeys: 148 - necessary: ``launch_player`` (the active player) 149 """ 150 if player_id is None: 151 player_id = self.active_players_ids[0] 152 with self._active_players_lock: 153 idx = self.active_players_ids.index(player_id) 154 player = self.active_players[idx] 155 job_info = self._get_job_info(player, eval_flag) 156 assert 'launch_player' in job_info.keys() and job_info['launch_player'] == player.player_id 157 return job_info 158 159 @abstractmethod 160 def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict: 161 """ 162 Overview: 163 Real `get_job` method. Called by ``_launch_job``. 164 Arguments: 165 - player (:obj:`ActivePlayer`): The active player to be launched a job. 166 - eval_flag (:obj:`bool`): Whether this is an evaluation job. 167 Returns: 168 - job_info (:obj:`dict`): Job info. Should include keys ['lauch_player']. 169 """ 170 raise NotImplementedError 171 172 def judge_snapshot(self, player_id: str, force: bool = False) -> bool: 173 """ 174 Overview: 175 Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a 176 historical player(prepare the checkpoint and add it to the shared payoff), then mutate it, and return True. 177 Otherwise, return False. 178 Arguments: 179 - player_id (:obj:`ActivePlayer`): The active player's id. 180 Returns: 181 - snapshot_or_not (:obj:`dict`): Whether the active player is snapshotted. 182 """ 183 with self._active_players_lock: 184 idx = self.active_players_ids.index(player_id) 185 player = self.active_players[idx] 186 if force or player.is_trained_enough(): 187 # Snapshot 188 hp = player.snapshot(self.metric_env) 189 self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path) 190 self.historical_players.append(hp) 191 self.payoff.add_player(hp) 192 # Mutate 193 self._mutate_player(player) 194 return True 195 else: 196 return False 197 198 @abstractmethod 199 def _mutate_player(self, player: ActivePlayer) -> None: 200 """ 201 Overview: 202 Players have the probability to mutate, e.g. Reset network parameters. 203 Called by ``self.judge_snapshot``. 204 Arguments: 205 - player (:obj:`ActivePlayer`): The active player that may mutate. 206 """ 207 raise NotImplementedError 208 209 def update_active_player(self, player_info: dict) -> None: 210 """ 211 Overview: 212 Update an active player's info. 213 Arguments: 214 - player_info (:obj:`dict`): Info dict of the player which is to be updated. 215 ArgumentsKeys: 216 - necessary: `player_id`, `train_iteration` 217 """ 218 try: 219 idx = self.active_players_ids.index(player_info['player_id']) 220 player = self.active_players[idx] 221 return self._update_player(player, player_info) 222 except ValueError as e: 223 print(e) 224 225 @abstractmethod 226 def _update_player(self, player: ActivePlayer, player_info: dict) -> None: 227 """ 228 Overview: 229 Update an active player. Called by ``self.update_active_player``. 230 Arguments: 231 - player (:obj:`ActivePlayer`): The active player that will be updated. 232 - player_info (:obj:`dict`): Info dict of the active player which is to be updated. 233 """ 234 raise NotImplementedError 235 236 def finish_job(self, job_info: dict) -> None: 237 """ 238 Overview: 239 Finish current job. Update shared payoff to record the game results. 240 Arguments: 241 - job_info (:obj:`dict`): A dict containing job result information. 242 """ 243 # TODO(nyz) more fine-grained job info 244 self.payoff.update(job_info) 245 if 'eval_flag' in job_info and job_info['eval_flag']: 246 home_id, away_id = job_info['player_id'] 247 home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id) 248 job_info_result = job_info['result'] 249 if isinstance(job_info_result[0], list): 250 job_info_result = sum(job_info_result, []) 251 home_player.rating, away_player.rating = self.metric_env.rate_1vs1( 252 home_player.rating, away_player.rating, result=job_info_result 253 ) 254 255 def get_player_by_id(self, player_id: str) -> 'Player': # noqa 256 if 'historical' in player_id: 257 return [p for p in self.historical_players if p.player_id == player_id][0] 258 else: 259 return [p for p in self.active_players if p.player_id == player_id][0] 260 261 @staticmethod 262 def save_checkpoint(src_checkpoint, dst_checkpoint) -> None: 263 ''' 264 Overview: 265 Copy a checkpoint from path ``src_checkpoint`` to path ``dst_checkpoint``. 266 Arguments: 267 - src_checkpoint (:obj:`str`): Source checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth 268 - dst_checkpoint (:obj:`str`): Destination checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth 269 ''' 270 checkpoint = read_file(src_checkpoint) 271 save_file(dst_checkpoint, checkpoint) 272 273 def player_rank(self, string: bool = False) -> Union[str, Dict[str, float]]: 274 rank = {} 275 for p in self.active_players + self.historical_players: 276 name = p.player_id 277 rank[name] = p.rating.exposure 278 if string: 279 headers = ["Player ID", "Rank (TrueSkill)"] 280 data = [] 281 for k, v in rank.items(): 282 data.append([k, "{:.2f}".format(v)]) 283 s = "\n" + tabulate(data, headers=headers, tablefmt='pipe') 284 return s 285 else: 286 return rank 287 288 289def create_league(cfg: EasyDict, *args) -> BaseLeague: 290 """ 291 Overview: 292 Given the key (league_type), create a new league instance if in league_mapping's values, 293 or raise an KeyError. In other words, a derived league must first register then call ``create_league`` 294 to get the instance object. 295 Arguments: 296 - cfg (:obj:`EasyDict`): league config, necessary keys: [league.import_module, league.learner_type] 297 Returns: 298 - league (:obj:`BaseLeague`): the created new league, should be an instance of one of \ 299 league_mapping's values 300 """ 301 import_module(cfg.get('import_names', [])) 302 return LEAGUE_REGISTRY.build(cfg.league_type, cfg=cfg, *args)