1import copy 2from collections import defaultdict 3from typing import Tuple, Optional 4from easydict import EasyDict 5from tabulate import tabulate 6import numpy as np 7 8from ding.utils import LockContext, LockContextType 9from .player import Player 10 11 12class BattleRecordDict(dict): 13 """ 14 Overview: 15 A dict which is used to record battle game result. 16 Initialized four fixed keys: `wins`, `draws`, `losses`, `games`; Each with value 0. 17 Interfaces: 18 __mul__ 19 """ 20 data_keys = ['wins', 'draws', 'losses', 'games'] 21 22 def __init__(self) -> None: 23 """ 24 Overview: 25 Initialize four fixed keys ['wins', 'draws', 'losses', 'games'] and set value to 0 26 """ 27 super(BattleRecordDict, self).__init__() 28 for k in self.data_keys: 29 self[k] = 0 30 31 def __mul__(self, decay: float) -> dict: 32 """ 33 Overview: 34 Multiply each key's value with the input multiplier ``decay`` 35 Arguments: 36 - decay (:obj:`float`): The multiplier. 37 Returns: 38 - obj (:obj:`dict`): A deepcopied RecordDict after multiplication decay. 39 """ 40 obj = copy.deepcopy(self) 41 for k in obj.keys(): 42 obj[k] *= decay 43 return obj 44 45 46class BattleSharedPayoff: 47 """ 48 Overview: 49 Payoff data structure to record historical match result, this payoff is shared among all the players. 50 Use LockContext to ensure thread safe, since all players from all threads can access and modify it. 51 Interface: 52 __getitem__, add_player, update, get_key 53 Property: 54 players 55 """ 56 57 # TODO(nyz) whether ensures the thread-safe 58 59 def __init__(self, cfg: EasyDict): 60 """ 61 Overview: 62 Initialize battle payoff 63 Arguments: 64 - cfg (:obj:`dict`): config(contains {decay, min_win_rate_games}) 65 """ 66 # ``_players``` is a list containing the references(shallow copy) of all players, 67 # while ``_players_ids``` is a list of strings. 68 self._players = [] 69 self._players_ids = [] 70 # ``_data``` is a defaultdict. If a key doesn't exist when query, return an instance of BattleRecordDict class. 71 # Key is '[player_id]-[player_id]' string, value is the payoff of the two players. 72 self._data = defaultdict(BattleRecordDict) 73 # ``_decay``` controls how past game info (win, draw, loss) decays. 74 self._decay = cfg.decay 75 # ``_min_win_rate_games``` is used in ``self._win_rate`` method for calculating win rate between two players. 76 self._min_win_rate_games = cfg.get('min_win_rate_games', 8) 77 # Thread lock. 78 self._lock = LockContext(lock_type=LockContextType.THREAD_LOCK) 79 80 def __repr__(self) -> str: 81 headers = ["Home Player", "Away Player", "Wins", "Draws", "Losses", "Naive Win Rate"] 82 data = [] 83 for k, v in self._data.items(): 84 k1 = k.split('-') 85 # k is the format of '{}-{}'.format(name1, name2), and each HistoricalPlayer has `historical` suffix 86 if 'historical' in k1[0]: 87 # reverse representation 88 naive_win_rate = (v['losses'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8) 89 data.append([k1[1], k1[0], v['losses'], v['draws'], v['wins'], naive_win_rate]) 90 else: 91 naive_win_rate = (v['wins'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8) 92 data.append([k1[0], k1[1], v['wins'], v['draws'], v['losses'], naive_win_rate]) 93 data = sorted(data, key=lambda x: x[0]) 94 s = tabulate(data, headers=headers, tablefmt='pipe') 95 return s 96 97 def __getitem__(self, players: tuple) -> np.ndarray: 98 """ 99 Overview: 100 Get win rates between home players and away players one by one 101 Arguments: 102 - players (:obj:`tuple`): A tuple of (home, away), each one is a player or a player list. 103 Returns: 104 - win_rates (:obj:`np.ndarray`): Win rate (squeezed, see Shape for more details) \ 105 between each player from home and each player from away. 106 Shape: 107 - win_rates: Assume there are m home players and n away players.(m,n > 0) 108 109 - m != 1 and n != 1: shape is (m, n) 110 - m == 1: shape is (n) 111 - n == 1: shape is (m) 112 """ 113 with self._lock: 114 home, away = players 115 assert isinstance(home, list) or isinstance(home, Player) 116 assert isinstance(away, list) or isinstance(away, Player) 117 if isinstance(home, Player): 118 home = [home] 119 if isinstance(away, Player): 120 away = [away] 121 win_rates = np.array([[self._win_rate(h.player_id, a.player_id) for a in away] for h in home]) 122 if len(home) == 1 or len(away) == 1: 123 win_rates = win_rates.reshape(-1) 124 return win_rates 125 126 def _win_rate(self, home: str, away: str) -> float: 127 """ 128 Overview: 129 Calculate win rate of one `home player` vs one `away player` 130 Arguments: 131 - home (:obj:`str`): home player id to access win rate 132 - away (:obj:`str`): away player id to access win rate 133 Returns: 134 - win rate (:obj:`float`): float win rate value. \ 135 Only when total games is no less than ``self._min_win_rate_games``, \ 136 can the win rate be calculated by (wins + draws/2) / games, or return 0.5 by default. 137 """ 138 key, reverse = self.get_key(home, away) 139 handle = self._data[key] 140 # No enough game records. 141 if handle['games'] < self._min_win_rate_games: 142 return 0.5 143 # should use reverse here 144 wins = handle['wins'] if not reverse else handle['losses'] 145 return (wins + 0.5 * handle['draws']) / (handle['games']) 146 147 @property 148 def players(self): 149 """ 150 Overview: 151 Get all the players 152 Returns: 153 - players (:obj:`list`): players list 154 """ 155 with self._lock: 156 return self._players 157 158 def add_player(self, player: Player) -> None: 159 """ 160 Overview: 161 Add a player to the shared payoff. 162 Arguments: 163 - player (:obj:`Player`): The player to be added. Usually is a new one to the league as well. 164 """ 165 with self._lock: 166 self._players.append(player) 167 self._players_ids.append(player.player_id) 168 169 def update(self, job_info: dict) -> bool: 170 """ 171 Overview: 172 Update payoff with job_info when a job is to be finished. 173 If update succeeds, return True; If raises an exception when updating, resolve it and return False. 174 Arguments: 175 - job_info (:obj:`dict`): A dict containing job result information. 176 Returns: 177 - result (:obj:`bool`): Whether update is successful. 178 179 .. note:: 180 job_info has at least 5 keys ['launch_player', 'player_id', 'env_num', 'episode_num', 'result']. 181 Key ``player_id`` 's value is a tuple of (home_id, away_id). 182 Key ``result`` 's value is a two-layer list with the length of (episode_num, env_num). 183 """ 184 185 def _win_loss_reverse(result_: str, reverse_: bool) -> str: 186 if result_ == 'draws' or not reverse_: 187 return result_ 188 reverse_dict = {'wins': 'losses', 'losses': 'wins'} 189 return reverse_dict[result_] 190 191 with self._lock: 192 home_id, away_id = job_info['player_id'] 193 job_info_result = job_info['result'] 194 # for compatibility of one-layer list 195 if not isinstance(job_info_result[0], list): 196 job_info_result = [job_info_result] 197 try: 198 assert home_id in self._players_ids, "home_id error" 199 assert away_id in self._players_ids, "away_id error" 200 # Assert all results are in ['wins', 'losses', 'draws'] 201 assert all([i in BattleRecordDict.data_keys[:3] for j in job_info_result for i in j]), "results error" 202 except Exception as e: 203 print("[ERROR] invalid job_info: {}\n\tError reason is: {}".format(job_info, e)) 204 return False 205 if home_id == away_id: # self-play 206 key, reverse = self.get_key(home_id, away_id) 207 self._data[key]['draws'] += 1 # self-play defaults to draws 208 self._data[key]['games'] += 1 209 else: 210 key, reverse = self.get_key(home_id, away_id) 211 # Update with decay 212 # job_info_result is a two-layer list, including total NxM episodes of M envs, 213 # the first(outer) layer is episode dimension and the second(inner) layer is env dimension. 214 for one_episode_result in job_info_result: 215 for one_episode_result_per_env in one_episode_result: 216 # All categories should decay 217 self._data[key] *= self._decay 218 self._data[key]['games'] += 1 219 result = _win_loss_reverse(one_episode_result_per_env, reverse) 220 self._data[key][result] += 1 221 return True 222 223 def get_key(self, home: str, away: str) -> Tuple[str, bool]: 224 """ 225 Overview: 226 Join home player id and away player id in alphabetival order. 227 Arguments: 228 - home (:obj:`str`): Home player id 229 - away (:obj:`str`): Away player id 230 Returns: 231 - key (:obj:`str`): Tow ids sorted in alphabetical order, and joined by '-'. 232 - reverse (:obj:`bool`): Whether the two player ids are reordered. 233 """ 234 assert isinstance(home, str) 235 assert isinstance(away, str) 236 reverse = False 237 if home <= away: 238 tmp = [home, away] 239 else: 240 tmp = [away, home] 241 reverse = True 242 return '-'.join(tmp), reverse 243 244 245def create_payoff(cfg: EasyDict) -> Optional[BattleSharedPayoff]: 246 """ 247 Overview: 248 Given the key (payoff type), now supports keys ['solo', 'battle'], 249 create a new payoff instance if in payoff_mapping's values, or raise an KeyError. 250 Arguments: 251 - cfg (:obj:`EasyDict`): payoff config containing at least one key 'type' 252 Returns: 253 - payoff (:obj:`BattleSharedPayoff` or :obj:`SoloSharedPayoff`): the created new payoff, \ 254 should be an instance of one of payoff_mapping's values 255 """ 256 payoff_mapping = {'battle': BattleSharedPayoff} 257 payoff_type = cfg.type 258 if payoff_type not in payoff_mapping.keys(): 259 raise KeyError("not support payoff type: {}".format(payoff_type)) 260 else: 261 return payoff_mapping[payoff_type](cfg)