Skip to content

ding.policy.policy_factory

ding.policy.policy_factory

PolicyFactory

Overview

Policy factory class, used to generate different policies for general purpose. Such as random action policy, which is used for initial sample collecting for better exploration when random_collect_size > 0.

Interfaces: get_random_policy

get_random_policy(policy, action_space=None, forward_fn=None) staticmethod

Overview

According to the given action space, define the forward function of the random policy, then pack it with other interfaces of the given policy, and return the final collect mode interfaces of policy.

Arguments: - policy (:obj:Policy.collect_mode): The collect mode interfaces of the policy. - action_space (:obj:gym.spaces.Space): The action space of the environment, gym-style. - forward_fn (:obj:Callable): It action space is too complex, you can define your own forward function and pass it to this function, note you should set action_space to None in this case. Returns: - random_policy (:obj:Policy.collect_mode): The collect mode intefaces of the random policy.

get_random_policy(cfg, policy, env)

Overview

The entry function to get the corresponding random policy. If a policy needs special data items in a transition, then return itself, otherwise, we will use PolicyFactory to return a general random policy.

Arguments: - cfg (:obj:EasyDict): The EasyDict-type dict configuration. - policy (:obj:Policy.collect_mode): The collect mode interfaces of the policy. - env (:obj:BaseEnvManager): The env manager instance, which is used to get the action space for random action generation. Returns: - random_policy (:obj:Policy.collect_mode): The collect mode intefaces of the random policy.

Full Source Code

../ding/policy/policy_factory.py

1from typing import Dict, Any, Callable 2from collections import namedtuple 3from easydict import EasyDict 4import gym 5import gymnasium 6import torch 7 8from ding.torch_utils import to_device 9 10 11class PolicyFactory: 12 """ 13 Overview: 14 Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ 15 which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. 16 Interfaces: 17 ``get_random_policy`` 18 """ 19 20 @staticmethod 21 def get_random_policy( 22 policy: 'Policy.collect_mode', # noqa 23 action_space: 'gym.spaces.Space' = None, # noqa 24 forward_fn: Callable = None, 25 ) -> 'Policy.collect_mode': # noqa 26 """ 27 Overview: 28 According to the given action space, define the forward function of the random policy, then pack it with \ 29 other interfaces of the given policy, and return the final collect mode interfaces of policy. 30 Arguments: 31 - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. 32 - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. 33 - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ 34 and pass it to this function, note you should set ``action_space`` to ``None`` in this case. 35 Returns: 36 - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. 37 """ 38 assert not (action_space is None and forward_fn is None) 39 random_collect_function = namedtuple( 40 'random_collect_function', [ 41 'forward', 42 'process_transition', 43 'get_train_sample', 44 'reset', 45 'get_attribute', 46 ] 47 ) 48 49 def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: 50 51 actions = {} 52 for env_id in data: 53 if isinstance(action_space, list): 54 if 'global_state' in data[env_id].keys(): 55 # for smac 56 logit = torch.ones_like(data[env_id]['action_mask']) 57 logit[data[env_id]['action_mask'] == 0.0] = -1e8 58 dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit)) 59 actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)} 60 else: 61 # for gfootball 62 actions[env_id] = { 63 'action': torch.as_tensor( 64 [action_space_agent.sample() for action_space_agent in action_space] 65 ), 66 'logit': torch.ones([len(action_space), action_space[0].n]) 67 } 68 elif isinstance(action_space, gymnasium.spaces.Dict): # pettingzoo 69 actions[env_id] = { 70 'action': torch.as_tensor( 71 [action_space_agent.sample() for action_space_agent in action_space.values()] 72 ) 73 } 74 else: 75 if isinstance(action_space, gym.spaces.Discrete): 76 action = torch.LongTensor([action_space.sample()]) 77 elif isinstance(action_space, gym.spaces.MultiDiscrete): 78 action = [torch.LongTensor([v]) for v in action_space.sample()] 79 else: 80 action = torch.as_tensor(action_space.sample()) 81 actions[env_id] = {'action': action} 82 return actions 83 84 def reset(*args, **kwargs) -> None: 85 pass 86 87 if action_space is None: 88 return random_collect_function( 89 forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute 90 ) 91 elif forward_fn is None: 92 return random_collect_function( 93 forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute 94 ) 95 96 97def get_random_policy( 98 cfg: EasyDict, 99 policy: 'Policy.collect_mode', # noqa 100 env: 'BaseEnvManager' # noqa 101) -> 'Policy.collect_mode': # noqa 102 """ 103 Overview: 104 The entry function to get the corresponding random policy. If a policy needs special data items in a \ 105 transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. 106 Arguments: 107 - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. 108 - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. 109 - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ 110 action generation. 111 Returns: 112 - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. 113 """ 114 if cfg.policy.get('transition_with_policy_data', False): 115 return policy 116 else: 117 action_space = env.action_space 118 return PolicyFactory.get_random_policy(policy, action_space=action_space)