ding.example.ppo_with_complex_obs¶
ding.example.ppo_with_complex_obs
¶
Full Source Code
../ding/example/ppo_with_complex_obs.py
1from typing import Dict 2import os 3import torch 4import torch.nn as nn 5import numpy as np 6import gym 7from gym import spaces 8from ditk import logging 9from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, \ 10 BaseEnvManagerV2 11from ding.config import compile_config 12from ding.policy import PPOPolicy 13from ding.utils import set_pkg_seed 14from ding.model import VAC 15from ding.framework import task, ding_init 16from ding.framework.context import OnlineRLContext 17from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ 18 gae_estimator, online_logger 19from easydict import EasyDict 20 21my_env_ppo_config = dict( 22 exp_name='my_env_ppo_seed0', 23 env=dict( 24 collector_env_num=4, 25 evaluator_env_num=4, 26 n_evaluator_episode=4, 27 stop_value=195, 28 ), 29 policy=dict( 30 cuda=True, 31 action_space='discrete', 32 model=dict( 33 obs_shape=dict( 34 key_0=dict(k1=(), k2=()), 35 key_1=(5, 10), 36 key_2=(10, 10, 3), 37 key_3=(2, ), 38 ), 39 action_shape=2, 40 action_space='discrete', 41 critic_head_hidden_size=138, 42 actor_head_hidden_size=138, 43 ), 44 learn=dict( 45 epoch_per_collect=2, 46 batch_size=64, 47 learning_rate=0.001, 48 value_weight=0.5, 49 entropy_weight=0.01, 50 clip_ratio=0.2, 51 learner=dict(hook=dict(save_ckpt_after_iter=100)), 52 ), 53 collect=dict( 54 n_sample=256, unroll_len=1, discount_factor=0.9, gae_lambda=0.95, collector=dict(transform_obs=True, ) 55 ), 56 eval=dict(evaluator=dict(eval_freq=100, ), ), 57 ), 58) 59my_env_ppo_config = EasyDict(my_env_ppo_config) 60main_config = my_env_ppo_config 61my_env_ppo_create_config = dict( 62 env_manager=dict(type='base'), 63 policy=dict(type='ppo'), 64) 65my_env_ppo_create_config = EasyDict(my_env_ppo_create_config) 66create_config = my_env_ppo_create_config 67 68 69class MyEnv(gym.Env): 70 71 def __init__(self, seq_len=5, feature_dim=10, image_size=(10, 10, 3)): 72 super().__init__() 73 74 # Define the action space 75 self.action_space = spaces.Discrete(2) 76 77 # Define the observation space 78 self.observation_space = spaces.Dict( 79 ( 80 { 81 'key_0': spaces.Dict( 82 { 83 'k1': spaces.Box(low=0, high=np.inf, shape=(1, ), dtype=np.float32), 84 'k2': spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32), 85 } 86 ), 87 'key_1': spaces.Box(low=-np.inf, high=np.inf, shape=(seq_len, feature_dim), dtype=np.float32), 88 'key_2': spaces.Box(low=0, high=255, shape=image_size, dtype=np.uint8), 89 'key_3': spaces.Box(low=0, high=np.array([np.inf, 3]), shape=(2, ), dtype=np.float32) 90 } 91 ) 92 ) 93 94 def reset(self): 95 # Generate a random initial state 96 return self.observation_space.sample() 97 98 def step(self, action): 99 # Compute the reward and done flag (which are not used in this example) 100 reward = np.random.uniform(low=0.0, high=1.0) 101 102 done = False 103 if np.random.uniform(low=0.0, high=1.0) > 0.7: 104 done = True 105 106 info = {} 107 108 # Return the next state, reward, and done flag 109 return self.observation_space.sample(), reward, done, info 110 111 112def ding_env_maker(): 113 return DingEnvWrapper( 114 MyEnv(), cfg={'env_wrapper': [ 115 lambda env: EvalEpisodeReturnWrapper(env), 116 ]} 117 ) 118 119 120class Encoder(nn.Module): 121 122 def __init__(self, feature_dim: int): 123 super(Encoder, self).__init__() 124 125 # Define the networks for each input type 126 self.fc_net_1_k1 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) 127 self.fc_net_1_k2 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) 128 self.fc_net_1 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 129 """ 130 Implementation of transformer_encoder refers to Vision Transformer (ViT) code: 131 https://arxiv.org/abs/2010.11929 132 https://pytorch.org/vision/main/_modules/torchvision/models/vision_transformer.html 133 """ 134 self.class_token = nn.Parameter(torch.zeros(1, 1, feature_dim)) 135 self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=2, batch_first=True) 136 self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1) 137 138 self.conv_net = nn.Sequential( 139 nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1), 140 nn.ReLU() 141 ) 142 self.conv_fc_net = nn.Sequential(nn.Flatten(), nn.Linear(3200, 64), nn.ReLU()) 143 144 self.fc_net_2 = nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 32), nn.ReLU(), nn.Flatten()) 145 146 def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: 147 # Unpack the input tuple 148 dict_input = inputs['key_0'] # dict{key:(B)} 149 transformer_input = inputs['key_1'] # (B, seq_len, feature_dim) 150 conv_input = inputs['key_2'] # (B, H, W, 3) 151 fc_input = inputs['key_3'] # (B, X) 152 153 B = fc_input.shape[0] 154 155 # Pass each input through its corresponding network 156 dict_output = self.fc_net_1( 157 torch.cat( 158 [self.fc_net_1_k1(dict_input['k1'].unsqueeze(-1)), 159 self.fc_net_1_k2(dict_input['k2'].unsqueeze(-1))], 160 dim=1 161 ) 162 ) 163 164 batch_class_token = self.class_token.expand(B, -1, -1) 165 transformer_output = self.transformer_encoder(torch.cat([batch_class_token, transformer_input], dim=1)) 166 transformer_output = transformer_output[:, 0] 167 168 conv_output = self.conv_fc_net(self.conv_net(conv_input.permute(0, 3, 1, 2))) 169 fc_output = self.fc_net_2(fc_input) 170 171 # Concatenate the outputs along the feature dimension 172 encoded_output = torch.cat([dict_output, transformer_output, conv_output, fc_output], dim=1) 173 174 return encoded_output 175 176 177def main(): 178 logging.getLogger().setLevel(logging.INFO) 179 cfg = compile_config(main_config, create_cfg=create_config, auto=True) 180 ding_init(cfg) 181 with task.start(async_mode=False, ctx=OnlineRLContext()): 182 collector_env = BaseEnvManagerV2( 183 env_fn=[ding_env_maker for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager 184 ) 185 evaluator_env = BaseEnvManagerV2( 186 env_fn=[ding_env_maker for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager 187 ) 188 189 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 190 191 encoder = Encoder(feature_dim=10) 192 model = VAC(encoder=encoder, **cfg.policy.model) 193 policy = PPOPolicy(cfg.policy, model=model) 194 195 task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) 196 task.use(StepCollector(cfg, policy.collect_mode, collector_env)) 197 task.use(gae_estimator(cfg, policy.collect_mode)) 198 task.use(multistep_trainer(policy.learn_mode, log_freq=50)) 199 task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) 200 task.use(online_logger(train_show_freq=3)) 201 task.run() 202 203 204if __name__ == "__main__": 205 main()