1from typing import Dict, Any, Callable 2from collections import namedtuple 3from easydict import EasyDict 4import gym 5import gymnasium 6import torch 7 8from ding.torch_utils import to_device 9 10 11class PolicyFactory: 12 """ 13 Overview: 14 Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ 15 which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. 16 Interfaces: 17 ``get_random_policy`` 18 """ 19 20 @staticmethod 21 def get_random_policy( 22 policy: 'Policy.collect_mode', # noqa 23 action_space: 'gym.spaces.Space' = None, # noqa 24 forward_fn: Callable = None, 25 ) -> 'Policy.collect_mode': # noqa 26 """ 27 Overview: 28 According to the given action space, define the forward function of the random policy, then pack it with \ 29 other interfaces of the given policy, and return the final collect mode interfaces of policy. 30 Arguments: 31 - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. 32 - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. 33 - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ 34 and pass it to this function, note you should set ``action_space`` to ``None`` in this case. 35 Returns: 36 - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. 37 """ 38 assert not (action_space is None and forward_fn is None) 39 random_collect_function = namedtuple( 40 'random_collect_function', [ 41 'forward', 42 'process_transition', 43 'get_train_sample', 44 'reset', 45 'get_attribute', 46 ] 47 ) 48 49 def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: 50 51 actions = {} 52 for env_id in data: 53 if isinstance(action_space, list): 54 if 'global_state' in data[env_id].keys(): 55 # for smac 56 logit = torch.ones_like(data[env_id]['action_mask']) 57 logit[data[env_id]['action_mask'] == 0.0] = -1e8 58 dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit)) 59 actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)} 60 else: 61 # for gfootball 62 actions[env_id] = { 63 'action': torch.as_tensor( 64 [action_space_agent.sample() for action_space_agent in action_space] 65 ), 66 'logit': torch.ones([len(action_space), action_space[0].n]) 67 } 68 elif isinstance(action_space, gymnasium.spaces.Dict): # pettingzoo 69 actions[env_id] = { 70 'action': torch.as_tensor( 71 [action_space_agent.sample() for action_space_agent in action_space.values()] 72 ) 73 } 74 else: 75 if isinstance(action_space, gym.spaces.Discrete): 76 action = torch.LongTensor([action_space.sample()]) 77 elif isinstance(action_space, gym.spaces.MultiDiscrete): 78 action = [torch.LongTensor([v]) for v in action_space.sample()] 79 else: 80 action = torch.as_tensor(action_space.sample()) 81 actions[env_id] = {'action': action} 82 return actions 83 84 def reset(*args, **kwargs) -> None: 85 pass 86 87 if action_space is None: 88 return random_collect_function( 89 forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute 90 ) 91 elif forward_fn is None: 92 return random_collect_function( 93 forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute 94 ) 95 96 97def get_random_policy( 98 cfg: EasyDict, 99 policy: 'Policy.collect_mode', # noqa 100 env: 'BaseEnvManager' # noqa 101) -> 'Policy.collect_mode': # noqa 102 """ 103 Overview: 104 The entry function to get the corresponding random policy. If a policy needs special data items in a \ 105 transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. 106 Arguments: 107 - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. 108 - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. 109 - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ 110 action generation. 111 Returns: 112 - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. 113 """ 114 if cfg.policy.get('transition_with_policy_data', False): 115 return policy 116 else: 117 action_space = env.action_space 118 return PolicyFactory.get_random_policy(policy, action_space=action_space)