1from typing import Callable, Optional, List 2from collections import namedtuple 3import numpy as np 4from easydict import EasyDict 5 6from ding.utils import import_module, PLAYER_REGISTRY 7from .algorithm import pfsp 8 9 10class Player: 11 """ 12 Overview: 13 Base player class, player is the basic member of a league 14 Interfaces: 15 __init__ 16 Property: 17 race, payoff, checkpoint_path, player_id, total_agent_step 18 """ 19 _name = "BasePlayer" # override this variable for sub-class player 20 21 def __init__( 22 self, 23 cfg: EasyDict, 24 category: str, 25 init_payoff: 'BattleSharedPayoff', # noqa 26 checkpoint_path: str, 27 player_id: str, 28 total_agent_step: int, 29 rating: 'PlayerRating', # noqa 30 ) -> None: 31 """ 32 Overview: 33 Initialize base player metadata 34 Arguments: 35 - cfg (:obj:`EasyDict`): Player config dict. 36 - category (:obj:`str`): Player category, depending on the game, \ 37 e.g. StarCraft has 3 races ['terran', 'protoss', 'zerg']. 38 - init_payoff (:obj:`Union[BattleSharedPayoff, SoloSharedPayoff]`): Payoff shared by all players. 39 - checkpoint_path (:obj:`str`): The path to load player checkpoint. 40 - player_id (:obj:`str`): Player id in string format. 41 - total_agent_step (:obj:`int`): For active player, it should be 0; \ 42 For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``. 43 - rating (:obj:`PlayerRating`): player rating information in total league 44 """ 45 self._cfg = cfg 46 self._category = category 47 self._payoff = init_payoff 48 self._checkpoint_path = checkpoint_path 49 assert isinstance(player_id, str) 50 self._player_id = player_id 51 assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step)) 52 self._total_agent_step = total_agent_step 53 self._rating = rating 54 55 @property 56 def category(self) -> str: 57 return self._category 58 59 @property 60 def payoff(self) -> 'BattleSharedPayoff': # noqa 61 return self._payoff 62 63 @property 64 def checkpoint_path(self) -> str: 65 return self._checkpoint_path 66 67 @property 68 def player_id(self) -> str: 69 return self._player_id 70 71 @property 72 def total_agent_step(self) -> int: 73 return self._total_agent_step 74 75 @total_agent_step.setter 76 def total_agent_step(self, step: int) -> None: 77 self._total_agent_step = step 78 79 @property 80 def rating(self) -> 'PlayerRating': # noqa 81 return self._rating 82 83 @rating.setter 84 def rating(self, _rating: 'PlayerRating') -> None: # noqa 85 self._rating = _rating 86 87 88@PLAYER_REGISTRY.register('historical_player') 89class HistoricalPlayer(Player): 90 """ 91 Overview: 92 Historical player which is snapshotted from an active player, and is fixed with the checkpoint. 93 Have a unique attribute ``parent_id``. 94 Property: 95 race, payoff, checkpoint_path, player_id, total_agent_step, parent_id 96 """ 97 _name = "HistoricalPlayer" 98 99 def __init__(self, *args, parent_id: str) -> None: 100 """ 101 Overview: 102 Initialize ``_parent_id`` additionally 103 Arguments: 104 - parent_id (:obj:`str`): id of historical player's parent, should be an active player 105 """ 106 super().__init__(*args) 107 self._parent_id = parent_id 108 109 @property 110 def parent_id(self) -> str: 111 return self._parent_id 112 113 114class ActivePlayer(Player): 115 """ 116 Overview: 117 Active player can be updated, or snapshotted to a historical player in the league training. 118 Interface: 119 __init__, is_trained_enough, snapshot, mutate, get_job 120 Property: 121 race, payoff, checkpoint_path, player_id, total_agent_step 122 """ 123 _name = "ActivePlayer" 124 BRANCH = namedtuple("BRANCH", ['name', 'prob']) 125 126 def __init__(self, *args, **kwargs) -> None: 127 """ 128 Overview: 129 Initialize player metadata, depending on the game 130 Note: 131 - one_phase_step (:obj:`int`): An active player will be considered trained enough for snapshot \ 132 after two phase steps. 133 - last_enough_step (:obj:`int`): Player's last step number that satisfies ``_is_trained_enough``. 134 - strong_win_rate (:obj:`float`): If win rates between this player and all the opponents are greater than 135 this value, this player can be regarded as strong enough to these opponents. \ 136 If also already trained for one phase step, this player can be regarded as trained enough for snapshot. 137 - branch_probs (:obj:`namedtuple`): A namedtuple of probabilities of selecting different opponent branch. 138 """ 139 super().__init__(*args) 140 self._one_phase_step = int(float(self._cfg.one_phase_step)) # ``one_phase_step`` is like 1e9 141 self._last_enough_step = 0 142 self._strong_win_rate = self._cfg.strong_win_rate 143 assert isinstance(self._cfg.branch_probs, dict) 144 self._branch_probs = [self.BRANCH(k, v) for k, v in self._cfg.branch_probs.items()] 145 # self._eval_opponent_difficulty = ["WEAK", "MEDIUM", "STRONG"] 146 self._eval_opponent_difficulty = ["RULE_BASED"] 147 self._eval_opponent_index = 0 148 149 def is_trained_enough(self, select_fn: Optional[Callable] = None) -> bool: 150 """ 151 Overview: 152 Judge whether this player is trained enough for further operations(e.g. snapshot, mutate...) 153 according to past step count and overall win rates against opponents. 154 If yes, set ``self._last_agent_step`` to ``self._total_agent_step`` and return True; otherwise return False. 155 Arguments: 156 - select_fn (:obj:`function`): The function to select opponent players. 157 Returns: 158 - flag (:obj:`bool`): Whether this player is trained enough 159 """ 160 if select_fn is None: 161 select_fn = lambda x: isinstance(x, HistoricalPlayer) # noqa 162 step_passed = self._total_agent_step - self._last_enough_step 163 if step_passed < self._one_phase_step: 164 return False 165 elif step_passed >= 2 * self._one_phase_step: 166 # ``step_passed`` is 2 times of ``self._one_phase_step``, regarded as trained enough 167 self._last_enough_step = self._total_agent_step 168 return True 169 else: 170 # Get payoff against specific opponents (Different players have different type of opponent players) 171 # If min win rate is larger than ``self._strong_win_rate``, then is judged trained enough 172 selected_players = self._get_players(select_fn) 173 if len(selected_players) == 0: # No such player, therefore no past game 174 return False 175 win_rates = self._payoff[self, selected_players] 176 if win_rates.min() > self._strong_win_rate: 177 self._last_enough_step = self._total_agent_step 178 return True 179 else: 180 return False 181 182 def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa 183 """ 184 Overview: 185 Generate a snapshot historical player from the current player, called in league's ``_snapshot``. 186 Argument: 187 - metric_env (:obj:`LeagueMetricEnv`): player rating environment, one league one env 188 Returns: 189 - snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player 190 191 .. note:: 192 This method only generates a historical player object, but without saving the checkpoint, which should be 193 done by league. 194 """ 195 path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth' 196 return HistoricalPlayer( 197 self._cfg, 198 self.category, 199 self.payoff, 200 path, 201 self.player_id + '_{}_historical'.format(int(self._total_agent_step)), 202 self._total_agent_step, 203 metric_env.create_rating(mu=self.rating.mu), 204 parent_id=self.player_id 205 ) 206 207 def mutate(self, info: dict) -> Optional[str]: 208 """ 209 Overview: 210 Mutate the current player, called in league's ``_mutate_player``. 211 Arguments: 212 - info (:obj:`dict`): related information for the mutation 213 Returns: 214 - mutation_result (:obj:`str`): if the player does the mutation operation then returns the 215 corresponding model path, otherwise returns None 216 """ 217 pass 218 219 def get_job(self, eval_flag: bool = False) -> dict: 220 """ 221 Overview: 222 Get a dict containing some info about the job to be launched, e.g. the selected opponent. 223 Arguments: 224 - eval_flag (:obj:`bool`): Whether to select an opponent for evaluator task. 225 Returns: 226 - ret (:obj:`dict`): The returned dict. Should contain key ['opponent']. 227 """ 228 if eval_flag: 229 # eval opponent is a str. 230 opponent = self._eval_opponent_difficulty[self._eval_opponent_index] 231 else: 232 # collect opponent is a Player. 233 opponent = self._get_collect_opponent() 234 return { 235 'opponent': opponent, 236 } 237 238 def _get_collect_opponent(self) -> Player: 239 """ 240 Overview: 241 Select an opponent according to the player's ``branch_probs``. 242 Returns: 243 - opponent (:obj:`Player`): Selected opponent. 244 """ 245 p = np.random.uniform() 246 L = len(self._branch_probs) 247 cum_p = [0.] + [sum([j.prob for j in self._branch_probs[:i + 1]]) for i in range(L)] 248 idx = [cum_p[i] <= p < cum_p[i + 1] for i in range(L)].index(True) 249 branch_name = '_{}_branch'.format(self._branch_probs[idx].name) 250 opponent = getattr(self, branch_name)() 251 return opponent 252 253 def _get_players(self, select_fn: Callable) -> List[Player]: 254 """ 255 Overview: 256 Get a list of players in the league (shared_payoff), selected by ``select_fn`` . 257 Arguments: 258 - select_fn (:obj:`function`): players in the returned list must satisfy this function 259 Returns: 260 - players (:obj:`list`): a list of players that satisfies ``select_fn`` 261 """ 262 return [player for player in self._payoff.players if select_fn(player)] 263 264 def _get_opponent(self, players: list, p: Optional[np.ndarray] = None) -> Player: 265 """ 266 Overview: 267 Get one opponent player from list ``players`` according to probability ``p``. 268 Arguments: 269 - players (:obj:`list`): a list of players that can select opponent from 270 - p (:obj:`np.ndarray`): the selection probability of each player, should have the same size as \ 271 ``players``. If you don't need it and set None, it would select uniformly by default. 272 Returns: 273 - opponent_player (:obj:`Player`): a random chosen opponent player according to probability 274 """ 275 idx = np.random.choice(len(players), p=p) 276 return players[idx] 277 278 def increment_eval_difficulty(self) -> bool: 279 """ 280 Overview: 281 When evaluating, active player will choose a specific builtin opponent difficulty. 282 This method is used to increment the difficulty. 283 It is usually called after the easier builtin bot is already been beaten by this player. 284 Returns: 285 - increment_or_not (:obj:`bool`): True means difficulty is incremented; \ 286 False means difficulty is already the hardest. 287 """ 288 if self._eval_opponent_index < len(self._eval_opponent_difficulty) - 1: 289 self._eval_opponent_index += 1 290 return True 291 else: 292 return False 293 294 @property 295 def checkpoint_path(self) -> str: 296 return self._checkpoint_path 297 298 @checkpoint_path.setter 299 def checkpoint_path(self, path: str) -> None: 300 self._checkpoint_path = path 301 302 303@PLAYER_REGISTRY.register('naive_sp_player') 304class NaiveSpPlayer(ActivePlayer): 305 306 def _pfsp_branch(self) -> HistoricalPlayer: 307 """ 308 Overview: 309 Select prioritized fictitious self-play opponent, should be a historical player. 310 Returns: 311 - player (:obj:`HistoricalPlayer`): The selected historical player. 312 """ 313 historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer)) 314 win_rates = self._payoff[self, historical] 315 # Normal self-play if no historical players 316 if win_rates.shape == (0, ): 317 return self 318 p = pfsp(win_rates, weighting='squared') 319 return self._get_opponent(historical, p) 320 321 def _sp_branch(self) -> ActivePlayer: 322 """ 323 Overview: 324 Select normal self-play opponent 325 """ 326 return self 327 328 329def create_player(cfg: EasyDict, player_type: str, *args, **kwargs) -> Player: 330 """ 331 Overview: 332 Given the key (player_type), create a new player instance if in player_mapping's values, 333 or raise an KeyError. In other words, a derived player must first register then call ``create_player`` 334 to get the instance object. 335 Arguments: 336 - cfg (:obj:`EasyDict`): player config, necessary keys: [import_names] 337 - player_type (:obj:`str`): the type of player to be created 338 Returns: 339 - player (:obj:`Player`): the created new player, should be an instance of one of \ 340 player_mapping's values 341 """ 342 import_module(cfg.get('import_names', [])) 343 return PLAYER_REGISTRY.build(player_type, *args, **kwargs)