1from .default_helper import deep_merge_dicts 2from easydict import EasyDict 3 4 5class Scheduler(object): 6 """ 7 Overview: 8 Update learning parameters when the trueskill metrics has stopped improving. 9 For example, models often benefits from reducing entropy weight once the learning process stagnates. 10 This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, 11 the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'. 12 Arguments: 13 - schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline. 14 Default: False 15 - schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode 16 decides the way of updating the parameters. Default:'reduce'. 17 - factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be 18 increased/decreased. Default: 0.05 19 - change_range (:obj:`list`): Indicates the minimum and maximum value 20 the parameter can reach respectively. Default: [-1,1] 21 - threshold (:obj:`float`): Threshold for measuring the new optimum, 22 to only focus on significant changes. Default: 1e-4. 23 - optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of 24 optimization objective. Dynamic_threshold = last_metrics + threshold in `max` 25 mode or last_metrics - threshold in `min` mode. Default: 'min' 26 - patience (:obj:`int`): Number of epochs with no improvement after which 27 the parameter will be updated. For example, if `patience = 2`, then we 28 will ignore the first 2 epochs with no improvement, and will only update 29 the parameter after the 3rd epoch if the metrics still hasn't improved then. 30 Default: 10. 31 - cooldown (:obj:`int`): Number of epochs to wait before resuming 32 normal operation after the parameter has been updated. Default: 0. 33 Interfaces: 34 __init__, update_param, step 35 Property: 36 in_cooldown, is_better 37 """ 38 39 config = dict( 40 schedule_flag=False, 41 schedule_mode='reduce', 42 factor=0.05, 43 change_range=[-1, 1], 44 threshold=1e-4, 45 optimize_mode='min', 46 patience=10, 47 cooldown=0, 48 ) 49 50 def __init__(self, merged_scheduler_config: EasyDict) -> None: 51 """ 52 Overview: 53 Initialize the scheduler. 54 Arguments: 55 - merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user 56 config and defaul config 57 """ 58 59 schedule_mode = merged_scheduler_config.schedule_mode 60 factor = merged_scheduler_config.factor 61 change_range = merged_scheduler_config.change_range 62 threshold = merged_scheduler_config.threshold 63 optimize_mode = merged_scheduler_config.optimize_mode 64 patience = merged_scheduler_config.patience 65 cooldown = merged_scheduler_config.cooldown 66 67 assert schedule_mode in [ 68 'reduce', 'add', 'multi', 'div' 69 ], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']' 70 self.schedule_mode = schedule_mode 71 72 assert isinstance(factor, (float, int)), 'The factor should be a float/int number ' 73 assert factor > 0, 'The factor should be greater than 0' 74 self.factor = float(factor) 75 76 assert isinstance(change_range, 77 list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers' 78 assert (isinstance(change_range[0], (float, int))) and ( 79 isinstance(change_range[1], (float, int)) 80 ), 'The change_range should be a list with 2 float/int numbers' 81 assert change_range[0] < change_range[1], 'The first num should be smaller than the second num' 82 self.change_range = change_range 83 84 assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number' 85 self.threshold = threshold 86 87 assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']' 88 self.optimize_mode = optimize_mode 89 90 assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0' 91 assert patience >= 0, 'The patience should be a integer greater than or equal to 0' 92 self.patience = patience 93 94 assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0' 95 assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0' 96 self.cooldown = cooldown 97 self.cooldown_counter = cooldown 98 99 self.last_metrics = None 100 self.bad_epochs_num = 0 101 102 def step(self, metrics: float, param: float) -> float: 103 """ 104 Overview: 105 Decides whether to update the scheduled parameter 106 Args: 107 - metrics (:obj:`float`): current input metrics 108 - param (:obj:`float`): parameter need to be updated 109 Returns: 110 - step_param (:obj:`float`): parameter after one step 111 """ 112 assert isinstance(metrics, float), 'The metrics should be converted to a float number' 113 cur_metrics = metrics 114 115 if self.is_better(cur_metrics): 116 self.bad_epochs_num = 0 117 else: 118 self.bad_epochs_num += 1 119 self.last_metrics = cur_metrics 120 121 if self.in_cooldown: 122 self.cooldown_counter -= 1 123 self.bad_epochs_num = 0 # ignore any bad epochs in cooldown 124 125 if self.bad_epochs_num > self.patience: 126 param = self.update_param(param) 127 self.cooldown_counter = self.cooldown 128 self.bad_epochs_num = 0 129 return param 130 131 def update_param(self, param: float) -> float: 132 """ 133 Overview: 134 update the scheduling parameter 135 Args: 136 - param (:obj:`float`): parameter need to be updated 137 Returns: 138 - updated param (:obj:`float`): parameter after updating 139 """ 140 schedule_fn = { 141 'reduce': lambda x, y, z: max(x - y, z[0]), 142 'add': lambda x, y, z: min(x + y, z[1]), 143 'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]), 144 'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]), 145 } 146 147 schedule_mode_list = list(schedule_fn.keys()) 148 149 if self.schedule_mode in schedule_mode_list: 150 return schedule_fn[self.schedule_mode](param, self.factor, self.change_range) 151 else: 152 raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list)) 153 154 @property 155 def in_cooldown(self) -> bool: 156 """ 157 Overview: 158 Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler 159 will ignore any bad epochs. 160 """ 161 return self.cooldown_counter > 0 162 163 def is_better(self, cur: float) -> bool: 164 """ 165 Overview: 166 Checks whether the current metrics is better than last matric with respect to threshold. 167 Args: 168 - cur (:obj:`float`): current metrics 169 """ 170 if self.last_metrics is None: 171 return True 172 173 elif self.optimize_mode == 'min': 174 return cur < self.last_metrics - self.threshold 175 176 elif self.optimize_mode == 'max': 177 return cur > self.last_metrics + self.threshold