ding.framework.context¶
ding.framework.context
¶
Context
dataclass
¶
Overview
Context is an object that pass contextual data between middlewares, whose life cycle is only one training iteration. It is a dict that reflect itself, so you can set any properties as you wish. Note that the initial value of the property must be equal to False.
Full Source Code
../ding/framework/context.py
1import numpy as np 2import dataclasses 3import treetensor.torch as ttorch 4from typing import Union, Dict, List 5 6 7@dataclasses.dataclass 8class Context: 9 """ 10 Overview: 11 Context is an object that pass contextual data between middlewares, whose life cycle 12 is only one training iteration. It is a dict that reflect itself, so you can set 13 any properties as you wish. 14 Note that the initial value of the property must be equal to False. 15 """ 16 _kept_keys: set = dataclasses.field(default_factory=set) 17 total_step: int = 0 18 19 def renew(self) -> 'Context': # noqa 20 """ 21 Overview: 22 Renew context from self, add total_step and shift kept properties to the new instance. 23 """ 24 total_step = self.total_step 25 ctx = type(self)() 26 for key in self._kept_keys: 27 if self.has_attr(key): 28 setattr(ctx, key, getattr(self, key)) 29 ctx.total_step = total_step + 1 30 return ctx 31 32 def keep(self, *keys: str) -> None: 33 """ 34 Overview: 35 Keep this key/keys until next iteration. 36 """ 37 for key in keys: 38 self._kept_keys.add(key) 39 40 def has_attr(self, key): 41 return hasattr(self, key) 42 43 44# TODO: Restrict data to specific types 45@dataclasses.dataclass 46class OnlineRLContext(Context): 47 48 # common 49 total_step: int = 0 50 env_step: int = 0 51 env_episode: int = 0 52 train_iter: int = 0 53 train_data: Union[Dict, List] = None 54 train_output: Union[Dict, List[Dict]] = None 55 # collect 56 collect_kwargs: Dict = dataclasses.field(default_factory=dict) 57 obs: ttorch.Tensor = None 58 action: List = None 59 inference_output: Dict[int, Dict] = None 60 trajectories: List = None 61 episodes: List = None 62 trajectory_end_idx: List = dataclasses.field(default_factory=list) 63 action: Dict = None 64 inference_output: Dict = None 65 # eval 66 eval_value: float = -np.inf 67 last_eval_iter: int = -1 68 last_eval_value: int = -np.inf 69 eval_output: List = dataclasses.field(default_factory=dict) 70 # wandb 71 info_for_logging: Dict = dataclasses.field(default_factory=dict) 72 wandb_url: str = "" 73 74 def __post_init__(self): 75 # This method is called just after __init__ method. Here, concretely speaking, 76 # this method is called just after the object initialize its fields. 77 # We use this method here to keep the fields needed for each iteration. 78 self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') 79 80 81@dataclasses.dataclass 82class OfflineRLContext(Context): 83 84 # common 85 total_step: int = 0 86 trained_env_step: int = 0 87 train_epoch: int = 0 88 train_iter: int = 0 89 train_data: Union[Dict, List] = None 90 train_output: Union[Dict, List[Dict]] = None 91 # eval 92 eval_value: float = -np.inf 93 last_eval_iter: int = -1 94 last_eval_value: int = -np.inf 95 eval_output: List = dataclasses.field(default_factory=dict) 96 # wandb 97 info_for_logging: Dict = dataclasses.field(default_factory=dict) 98 wandb_url: str = "" 99 100 def __post_init__(self): 101 # This method is called just after __init__ method. Here, concretely speaking, 102 # this method is called just after the object initialize its fields. 103 # We use this method here to keep the fields needed for each iteration. 104 self.keep('trained_env_step', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url')