Skip to content

ding.framework.middleware.functional.logger

ding.framework.middleware.functional.logger

online_logger(record_train_iter=False, train_show_freq=100)

Overview

Create an online RL tensorboard logger for recording training and evaluation metrics.

Arguments: - record_train_iter (:obj:bool): Whether to record training iteration. Default is False. - train_show_freq (:obj:int): Frequency of showing training logs. Default is 100. Returns: - _logger (:obj:Callable): A logger function that takes an OnlineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars".

Examples:

>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))

offline_logger(train_show_freq=100)

Overview

Create an offline RL tensorboard logger for recording training and evaluation metrics.

Arguments: - train_show_freq (:obj:int): Frequency of showing training logs. Defaults to 100. Returns: - _logger (:obj:Callable): A logger function that takes an OfflineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars".

Examples:

>>> task.use(offline_logger(train_show_freq=1000))

wandb_online_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)

Overview

Wandb visualizer to track the experiment.

Arguments: - record_path (:obj:str): The path to save the replay of simulation. - cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. q_value or action probability. - return_logger: boolean. Whether to track the return value. - metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies. - env (:obj:BaseEnvManagerV2): Evaluator environment. - model (:obj:nn.Module): Policy neural network model. - anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count. - project_name (:obj:str): The name of wandb project. - run_name (:obj:str): The name of wandb run. - wandb_sweep (:obj:bool): Whether to use wandb sweep. ''' Returns: - _plot (:obj:Callable): A logger function that takes an OnlineRLContext object as input.

wandb_offline_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)

Overview

Wandb visualizer to track the experiment.

Arguments: - record_path (:obj:str): The path to save the replay of simulation. - cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. q_value or action probability. - return_logger: boolean. Whether to track the return value. - vis_dataset: boolean. Whether to visualize the dataset. - metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies. - env (:obj:BaseEnvManagerV2): Evaluator environment. - model (:obj:nn.Module): Policy neural network model. - anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count. - project_name (:obj:str): The name of wandb project. - run_name (:obj:str): The name of wandb run. - wandb_sweep (:obj:bool): Whether to use wandb sweep. ''' Returns: - _plot (:obj:Callable): A logger function that takes an OfflineRLContext object as input.

Full Source Code

../ding/framework/middleware/functional/logger.py

1from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union 2from ditk import logging 3from easydict import EasyDict 4from matplotlib import pyplot as plt 5from matplotlib import animation 6import os 7import numpy as np 8import torch 9import wandb 10import pickle 11import treetensor.numpy as tnp 12from ding.framework import task 13from ding.envs import BaseEnvManagerV2 14from ding.utils import DistributedWriter 15from ding.torch_utils import to_ndarray 16from ding.utils.default_helper import one_time_warning 17 18if TYPE_CHECKING: 19 from ding.framework import OnlineRLContext, OfflineRLContext 20 21 22def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: 23 """ 24 Overview: 25 Create an online RL tensorboard logger for recording training and evaluation metrics. 26 Arguments: 27 - record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False. 28 - train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100. 29 Returns: 30 - _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. 31 Raises: 32 - RuntimeError: If writer is None. 33 - NotImplementedError: If the key of train_output is not supported, such as "scalars". 34 35 Examples: 36 >>> task.use(online_logger(record_train_iter=False, train_show_freq=1000)) 37 """ 38 if task.router.is_active and not task.has_role(task.role.LEARNER): 39 return task.void() 40 writer = DistributedWriter.get_instance() 41 if writer is None: 42 raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") 43 last_train_show_iter = -1 44 45 def _logger(ctx: "OnlineRLContext"): 46 if task.finish: 47 writer.close() 48 nonlocal last_train_show_iter 49 50 if not np.isinf(ctx.eval_value): 51 if record_train_iter: 52 writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) 53 writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) 54 else: 55 writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step) 56 if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: 57 last_train_show_iter = ctx.train_iter 58 if isinstance(ctx.train_output, List): 59 output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO 60 else: 61 output = ctx.train_output 62 for k, v in output.items(): 63 if k in ['priority', 'td_error_priority']: 64 continue 65 if "[scalars]" in k: 66 new_k = k.split(']')[-1] 67 raise NotImplementedError 68 elif "[histogram]" in k: 69 new_k = k.split(']')[-1] 70 writer.add_histogram(new_k, v, ctx.env_step) 71 if record_train_iter: 72 writer.add_histogram(new_k, v, ctx.train_iter) 73 else: 74 if record_train_iter: 75 writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) 76 writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) 77 else: 78 writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) 79 80 return _logger 81 82 83def offline_logger(train_show_freq: int = 100) -> Callable: 84 """ 85 Overview: 86 Create an offline RL tensorboard logger for recording training and evaluation metrics. 87 Arguments: 88 - train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100. 89 Returns: 90 - _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. 91 Raises: 92 - RuntimeError: If writer is None. 93 - NotImplementedError: If the key of train_output is not supported, such as "scalars". 94 95 Examples: 96 >>> task.use(offline_logger(train_show_freq=1000)) 97 """ 98 if task.router.is_active and not task.has_role(task.role.LEARNER): 99 return task.void() 100 writer = DistributedWriter.get_instance() 101 if writer is None: 102 raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") 103 last_train_show_iter = -1 104 105 def _logger(ctx: "OfflineRLContext"): 106 nonlocal last_train_show_iter 107 if task.finish: 108 writer.close() 109 if not np.isinf(ctx.eval_value): 110 writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) 111 if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: 112 last_train_show_iter = ctx.train_iter 113 output = ctx.train_output 114 for k, v in output.items(): 115 if k in ['priority']: 116 continue 117 if "[scalars]" in k: 118 new_k = k.split(']')[-1] 119 raise NotImplementedError 120 elif "[histogram]" in k: 121 new_k = k.split(']')[-1] 122 writer.add_histogram(new_k, v, ctx.train_iter) 123 else: 124 writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) 125 126 return _logger 127 128 129# four utility functions for wandb logger 130def softmax(logit: np.ndarray) -> np.ndarray: 131 v = np.exp(logit) 132 return v / v.sum(axis=-1, keepdims=True) 133 134 135def action_prob(num, action_prob, ln): 136 ax = plt.gca() 137 ax.set_ylim([0, 1]) 138 for rect, x in zip(ln, action_prob[num]): 139 rect.set_height(x) 140 return ln 141 142 143def return_prob(num, return_prob, ln): 144 return ln 145 146 147def return_distribution(episode_return): 148 num = len(episode_return) 149 max_return = max(episode_return) 150 min_return = min(episode_return) 151 hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6)) 152 gap = (max_return - min_return + 100) / 5 153 x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)] 154 return hist / num, x_dim 155 156 157def wandb_online_logger( 158 record_path: str = None, 159 cfg: Union[dict, EasyDict] = None, 160 exp_config: Union[dict, EasyDict] = None, 161 metric_list: Optional[List[str]] = None, 162 env: Optional[BaseEnvManagerV2] = None, 163 model: Optional[torch.nn.Module] = None, 164 anonymous: bool = False, 165 project_name: str = 'default-project', 166 run_name: str = None, 167 wandb_sweep: bool = False, 168) -> Callable: 169 """ 170 Overview: 171 Wandb visualizer to track the experiment. 172 Arguments: 173 - record_path (:obj:`str`): The path to save the replay of simulation. 174 - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: 175 - gradient_logger: boolean. Whether to track the gradient. 176 - plot_logger: boolean. Whether to track the metrics like reward and loss. 177 - video_logger: boolean. Whether to upload the rendering video replay. 178 - action_logger: boolean. `q_value` or `action probability`. 179 - return_logger: boolean. Whether to track the return value. 180 - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. 181 - env (:obj:`BaseEnvManagerV2`): Evaluator environment. 182 - model (:obj:`nn.Module`): Policy neural network model. 183 - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ 184 of data without wandb count. 185 - project_name (:obj:`str`): The name of wandb project. 186 - run_name (:obj:`str`): The name of wandb run. 187 - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. 188 ''' 189 Returns: 190 - _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. 191 """ 192 if task.router.is_active and not task.has_role(task.role.LEARNER): 193 return task.void() 194 color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] 195 if metric_list is None: 196 metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] 197 # Initialize wandb with default settings 198 # Settings can be covered by calling wandb.init() at the top of the script 199 if exp_config: 200 if not wandb_sweep: 201 if run_name is not None: 202 if anonymous: 203 wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") 204 else: 205 wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) 206 else: 207 if anonymous: 208 wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") 209 else: 210 wandb.init(project=project_name, config=exp_config, reinit=True) 211 else: 212 if run_name is not None: 213 if anonymous: 214 wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") 215 else: 216 wandb.init(project=project_name, config=exp_config, name=run_name) 217 else: 218 if anonymous: 219 wandb.init(project=project_name, config=exp_config, anonymous="must") 220 else: 221 wandb.init(project=project_name, config=exp_config) 222 else: 223 if not wandb_sweep: 224 if run_name is not None: 225 if anonymous: 226 wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") 227 else: 228 wandb.init(project=project_name, reinit=True, name=run_name) 229 else: 230 if anonymous: 231 wandb.init(project=project_name, reinit=True, anonymous="must") 232 else: 233 wandb.init(project=project_name, reinit=True) 234 else: 235 if run_name is not None: 236 if anonymous: 237 wandb.init(project=project_name, name=run_name, anonymous="must") 238 else: 239 wandb.init(project=project_name, name=run_name) 240 else: 241 if anonymous: 242 wandb.init(project=project_name, anonymous="must") 243 else: 244 wandb.init(project=project_name) 245 plt.switch_backend('agg') 246 if cfg is None: 247 cfg = EasyDict( 248 dict( 249 gradient_logger=False, 250 plot_logger=True, 251 video_logger=False, 252 action_logger=False, 253 return_logger=False, 254 ) 255 ) 256 else: 257 if not isinstance(cfg, EasyDict): 258 cfg = EasyDict(cfg) 259 for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: 260 if key not in cfg.keys(): 261 cfg[key] = False 262 263 # The visualizer is called to save the replay of the simulation 264 # which will be uploaded to wandb later 265 if env is not None and cfg.video_logger is True and record_path is not None: 266 env.enable_save_replay(replay_path=record_path) 267 if cfg.gradient_logger: 268 wandb.watch(model, log="all", log_freq=100, log_graph=True) 269 else: 270 one_time_warning( 271 "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." 272 ) 273 274 first_plot = True 275 276 def _plot(ctx: "OnlineRLContext"): 277 nonlocal first_plot 278 if first_plot: 279 first_plot = False 280 ctx.wandb_url = wandb.run.get_project_url() 281 282 info_for_logging = {} 283 284 if cfg.plot_logger: 285 for metric in metric_list: 286 if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: 287 if isinstance(ctx.train_output[metric], torch.Tensor): 288 info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) 289 else: 290 info_for_logging.update({metric: ctx.train_output[metric]}) 291 elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: 292 metric_value_list = [] 293 for item in ctx.train_output: 294 if isinstance(item[metric], torch.Tensor): 295 metric_value_list.append(item[metric].cpu().detach().numpy()) 296 else: 297 metric_value_list.append(item[metric]) 298 metric_value = np.mean(metric_value_list) 299 info_for_logging.update({metric: metric_value}) 300 else: 301 one_time_warning( 302 "If you want to use wandb to visualize the result, please set plot_logger = True in the config." 303 ) 304 305 if ctx.eval_value != -np.inf: 306 if hasattr(ctx, "eval_value_min"): 307 info_for_logging.update({ 308 "episode return min": ctx.eval_value_min, 309 }) 310 if hasattr(ctx, "eval_value_max"): 311 info_for_logging.update({ 312 "episode return max": ctx.eval_value_max, 313 }) 314 if hasattr(ctx, "eval_value_std"): 315 info_for_logging.update({ 316 "episode return std": ctx.eval_value_std, 317 }) 318 if hasattr(ctx, "eval_value"): 319 info_for_logging.update({ 320 "episode return mean": ctx.eval_value, 321 }) 322 if hasattr(ctx, "train_iter"): 323 info_for_logging.update({ 324 "train iter": ctx.train_iter, 325 }) 326 if hasattr(ctx, "env_step"): 327 info_for_logging.update({ 328 "env step": ctx.env_step, 329 }) 330 331 eval_output = ctx.eval_output['output'] 332 episode_return = ctx.eval_output['episode_return'] 333 episode_return = np.array(episode_return) 334 if len(episode_return.shape) == 2: 335 episode_return = episode_return.squeeze(1) 336 337 if cfg.video_logger: 338 if 'replay_video' in ctx.eval_output: 339 # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format 340 # The numpy tensor must be either 4 dimensional or 5 dimensional. 341 # Channels should be (time, channel, height, width) or (batch, time, channel, height width) 342 video_images = ctx.eval_output['replay_video'] 343 video_images = video_images.astype(np.uint8) 344 info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) 345 elif record_path is not None: 346 file_list = [] 347 for p in os.listdir(record_path): 348 if os.path.splitext(p)[-1] == ".mp4": 349 file_list.append(p) 350 file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) 351 video_path = os.path.join(record_path, file_list[-2]) 352 info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) 353 354 if cfg.action_logger: 355 action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif")) 356 if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): 357 if isinstance(eval_output, tnp.ndarray): 358 action_prob = softmax(eval_output.logit) 359 else: 360 action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] 361 fig, ax = plt.subplots() 362 plt.ylim([-1, 1]) 363 action_dim = len(action_prob[1]) 364 x_range = [str(x + 1) for x in range(action_dim)] 365 ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) 366 ani = animation.FuncAnimation( 367 fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) 368 ) 369 ani.save(action_path, writer='pillow') 370 info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) 371 372 elif all(['action' in v for v in eval_output[0]]): 373 for i, action_trajectory in enumerate(eval_output): 374 fig, ax = plt.subplots() 375 fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) 376 steps = fig_data[:, 0] 377 actions = fig_data[:, 1:] 378 plt.ylim([-1, 1]) 379 for j in range(actions.shape[1]): 380 ax.scatter(steps, actions[:, j]) 381 info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) 382 383 if cfg.return_logger: 384 return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif")) 385 fig, ax = plt.subplots() 386 ax = plt.gca() 387 ax.set_ylim([0, 1]) 388 hist, x_dim = return_distribution(episode_return) 389 assert len(hist) == len(x_dim) 390 ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) 391 ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) 392 ani.save(return_path, writer='pillow') 393 info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) 394 395 if bool(info_for_logging): 396 wandb.log(data=info_for_logging, step=ctx.env_step) 397 plt.clf() 398 399 return _plot 400 401 402def wandb_offline_logger( 403 record_path: str = None, 404 cfg: Union[dict, EasyDict] = None, 405 exp_config: Union[dict, EasyDict] = None, 406 metric_list: Optional[List[str]] = None, 407 env: Optional[BaseEnvManagerV2] = None, 408 model: Optional[torch.nn.Module] = None, 409 anonymous: bool = False, 410 project_name: str = 'default-project', 411 run_name: str = None, 412 wandb_sweep: bool = False, 413) -> Callable: 414 """ 415 Overview: 416 Wandb visualizer to track the experiment. 417 Arguments: 418 - record_path (:obj:`str`): The path to save the replay of simulation. 419 - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: 420 - gradient_logger: boolean. Whether to track the gradient. 421 - plot_logger: boolean. Whether to track the metrics like reward and loss. 422 - video_logger: boolean. Whether to upload the rendering video replay. 423 - action_logger: boolean. `q_value` or `action probability`. 424 - return_logger: boolean. Whether to track the return value. 425 - vis_dataset: boolean. Whether to visualize the dataset. 426 - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. 427 - env (:obj:`BaseEnvManagerV2`): Evaluator environment. 428 - model (:obj:`nn.Module`): Policy neural network model. 429 - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ 430 of data without wandb count. 431 - project_name (:obj:`str`): The name of wandb project. 432 - run_name (:obj:`str`): The name of wandb run. 433 - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. 434 ''' 435 Returns: 436 - _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. 437 """ 438 if task.router.is_active and not task.has_role(task.role.LEARNER): 439 return task.void() 440 color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] 441 if metric_list is None: 442 metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] 443 # Initialize wandb with default settings 444 # Settings can be covered by calling wandb.init() at the top of the script 445 if exp_config: 446 if not wandb_sweep: 447 if run_name is not None: 448 if anonymous: 449 wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") 450 else: 451 wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) 452 else: 453 if anonymous: 454 wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") 455 else: 456 wandb.init(project=project_name, config=exp_config, reinit=True) 457 else: 458 if run_name is not None: 459 if anonymous: 460 wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") 461 else: 462 wandb.init(project=project_name, config=exp_config, name=run_name) 463 else: 464 if anonymous: 465 wandb.init(project=project_name, config=exp_config, anonymous="must") 466 else: 467 wandb.init(project=project_name, config=exp_config) 468 else: 469 if not wandb_sweep: 470 if run_name is not None: 471 if anonymous: 472 wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") 473 else: 474 wandb.init(project=project_name, reinit=True, name=run_name) 475 else: 476 if anonymous: 477 wandb.init(project=project_name, reinit=True, anonymous="must") 478 else: 479 wandb.init(project=project_name, reinit=True) 480 else: 481 if run_name is not None: 482 if anonymous: 483 wandb.init(project=project_name, name=run_name, anonymous="must") 484 else: 485 wandb.init(project=project_name, name=run_name) 486 else: 487 if anonymous: 488 wandb.init(project=project_name, anonymous="must") 489 else: 490 wandb.init(project=project_name) 491 plt.switch_backend('agg') 492 plt.switch_backend('agg') 493 if cfg is None: 494 cfg = EasyDict( 495 dict( 496 gradient_logger=False, 497 plot_logger=True, 498 video_logger=False, 499 action_logger=False, 500 return_logger=False, 501 vis_dataset=True, 502 ) 503 ) 504 else: 505 if not isinstance(cfg, EasyDict): 506 cfg = EasyDict(cfg) 507 for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: 508 if key not in cfg.keys(): 509 cfg[key] = False 510 511 # The visualizer is called to save the replay of the simulation 512 # which will be uploaded to wandb later 513 if env is not None and cfg.video_logger is True and record_path is not None: 514 env.enable_save_replay(replay_path=record_path) 515 if cfg.gradient_logger: 516 wandb.watch(model, log="all", log_freq=100, log_graph=True) 517 else: 518 one_time_warning( 519 "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." 520 ) 521 522 first_plot = True 523 524 def _vis_dataset(datasetpath: str): 525 try: 526 from sklearn.manifold import TSNE 527 except ImportError: 528 import sys 529 logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.") 530 sys.exit(1) 531 try: 532 import h5py 533 except ImportError: 534 import sys 535 logging.warning("Please install h5py first, such as `pip3 install h5py`.") 536 sys.exit(1) 537 assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5'] 538 if os.path.splitext(datasetpath)[-1] == '.pkl': 539 with open(datasetpath, 'rb') as f: 540 data = pickle.load(f) 541 obs = [] 542 action = [] 543 reward = [] 544 for i in range(len(data)): 545 obs.extend(data[i]['observations']) 546 action.extend(data[i]['actions']) 547 reward.extend(data[i]['rewards']) 548 elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']: 549 with h5py.File(datasetpath, 'r') as f: 550 obs = f['obs'][()] 551 action = f['action'][()] 552 reward = f['reward'][()] 553 554 cmap = plt.cm.hsv 555 obs = np.array(obs) 556 reward = np.array(reward) 557 obs_action = np.hstack((obs, np.array(action))) 558 reward = reward / (max(reward) - min(reward)) 559 560 embedded_obs = TSNE(n_components=2).fit_transform(obs) 561 embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action) 562 x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0) 563 embedded_obs = embedded_obs / (x_max - x_min) 564 565 x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0) 566 embedded_obs_action = embedded_obs_action / (x_max - x_min) 567 568 fig = plt.figure() 569 f, axes = plt.subplots(nrows=1, ncols=3) 570 571 axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward)) 572 axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action)) 573 axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward)) 574 axes[0].set_title('state-reward') 575 axes[1].set_title('state-action') 576 axes[2].set_title('stateAction-reward') 577 plt.savefig('dataset.png') 578 579 wandb.log({"dataset": wandb.Image("dataset.png")}) 580 581 if cfg.vis_dataset is True: 582 _vis_dataset(exp_config.dataset_path) 583 584 def _plot(ctx: "OfflineRLContext"): 585 nonlocal first_plot 586 if first_plot: 587 first_plot = False 588 ctx.wandb_url = wandb.run.get_project_url() 589 590 info_for_logging = {} 591 592 if cfg.plot_logger: 593 for metric in metric_list: 594 if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: 595 if isinstance(ctx.train_output[metric], torch.Tensor): 596 info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) 597 else: 598 info_for_logging.update({metric: ctx.train_output[metric]}) 599 elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: 600 metric_value_list = [] 601 for item in ctx.train_output: 602 if isinstance(item[metric], torch.Tensor): 603 metric_value_list.append(item[metric].cpu().detach().numpy()) 604 else: 605 metric_value_list.append(item[metric]) 606 metric_value = np.mean(metric_value_list) 607 info_for_logging.update({metric: metric_value}) 608 else: 609 one_time_warning( 610 "If you want to use wandb to visualize the result, please set plot_logger = True in the config." 611 ) 612 613 if ctx.eval_value != -np.inf: 614 if hasattr(ctx, "info_for_logging"): 615 """ 616 .. note:: 617 The info_for_logging is a dict that contains the information to be logged. 618 Users can add their own information to the dict. 619 All the information in the dict will be logged to wandb. 620 """ 621 info_for_logging.update(ctx.info_for_logging) 622 623 if hasattr(ctx, "eval_value_min"): 624 info_for_logging.update({ 625 "episode return min": ctx.eval_value_min, 626 }) 627 if hasattr(ctx, "eval_value_max"): 628 info_for_logging.update({ 629 "episode return max": ctx.eval_value_max, 630 }) 631 if hasattr(ctx, "eval_value_std"): 632 info_for_logging.update({ 633 "episode return std": ctx.eval_value_std, 634 }) 635 if hasattr(ctx, "eval_value"): 636 info_for_logging.update({ 637 "episode return mean": ctx.eval_value, 638 }) 639 if hasattr(ctx, "train_iter"): 640 info_for_logging.update({ 641 "train iter": ctx.train_iter, 642 }) 643 if hasattr(ctx, "train_epoch"): 644 info_for_logging.update({ 645 "train_epoch": ctx.train_epoch, 646 }) 647 648 eval_output = ctx.eval_output['output'] 649 episode_return = ctx.eval_output['episode_return'] 650 episode_return = np.array(episode_return) 651 if len(episode_return.shape) == 2: 652 episode_return = episode_return.squeeze(1) 653 654 if cfg.video_logger: 655 if 'replay_video' in ctx.eval_output: 656 # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format 657 # The numpy tensor must be either 4 dimensional or 5 dimensional. 658 # Channels should be (time, channel, height, width) or (batch, time, channel, height width) 659 video_images = ctx.eval_output['replay_video'] 660 video_images = video_images.astype(np.uint8) 661 info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) 662 elif record_path is not None: 663 file_list = [] 664 for p in os.listdir(record_path): 665 if os.path.splitext(p)[-1] == ".mp4": 666 file_list.append(p) 667 file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) 668 video_path = os.path.join(record_path, file_list[-2]) 669 info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) 670 671 if cfg.action_logger: 672 action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif")) 673 if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): 674 if isinstance(eval_output, tnp.ndarray): 675 action_prob = softmax(eval_output.logit) 676 else: 677 action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] 678 fig, ax = plt.subplots() 679 plt.ylim([-1, 1]) 680 action_dim = len(action_prob[1]) 681 x_range = [str(x + 1) for x in range(action_dim)] 682 ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) 683 ani = animation.FuncAnimation( 684 fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) 685 ) 686 ani.save(action_path, writer='pillow') 687 info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) 688 689 elif all(['action' in v for v in eval_output[0]]): 690 for i, action_trajectory in enumerate(eval_output): 691 fig, ax = plt.subplots() 692 fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) 693 steps = fig_data[:, 0] 694 actions = fig_data[:, 1:] 695 plt.ylim([-1, 1]) 696 for j in range(actions.shape[1]): 697 ax.scatter(steps, actions[:, j]) 698 info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) 699 700 if cfg.return_logger: 701 return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif")) 702 fig, ax = plt.subplots() 703 ax = plt.gca() 704 ax.set_ylim([0, 1]) 705 hist, x_dim = return_distribution(episode_return) 706 assert len(hist) == len(x_dim) 707 ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) 708 ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) 709 ani.save(return_path, writer='pillow') 710 info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) 711 712 if bool(info_for_logging): 713 wandb.log(data=info_for_logging, step=ctx.trained_env_step) 714 plt.clf() 715 716 return _plot