Skip to content

ding.rl_utils.sampler

ding.rl_utils.sampler

ArgmaxSampler

Overview

Argmax sampler, return the index of the maximum value

__call__(logit)

Overview

Return the index of the maximum value

Arguments: - logit (:obj:torch.Tensor): The input tensor Returns: - action (:obj:torch.Tensor): The index of the maximum value

MultinomialSampler

Overview

Multinomial sampler, return the index of the sampled value

__call__(logit)

Overview

Return the index of the sampled value

Arguments: - logit (:obj:torch.Tensor): The input tensor Returns: - action (:obj:torch.Tensor): The index of the sampled value

MuSampler

Overview

Mu sampler, return the mu of the input tensor

__call__(logit)

Overview

Return the mu of the input tensor

Arguments: - logit (:obj:ttorch.Tensor): The input tensor Returns: - action (:obj:torch.Tensor): The mu of the input tensor

ReparameterizationSampler

Overview

Reparameterization sampler, return the reparameterized value of the input tensor

__call__(logit)

Overview

Return the reparameterized value of the input tensor

Arguments: - logit (:obj:ttorch.Tensor): The input tensor Returns: - action (:obj:torch.Tensor): The reparameterized value of the input tensor

HybridStochasticSampler

Overview

Hybrid stochastic sampler, return the sampled action type and the reparameterized action args

__call__(logit)

Overview

Return the sampled action type and the reparameterized action args

Arguments: - logit (:obj:ttorch.Tensor): The input tensor Returns: - action (:obj:ttorch.Tensor): The sampled action type and the reparameterized action args

HybridDeterminsticSampler

Overview

Hybrid deterministic sampler, return the argmax action type and the mu action args

__call__(logit)

Overview

Return the argmax action type and the mu action args

Arguments: - logit (:obj:ttorch.Tensor): The input tensor Returns: - action (:obj:ttorch.Tensor): The argmax action type and the mu action args

Full Source Code

../ding/rl_utils/sampler.py

1import torch 2import treetensor.torch as ttorch 3from torch.distributions import Normal, Independent 4 5 6class ArgmaxSampler: 7 ''' 8 Overview: 9 Argmax sampler, return the index of the maximum value 10 ''' 11 12 def __call__(self, logit: torch.Tensor) -> torch.Tensor: 13 ''' 14 Overview: 15 Return the index of the maximum value 16 Arguments: 17 - logit (:obj:`torch.Tensor`): The input tensor 18 Returns: 19 - action (:obj:`torch.Tensor`): The index of the maximum value 20 ''' 21 return logit.argmax(dim=-1) 22 23 24class MultinomialSampler: 25 ''' 26 Overview: 27 Multinomial sampler, return the index of the sampled value 28 ''' 29 30 def __call__(self, logit: torch.Tensor) -> torch.Tensor: 31 ''' 32 Overview: 33 Return the index of the sampled value 34 Arguments: 35 - logit (:obj:`torch.Tensor`): The input tensor 36 Returns: 37 - action (:obj:`torch.Tensor`): The index of the sampled value 38 ''' 39 dist = torch.distributions.Categorical(logits=logit) 40 return dist.sample() 41 42 43class MuSampler: 44 ''' 45 Overview: 46 Mu sampler, return the mu of the input tensor 47 ''' 48 49 def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: 50 ''' 51 Overview: 52 Return the mu of the input tensor 53 Arguments: 54 - logit (:obj:`ttorch.Tensor`): The input tensor 55 Returns: 56 - action (:obj:`torch.Tensor`): The mu of the input tensor 57 ''' 58 return logit.mu 59 60 61class ReparameterizationSampler: 62 ''' 63 Overview: 64 Reparameterization sampler, return the reparameterized value of the input tensor 65 ''' 66 67 def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: 68 ''' 69 Overview: 70 Return the reparameterized value of the input tensor 71 Arguments: 72 - logit (:obj:`ttorch.Tensor`): The input tensor 73 Returns: 74 - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor 75 ''' 76 dist = Normal(logit.mu, logit.sigma) 77 dist = Independent(dist, 1) 78 return dist.rsample() 79 80 81class HybridStochasticSampler: 82 ''' 83 Overview: 84 Hybrid stochastic sampler, return the sampled action type and the reparameterized action args 85 ''' 86 87 def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: 88 ''' 89 Overview: 90 Return the sampled action type and the reparameterized action args 91 Arguments: 92 - logit (:obj:`ttorch.Tensor`): The input tensor 93 Returns: 94 - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args 95 ''' 96 dist = torch.distributions.Categorical(logits=logit.action_type) 97 action_type = dist.sample() 98 dist = Normal(logit.action_args.mu, logit.action_args.sigma) 99 dist = Independent(dist, 1) 100 action_args = dist.rsample() 101 return ttorch.as_tensor({ 102 'action_type': action_type, 103 'action_args': action_args, 104 }) 105 106 107class HybridDeterminsticSampler: 108 ''' 109 Overview: 110 Hybrid deterministic sampler, return the argmax action type and the mu action args 111 ''' 112 113 def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: 114 ''' 115 Overview: 116 Return the argmax action type and the mu action args 117 Arguments: 118 - logit (:obj:`ttorch.Tensor`): The input tensor 119 Returns: 120 - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args 121 ''' 122 action_type = logit.action_type.argmax(dim=-1) 123 action_args = logit.action_args.mu 124 return ttorch.as_tensor({ 125 'action_type': action_type, 126 'action_args': action_args, 127 })