Skip to content

ding.world_model.base_world_model

ding.world_model.base_world_model

WorldModel

Bases: ABC

Overview

Abstract baseclass for world model.

Interfaces

should_train, should_eval, train, eval, step

should_train(envstep)

Overview

Check whether need to train world model.

should_eval(envstep)

Overview

Check whether need to evaluate world model.

train(env_buffer, envstep, train_iter) abstractmethod

Overview

Train world model using data from env_buffer.

Parameters:

Name Type Description Default
- env_buffer (

obj:IBuffer): the buffer which collects real environment steps

required
- envstep (

obj:int): the current number of environment steps in real environment

required
- train_iter (

obj:int): the current number of policy training iterations

required

eval(env_buffer, envstep, train_iter) abstractmethod

Overview

Evaluate world model using data from env_buffer.

Parameters:

Name Type Description Default
- env_buffer (

obj:IBuffer): the buffer that collects real environment steps

required
- envstep (

obj:int): the current number of environment steps in real environment

required
- train_iter (

obj:int): the current number of policy training iterations

required

step(obs, action) abstractmethod

Overview

Take one step in world model.

Parameters:

Name Type Description Default
- obs (

obj:torch.Tensor): current observations :math:S_t

required
- action (

obj:torch.Tensor): current actions :math:A_t

required

Returns:

Type Description
Tensor
  • reward (:obj:torch.Tensor): rewards :math:R_t
Tensor
  • next_obs (:obj:torch.Tensor): next observations :math:S_t+1
Tensor
  • done (:obj:torch.Tensor): whether the episodes ends
Shapes

:math:B: batch size :math:O: observation dimension :math:A: action dimension

  • obs: [B, O]
  • action: [B, A]
  • reward: [B, ]
  • next_obs: [B, O]
  • done: [B, ]

DynaWorldModel

Bases: WorldModel, ABC

Overview

Dyna-style world model (summarized in arXiv: 1907.02057) which stores and\ reuses imagination rollout in the imagination buffer.

Interfaces

sample, fill_img_buffer, should_train, should_eval, train, eval, step

sample(env_buffer, img_buffer, batch_size, train_iter)

Overview

Sample from the combination of environment buffer and imagination buffer with\ certain ratio to generate batched data for policy training.

Parameters:

Name Type Description Default
- policy (

obj:namedtuple): policy in collect mode

required
- env_buffer (

obj:IBuffer): the buffer that collects real environment steps

required
- img_buffer (

obj:IBuffer): the buffer that collects imagination steps

required
- batch_size (

obj:int): the batch size for policy training

required
- train_iter (

obj:int): the current number of policy training iterations

required

Returns:

Type Description
dict
  • data (:obj:int): the training data for policy training

fill_img_buffer(policy, env_buffer, img_buffer, envstep, train_iter)

Overview

Sample from the env_buffer, rollouts to generate new data, and push them into the img_buffer.

Parameters:

Name Type Description Default
- policy (

obj:namedtuple): policy in collect mode

required
- env_buffer (

obj:IBuffer): the buffer that collects real environment steps

required
- img_buffer (

obj:IBuffer): the buffer that collects imagination steps

required
- envstep (

obj:int): the current number of environment steps in real environment

required
- train_iter (

obj:int): the current number of policy training iterations

required

DreamWorldModel

Bases: WorldModel, ABC

Overview

Dreamer-style world model which uses each imagination rollout only once\ and backpropagate through time(rollout) to optimize policy.

Interfaces

rollout, should_train, should_eval, train, eval, step

rollout(obs, actor_fn, envstep, **kwargs)

Overview

Generate batched imagination rollouts starting from the current observations.\ This function is useful for value gradients where the policy is optimized by BPTT.

Parameters:

Name Type Description Default
- obs (

obj:Tensor): the current observations :math:S_t

required
- actor_fn (

obj:Callable): the unified API :math:(A_t, H_t) = pi(S_t)

required
- envstep (

obj:int): the current number of environment steps in real environment

required

Returns:

Type Description
Tensor
  • obss (:obj:Tensor): :math:S_t, ..., S_t+n
Tensor
  • actions (:obj:Tensor): :math:A_t, ..., A_t+n
Tensor
  • rewards (:obj:Tensor): :math:R_t, ..., R_t+n-1
Tensor
  • aug_rewards (:obj:Tensor): :math:H_t, ..., H_t+n, this can be entropy bonus as in SAC, otherwise it should be a zero tensor
Tensor
  • dones (:obj:Tensor): :math:\text{done}_t, ..., \text{done}_t+n
Shapes

:math:N: time step :math:B: batch size :math:O: observation dimension :math:A: action dimension

  • obss: :math:[N+1, B, O], where obss[0] are the real observations
  • actions: :math:[N+1, B, A]
  • rewards: :math:[N, B]
  • aug_rewards: :math:[N+1, B]
  • dones: :math:[N, B]

.. note:: - The rollout length is determined by rollout length scheduler.

- actor_fn's inputs and outputs shape are similar to WorldModel.step()

HybridWorldModel

Bases: DynaWorldModel, DreamWorldModel, ABC

Overview

The hybrid model that combines reused and on-the-fly rollouts.

Interfaces

rollout, sample, fill_img_buffer, should_train, should_eval, train, eval, step

Full Source Code

../ding/world_model/base_world_model.py

1from typing import Tuple, Callable, Optional 2from collections import namedtuple 3from abc import ABC, abstractmethod 4 5import torch 6from torch import Tensor, nn 7from easydict import EasyDict 8 9from ding.worker import IBuffer 10from ding.envs import BaseEnv 11from ding.utils import deep_merge_dicts 12from ding.world_model.utils import get_rollout_length_scheduler 13 14from ding.utils import import_module, WORLD_MODEL_REGISTRY 15 16 17def get_world_model_cls(cfg): 18 import_module(cfg.get('import_names', [])) 19 return WORLD_MODEL_REGISTRY.get(cfg.type) 20 21 22def create_world_model(cfg, *args, **kwargs): 23 import_module(cfg.get('import_names', [])) 24 return WORLD_MODEL_REGISTRY.build(cfg.type, cfg, *args, **kwargs) 25 26 27class WorldModel(ABC): 28 r""" 29 Overview: 30 Abstract baseclass for world model. 31 32 Interfaces: 33 should_train, should_eval, train, eval, step 34 """ 35 36 config = dict( 37 train_freq=250, # w.r.t environment step 38 eval_freq=250, # w.r.t environment step 39 cuda=True, 40 rollout_length_scheduler=dict( 41 type='linear', 42 rollout_start_step=20000, 43 rollout_end_step=150000, 44 rollout_length_min=1, 45 rollout_length_max=25, 46 ) 47 ) 48 49 def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa 50 self.cfg = cfg 51 self.env = env 52 self.tb_logger = tb_logger 53 54 self._cuda = cfg.cuda 55 self.train_freq = cfg.train_freq 56 self.eval_freq = cfg.eval_freq 57 self.rollout_length_scheduler = get_rollout_length_scheduler(cfg.rollout_length_scheduler) 58 59 self.last_train_step = 0 60 self.last_eval_step = 0 61 62 @classmethod 63 def default_config(cls: type) -> EasyDict: 64 # can not call default_config() recursively 65 # because config will be overwritten by subclasses 66 merge_cfg = EasyDict(cfg_type=cls.__name__ + 'Dict') 67 while cls != ABC: 68 merge_cfg = deep_merge_dicts(merge_cfg, cls.config) 69 cls = cls.__base__ 70 return merge_cfg 71 72 def should_train(self, envstep: int): 73 r""" 74 Overview: 75 Check whether need to train world model. 76 """ 77 return (envstep - self.last_train_step) >= self.train_freq 78 79 def should_eval(self, envstep: int): 80 r""" 81 Overview: 82 Check whether need to evaluate world model. 83 """ 84 return (envstep - self.last_eval_step) >= self.eval_freq and self.last_train_step != 0 85 86 @abstractmethod 87 def train(self, env_buffer: IBuffer, envstep: int, train_iter: int): 88 r""" 89 Overview: 90 Train world model using data from env_buffer. 91 92 Arguments: 93 - env_buffer (:obj:`IBuffer`): the buffer which collects real environment steps 94 - envstep (:obj:`int`): the current number of environment steps in real environment 95 - train_iter (:obj:`int`): the current number of policy training iterations 96 """ 97 raise NotImplementedError 98 99 @abstractmethod 100 def eval(self, env_buffer: IBuffer, envstep: int, train_iter: int): 101 r""" 102 Overview: 103 Evaluate world model using data from env_buffer. 104 105 Arguments: 106 - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps 107 - envstep (:obj:`int`): the current number of environment steps in real environment 108 - train_iter (:obj:`int`): the current number of policy training iterations 109 """ 110 raise NotImplementedError 111 112 @abstractmethod 113 def step(self, obs: Tensor, action: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 114 r""" 115 Overview: 116 Take one step in world model. 117 118 Arguments: 119 - obs (:obj:`torch.Tensor`): current observations :math:`S_t` 120 - action (:obj:`torch.Tensor`): current actions :math:`A_t` 121 122 Returns: 123 - reward (:obj:`torch.Tensor`): rewards :math:`R_t` 124 - next_obs (:obj:`torch.Tensor`): next observations :math:`S_t+1` 125 - done (:obj:`torch.Tensor`): whether the episodes ends 126 127 Shapes: 128 :math:`B`: batch size 129 :math:`O`: observation dimension 130 :math:`A`: action dimension 131 132 - obs: [B, O] 133 - action: [B, A] 134 - reward: [B, ] 135 - next_obs: [B, O] 136 - done: [B, ] 137 """ 138 raise NotImplementedError 139 140 141class DynaWorldModel(WorldModel, ABC): 142 r""" 143 Overview: 144 Dyna-style world model (summarized in arXiv: 1907.02057) which stores and\ 145 reuses imagination rollout in the imagination buffer. 146 147 Interfaces: 148 sample, fill_img_buffer, should_train, should_eval, train, eval, step 149 """ 150 151 config = dict( 152 other=dict( 153 real_ratio=0.05, 154 rollout_retain=4, 155 rollout_batch_size=100000, 156 imagination_buffer=dict( 157 type='elastic', 158 replay_buffer_size=6000000, 159 deepcopy=False, 160 enable_track_used_data=False, 161 # set_buffer_size=set_buffer_size, 162 periodic_thruput_seconds=60, 163 ), 164 ) 165 ) 166 167 def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa 168 super().__init__(cfg, env, tb_logger) 169 self.real_ratio = cfg.other.real_ratio 170 self.rollout_batch_size = cfg.other.rollout_batch_size 171 self.rollout_retain = cfg.other.rollout_retain 172 self.buffer_size_scheduler = \ 173 lambda x: self.rollout_length_scheduler(x) * self.rollout_batch_size * self.rollout_retain 174 175 def sample(self, env_buffer: IBuffer, img_buffer: IBuffer, batch_size: int, train_iter: int) -> dict: 176 r""" 177 Overview: 178 Sample from the combination of environment buffer and imagination buffer with\ 179 certain ratio to generate batched data for policy training. 180 181 Arguments: 182 - policy (:obj:`namedtuple`): policy in collect mode 183 - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps 184 - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps 185 - batch_size (:obj:`int`): the batch size for policy training 186 - train_iter (:obj:`int`): the current number of policy training iterations 187 188 Returns: 189 - data (:obj:`int`): the training data for policy training 190 """ 191 env_batch_size = int(batch_size * self.real_ratio) 192 img_batch_size = batch_size - env_batch_size 193 env_data = env_buffer.sample(env_batch_size, train_iter) 194 img_data = img_buffer.sample(img_batch_size, train_iter) 195 train_data = env_data + img_data 196 return train_data 197 198 def fill_img_buffer( 199 self, policy: namedtuple, env_buffer: IBuffer, img_buffer: IBuffer, envstep: int, train_iter: int 200 ): 201 r""" 202 Overview: 203 Sample from the env_buffer, rollouts to generate new data, and push them into the img_buffer. 204 205 Arguments: 206 - policy (:obj:`namedtuple`): policy in collect mode 207 - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps 208 - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps 209 - envstep (:obj:`int`): the current number of environment steps in real environment 210 - train_iter (:obj:`int`): the current number of policy training iterations 211 """ 212 from ding.torch_utils import to_tensor 213 from ding.envs import BaseEnvTimestep 214 from ding.worker.collector.base_serial_collector import to_tensor_transitions 215 216 def step(obs, act): 217 # This function has the same input and output format as env manager's step 218 data_id = list(obs.keys()) 219 obs = torch.stack([obs[id] for id in data_id], dim=0) 220 act = torch.stack([act[id] for id in data_id], dim=0) 221 with torch.no_grad(): 222 rewards, next_obs, terminals = self.step(obs, act) 223 # terminals = self.termination_fn(next_obs) 224 timesteps = { 225 id: BaseEnvTimestep(n, r, d, {}) 226 for id, n, r, d in zip( 227 data_id, 228 next_obs.cpu().numpy(), 229 rewards.unsqueeze(-1).cpu().numpy(), # ding api 230 terminals.cpu().numpy() 231 ) 232 } 233 return timesteps 234 235 # set rollout length 236 rollout_length = self.rollout_length_scheduler(envstep) 237 # load data 238 data = env_buffer.sample(self.rollout_batch_size, train_iter, replace=True) 239 obs = {id: data[id]['obs'] for id in range(len(data))} 240 # rollout 241 buffer = [[] for id in range(len(obs))] 242 new_data = [] 243 for i in range(rollout_length): 244 # get action 245 obs = to_tensor(obs, dtype=torch.float32) 246 policy_output = policy.forward(obs) 247 actions = {id: output['action'] for id, output in policy_output.items()} 248 # predict next obs and reward 249 # timesteps = self.step(obs, actions, env_model) 250 timesteps = step(obs, actions) 251 obs_new = {} 252 for id, timestep in timesteps.items(): 253 transition = policy.process_transition(obs[id], policy_output[id], timestep) 254 transition['collect_iter'] = train_iter 255 buffer[id].append(transition) 256 if not timestep.done: 257 obs_new[id] = timestep.obs 258 if timestep.done or i + 1 == rollout_length: 259 transitions = to_tensor_transitions(buffer[id]) 260 train_sample = policy.get_train_sample(transitions) 261 new_data.extend(train_sample) 262 if len(obs_new) == 0: 263 break 264 obs = obs_new 265 266 img_buffer.push(new_data, cur_collector_envstep=envstep) 267 268 269class DreamWorldModel(WorldModel, ABC): 270 r""" 271 Overview: 272 Dreamer-style world model which uses each imagination rollout only once\ 273 and backpropagate through time(rollout) to optimize policy. 274 275 Interfaces: 276 rollout, should_train, should_eval, train, eval, step 277 """ 278 279 def rollout(self, obs: Tensor, actor_fn: Callable[[Tensor], Tuple[Tensor, Tensor]], envstep: int, 280 **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Optional[bool]]: 281 r""" 282 Overview: 283 Generate batched imagination rollouts starting from the current observations.\ 284 This function is useful for value gradients where the policy is optimized by BPTT. 285 286 Arguments: 287 - obs (:obj:`Tensor`): the current observations :math:`S_t` 288 - actor_fn (:obj:`Callable`): the unified API :math:`(A_t, H_t) = pi(S_t)` 289 - envstep (:obj:`int`): the current number of environment steps in real environment 290 291 Returns: 292 - obss (:obj:`Tensor`): :math:`S_t, ..., S_t+n` 293 - actions (:obj:`Tensor`): :math:`A_t, ..., A_t+n` 294 - rewards (:obj:`Tensor`): :math:`R_t, ..., R_t+n-1` 295 - aug_rewards (:obj:`Tensor`): :math:`H_t, ..., H_t+n`, this can be entropy bonus as in SAC, 296 otherwise it should be a zero tensor 297 - dones (:obj:`Tensor`): :math:`\text{done}_t, ..., \text{done}_t+n` 298 299 Shapes: 300 :math:`N`: time step 301 :math:`B`: batch size 302 :math:`O`: observation dimension 303 :math:`A`: action dimension 304 305 - obss: :math:`[N+1, B, O]`, where obss[0] are the real observations 306 - actions: :math:`[N+1, B, A]` 307 - rewards: :math:`[N, B]` 308 - aug_rewards: :math:`[N+1, B]` 309 - dones: :math:`[N, B]` 310 311 .. note:: 312 - The rollout length is determined by rollout length scheduler. 313 314 - actor_fn's inputs and outputs shape are similar to WorldModel.step() 315 """ 316 horizon = self.rollout_length_scheduler(envstep) 317 if isinstance(self, nn.Module): 318 # Rollouts should propagate gradients only to policy, 319 # so make sure that the world model is not updated by rollout. 320 self.requires_grad_(False) 321 obss = [obs] 322 actions = [] 323 rewards = [] 324 aug_rewards = [] # -temperature*logprob 325 dones = [] 326 for _ in range(horizon): 327 action, aug_reward = actor_fn(obs) 328 # done: probability of termination 329 reward, obs, done = self.step(obs, action, **kwargs) 330 reward = reward + aug_reward 331 obss.append(obs) 332 actions.append(action) 333 rewards.append(reward) 334 aug_rewards.append(aug_reward) 335 dones.append(done) 336 action, aug_reward = actor_fn(obs) 337 actions.append(action) 338 aug_rewards.append(aug_reward) 339 if isinstance(self, nn.Module): 340 self.requires_grad_(True) 341 return ( 342 torch.stack(obss), 343 torch.stack(actions), 344 # rewards is an empty list when horizon=0 345 torch.stack(rewards) if rewards else torch.tensor(rewards, device=obs.device), 346 torch.stack(aug_rewards), 347 torch.stack(dones) if dones else torch.tensor(dones, device=obs.device) 348 ) 349 350 351class HybridWorldModel(DynaWorldModel, DreamWorldModel, ABC): 352 r""" 353 Overview: 354 The hybrid model that combines reused and on-the-fly rollouts. 355 356 Interfaces: 357 rollout, sample, fill_img_buffer, should_train, should_eval, train, eval, step 358 """ 359 360 def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa 361 DynaWorldModel.__init__(self, cfg, env, tb_logger) 362 DreamWorldModel.__init__(self, cfg, env, tb_logger)