1from typing import Union, Optional, List, Any, Tuple 2import torch 3import os 4from functools import partial 5 6from tensorboardX import SummaryWriter 7from copy import deepcopy 8 9from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 10 get_buffer_cls, create_serial_collector 11from ding.world_model import WorldModel 12from ding.worker import IBuffer 13from ding.envs import get_vec_env_setting, create_env_manager 14from ding.config import read_config, compile_config 15from ding.utils import set_pkg_seed, deep_merge_dicts 16from ding.policy import create_policy 17from ding.world_model import create_world_model 18from ding.entry.utils import random_collect 19 20 21def mbrl_entry_setup( 22 input_cfg: Union[str, Tuple[dict, dict]], 23 seed: int = 0, 24 env_setting: Optional[List[Any]] = None, 25 model: Optional[torch.nn.Module] = None, 26) -> Tuple: 27 if isinstance(input_cfg, str): 28 cfg, create_cfg = read_config(input_cfg) 29 else: 30 cfg, create_cfg = deepcopy(input_cfg) 31 create_cfg.policy.type = create_cfg.policy.type + '_command' 32 env_fn = None if env_setting is None else env_setting[0] 33 cfg = compile_config( 34 cfg, 35 seed=seed, 36 env=env_fn, 37 auto=True, 38 create_cfg=create_cfg, 39 save_cfg=True, 40 renew_dir=not cfg.policy.learn.get('resume_training', False) 41 ) 42 43 if env_setting is None: 44 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 45 else: 46 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 47 48 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 49 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 50 51 collector_env.seed(cfg.seed) 52 evaluator_env.seed(cfg.seed, dynamic_seed=False) 53 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 54 55 # create logger 56 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 57 58 # create world model 59 world_model = create_world_model(cfg.world_model, env_fn(cfg.env), tb_logger) 60 61 # create policy 62 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 63 64 # create worker 65 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 66 collector = create_serial_collector( 67 cfg.policy.collect.collector, 68 env=collector_env, 69 policy=policy.collect_mode, 70 tb_logger=tb_logger, 71 exp_name=cfg.exp_name 72 ) 73 evaluator = InteractionSerialEvaluator( 74 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 75 ) 76 env_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 77 commander = BaseSerialCommander( 78 cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode 79 ) 80 81 return (cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger) 82 83 84def create_img_buffer( 85 cfg: dict, input_cfg: Union[str, Tuple[dict, dict]], world_model: WorldModel, tb_logger: 'SummaryWriter' 86) -> IBuffer: # noqa 87 if isinstance(input_cfg, str): 88 _, create_cfg = read_config(input_cfg) 89 else: 90 _, create_cfg = input_cfg 91 img_buffer_cfg = cfg.world_model.other.imagination_buffer 92 img_buffer_cfg.update(create_cfg.imagination_buffer) 93 buffer_cls = get_buffer_cls(img_buffer_cfg) 94 cfg.world_model.other.imagination_buffer.update(deep_merge_dicts(buffer_cls.default_config(), img_buffer_cfg)) 95 if img_buffer_cfg.type == 'elastic': 96 img_buffer_cfg.set_buffer_size = world_model.buffer_size_scheduler 97 img_buffer = create_buffer(cfg.world_model.other.imagination_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 98 return img_buffer 99 100 101def serial_pipeline_dyna( 102 input_cfg: Union[str, Tuple[dict, dict]], 103 seed: int = 0, 104 env_setting: Optional[List[Any]] = None, 105 model: Optional[torch.nn.Module] = None, 106 max_train_iter: Optional[int] = int(1e10), 107 max_env_step: Optional[int] = int(1e10), 108) -> 'Policy': # noqa 109 """ 110 Overview: 111 Serial pipeline entry for dyna-style model-based RL. 112 Arguments: 113 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 114 ``str`` type means config file path. \ 115 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 116 - seed (:obj:`int`): Random seed. 117 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 118 ``BaseEnv`` subclass, collector env config, and evaluator env config. 119 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 120 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 121 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 122 Returns: 123 - policy (:obj:`Policy`): Converged policy. 124 """ 125 cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ 126 mbrl_entry_setup(input_cfg, seed, env_setting, model) 127 128 img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger) 129 130 learner.call_hook('before_run') 131 if cfg.policy.learn.get('resume_training', False): 132 collector.envstep = learner.collector_envstep 133 134 if cfg.policy.get('random_collect_size', 0) > 0: 135 random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) 136 137 while True: 138 collect_kwargs = commander.step() 139 # eval the policy 140 if evaluator.should_eval(collector.envstep): 141 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 142 if stop: 143 break 144 145 # fill environment buffer 146 data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 147 env_buffer.push(data, cur_collector_envstep=collector.envstep) 148 149 # eval&train world model and fill imagination buffer 150 if world_model.should_eval(collector.envstep): 151 world_model.eval(env_buffer, collector.envstep, learner.train_iter) 152 if world_model.should_train(collector.envstep): 153 world_model.train(env_buffer, collector.envstep, learner.train_iter) 154 world_model.fill_img_buffer( 155 policy.collect_mode, env_buffer, img_buffer, collector.envstep, learner.train_iter 156 ) 157 158 for i in range(cfg.policy.learn.update_per_collect): 159 batch_size = learner.policy.get_attribute('batch_size') 160 train_data = world_model.sample(env_buffer, img_buffer, batch_size, learner.train_iter) 161 learner.train(train_data, collector.envstep) 162 163 if cfg.policy.on_policy: 164 # On-policy algorithm must clear the replay buffer. 165 env_buffer.clear() 166 img_buffer.clear() 167 168 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 169 break 170 171 learner.call_hook('after_run') 172 173 return policy 174 175 176def serial_pipeline_dream( 177 input_cfg: Union[str, Tuple[dict, dict]], 178 seed: int = 0, 179 env_setting: Optional[List[Any]] = None, 180 model: Optional[torch.nn.Module] = None, 181 max_train_iter: Optional[int] = int(1e10), 182 max_env_step: Optional[int] = int(1e10), 183) -> 'Policy': # noqa 184 """ 185 Overview: 186 Serial pipeline entry for dreamer-style model-based RL. 187 Arguments: 188 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 189 ``str`` type means config file path. \ 190 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 191 - seed (:obj:`int`): Random seed. 192 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 193 ``BaseEnv`` subclass, collector env config, and evaluator env config. 194 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 195 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 196 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 197 Returns: 198 - policy (:obj:`Policy`): Converged policy. 199 """ 200 cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ 201 mbrl_entry_setup(input_cfg, seed, env_setting, model) 202 203 learner.call_hook('before_run') 204 if cfg.policy.learn.get('resume_training', False): 205 collector.envstep = learner.collector_envstep 206 207 if cfg.policy.get('random_collect_size', 0) > 0: 208 random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) 209 210 while True: 211 collect_kwargs = commander.step() 212 # eval the policy 213 if evaluator.should_eval(collector.envstep): 214 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 215 if stop: 216 break 217 218 # fill environment buffer 219 data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 220 env_buffer.push(data, cur_collector_envstep=collector.envstep) 221 222 # eval&train world model and fill imagination buffer 223 if world_model.should_eval(collector.envstep): 224 world_model.eval(env_buffer, collector.envstep, learner.train_iter) 225 if world_model.should_train(collector.envstep): 226 world_model.train(env_buffer, collector.envstep, learner.train_iter) 227 228 update_per_collect = cfg.policy.learn.update_per_collect // world_model.rollout_length_scheduler( 229 collector.envstep 230 ) 231 update_per_collect = max(1, update_per_collect) 232 for i in range(update_per_collect): 233 batch_size = learner.policy.get_attribute('batch_size') 234 train_data = env_buffer.sample(batch_size, learner.train_iter) 235 # dreamer-style: use pure on-policy imagined rollout to train policy, 236 # which depends on the current envstep to decide the rollout length 237 learner.train( 238 train_data, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep) 239 ) 240 241 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 242 break 243 244 learner.call_hook('after_run') 245 246 return policy 247 248 249def serial_pipeline_dreamer( 250 input_cfg: Union[str, Tuple[dict, dict]], 251 seed: int = 0, 252 env_setting: Optional[List[Any]] = None, 253 model: Optional[torch.nn.Module] = None, 254 max_train_iter: Optional[int] = int(1e10), 255 max_env_step: Optional[int] = int(1e10), 256) -> 'Policy': # noqa 257 """ 258 Overview: 259 Serial pipeline entry for dreamerv3. 260 Arguments: 261 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 262 ``str`` type means config file path. \ 263 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 264 - seed (:obj:`int`): Random seed. 265 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 266 ``BaseEnv`` subclass, collector env config, and evaluator env config. 267 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 268 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 269 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 270 Returns: 271 - policy (:obj:`Policy`): Converged policy. 272 """ 273 cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ 274 mbrl_entry_setup(input_cfg, seed, env_setting, model) 275 276 learner.call_hook('before_run') 277 278 # prefill environment buffer 279 if cfg.policy.get('random_collect_size', 0) > 0: 280 cfg.policy.random_collect_size = cfg.policy.random_collect_size // cfg.policy.collect.unroll_len 281 random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) 282 283 while True: 284 collect_kwargs = commander.step() 285 # eval the policy 286 if evaluator.should_eval(collector.envstep): 287 stop, reward = evaluator.eval( 288 learner.save_checkpoint, 289 learner.train_iter, 290 collector.envstep, 291 policy_kwargs=dict(world_model=world_model) 292 ) 293 if stop: 294 break 295 296 # train world model and fill imagination buffer 297 steps = ( 298 cfg.world_model.pretrain 299 if world_model.should_pretrain() else int(world_model.should_train(collector.envstep)) 300 ) 301 for _ in range(steps): 302 batch_size = learner.policy.get_attribute('batch_size') 303 batch_length = cfg.policy.learn.batch_length 304 post, context = world_model.train( 305 env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length 306 ) 307 308 start = post 309 310 learner.train( 311 start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep) 312 ) 313 314 # fill environment buffer 315 data = collector.collect( 316 train_iter=learner.train_iter, 317 policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs) 318 ) 319 env_buffer.push(data, cur_collector_envstep=collector.envstep) 320 321 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 322 break 323 324 learner.call_hook('after_run') 325 326 return policy