Skip to content

ding.entry.utils

ding.entry.utils

maybe_init_wandb(cfg)

Overview

Optionally initialize a wandb run for serial pipelines when enabled in config.

Full Source Code

../ding/entry/utils.py

1import os 2import re 3from datetime import datetime 4from typing import Optional, Callable, List, Any 5 6from ditk import logging 7from ding.policy import PolicyFactory 8from ding.worker import IMetric, MetricSerialEvaluator 9 10 11class AccMetric(IMetric): 12 13 def eval(self, inputs: Any, label: Any) -> dict: 14 return { 15 "Acc": (inputs["logit"].sum(dim=1) == label).sum().item() / label.shape[0] 16 } 17 18 def reduce_mean(self, inputs: List[Any]) -> Any: 19 s = 0 20 for item in inputs: 21 s += item["Acc"] 22 return {"Acc": s / len(inputs)} 23 24 def gt(self, metric1: Any, metric2: Any) -> bool: 25 if metric2 is None: 26 return True 27 if isinstance(metric2, dict): 28 m2 = metric2["Acc"] 29 else: 30 m2 = metric2 31 return metric1["Acc"] > m2 32 33 34def mark_not_expert(ori_data: List[dict]) -> List[dict]: 35 for i in range(len(ori_data)): 36 # Set is_expert flag (expert 1, agent 0) 37 ori_data[i]["is_expert"] = 0 38 return ori_data 39 40 41def mark_warm_up(ori_data: List[dict]) -> List[dict]: 42 # for td3_vae 43 for i in range(len(ori_data)): 44 ori_data[i]["warm_up"] = True 45 return ori_data 46 47 48def _sanitize_name_piece(value: Any, default: str) -> str: 49 text = str(value).strip() if value is not None else default 50 if not text: 51 text = default 52 text = text.replace("/", "-").replace("\\", "-") 53 text = re.sub(r"\s+", "", text) 54 text = re.sub(r"[^0-9A-Za-z_.-]+", "-", text) 55 text = re.sub(r"-{2,}", "-", text).strip("-_") 56 return text or default 57 58 59def _default_wandb_run_name(cfg: "EasyDict") -> str: # noqa 60 algo = _sanitize_name_piece( 61 '-'.join(cfg.get("exp_name").split("_")[1:3]), default="algo" 62 ).lower() 63 env_name = cfg.get("env", {}).get("env_id", None) 64 if env_name is None: 65 env_name = cfg.get("env", {}).get("type", None) 66 env = _sanitize_name_piece(env_name, default="env") 67 timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 68 return f"{algo}_{env}_{timestamp}" 69 70 71def _to_wandb_config(value: Any) -> Any: 72 if value is None or isinstance(value, (str, int, float, bool)): 73 return value 74 if isinstance(value, dict): 75 return {str(k): _to_wandb_config(v) for k, v in value.items()} 76 if isinstance(value, (list, tuple, set)): 77 return [_to_wandb_config(v) for v in value] 78 if isinstance(value, bytes): 79 try: 80 return value.decode("utf-8") 81 except Exception: 82 return str(value) 83 if hasattr(value, "item") and callable(getattr(value, "item")): 84 try: 85 return _to_wandb_config(value.item()) 86 except Exception: 87 pass 88 return str(value) 89 90 91def maybe_init_wandb(cfg: "EasyDict") -> Optional[Any]: # noqa 92 """ 93 Overview: 94 Optionally initialize a wandb run for serial pipelines when enabled in config. 95 """ 96 wandb_cfg = cfg.get("wandb_logger", None) 97 if wandb_cfg is None or not wandb_cfg.get("enabled", False): 98 return None 99 100 try: 101 import wandb 102 except ImportError: 103 logging.warning("wandb is not installed, skip wandb logging.") 104 return None 105 106 project_name = wandb_cfg.get( 107 "project_name", os.getenv("WANDB_PROJECT", "DI-engine") 108 ) 109 run_name = wandb_cfg.get("run_name", None) or _default_wandb_run_name(cfg) 110 entity = wandb_cfg.get("entity", os.getenv("WANDB_ENTITY")) 111 wandb_dir = os.path.abspath(wandb_cfg.get("dir", cfg.exp_name)) 112 os.makedirs(wandb_dir, exist_ok=True) 113 114 init_kwargs = dict( 115 project=project_name, 116 name=run_name, 117 sync_tensorboard=wandb_cfg.get("sync_tensorboard", True), 118 reinit=True, 119 dir=wandb_dir, 120 ) 121 if wandb_cfg.get("log_config", True): 122 init_kwargs["config"] = _to_wandb_config(cfg) 123 if entity: 124 init_kwargs["entity"] = entity 125 if wandb_cfg.get("group", None) is not None: 126 init_kwargs["group"] = wandb_cfg.get("group") 127 if wandb_cfg.get("job_type", None) is not None: 128 init_kwargs["job_type"] = wandb_cfg.get("job_type") 129 if wandb_cfg.get("mode", None) is not None: 130 init_kwargs["mode"] = wandb_cfg.get("mode") 131 if wandb_cfg.get("notes", None) is not None: 132 init_kwargs["notes"] = wandb_cfg.get("notes") 133 if wandb_cfg.get("anonymous", None) is not None: 134 init_kwargs["anonymous"] = wandb_cfg.get("anonymous") 135 if wandb_cfg.get("tags", None) is not None: 136 tags = wandb_cfg.get("tags") 137 init_kwargs["tags"] = ( 138 list(tags) if isinstance(tags, (list, tuple, set)) else [str(tags)] 139 ) 140 141 try: 142 run = wandb.init(**init_kwargs) 143 logging.info( 144 "wandb logging enabled: project=%s, run=%s", project_name, run_name 145 ) 146 return run 147 except Exception as e: 148 logging.warning("wandb init failed, continue without wandb logging: %s", e) 149 return None 150 151 152def maybe_finish_wandb(wandb_run: Optional[Any]) -> None: 153 if wandb_run is None: 154 return 155 try: 156 wandb_run.finish() 157 except Exception as e: 158 logging.warning("wandb finish failed: %s", e) 159 160 161def random_collect( 162 policy_cfg: "EasyDict", # noqa 163 policy: "Policy", # noqa 164 collector: "ISerialCollector", # noqa 165 collector_env: "BaseEnvManager", # noqa 166 commander: "BaseSerialCommander", # noqa 167 replay_buffer: "IBuffer", # noqa 168 postprocess_data_fn: Optional[Callable] = None, 169) -> None: # noqa 170 assert policy_cfg.random_collect_size > 0 171 if policy_cfg.get("transition_with_policy_data", False): 172 collector.reset_policy(policy.collect_mode) 173 else: 174 action_space = collector_env.action_space 175 random_policy = PolicyFactory.get_random_policy( 176 policy.collect_mode, action_space=action_space 177 ) 178 collector.reset_policy(random_policy) 179 collect_kwargs = commander.step() 180 if policy_cfg.collect.collector.type == "episode": 181 new_data = collector.collect( 182 n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs 183 ) 184 else: 185 new_data = collector.collect( 186 n_sample=policy_cfg.random_collect_size, 187 random_collect=True, 188 record_random_collect=False, 189 policy_kwargs=collect_kwargs, 190 ) # 'record_random_collect=False' means random collect without output log 191 if postprocess_data_fn is not None: 192 new_data = postprocess_data_fn(new_data) 193 replay_buffer.push(new_data, cur_collector_envstep=0) 194 collector.reset_policy(policy.collect_mode)