ding.reward_model.gail_irl_model¶
ding.reward_model.gail_irl_model
¶
GailRewardModel
¶
Bases: BaseRewardModel
Overview
The Gail reward model class (https://arxiv.org/abs/1606.03476)
Interface:
estimate, train, load_expert_data, collect_data, clear_date, __init__, state_dict, load_state_dict, learn
Config:
== ==================== ======== ============= =================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= =================================== =======================
1 type str gail | RL policy register name, refer | this arg is optional,
| to registry POLICY_REGISTRY | a placeholder
2 | expert_data_ str expert_data. | Path to the expert dataset | Should be a '.pkl'
| path .pkl | | file
3 | learning_rate float 0.001 | The step size of gradient descent |
4 | update_per_ int 100 | Number of updates per collect |
| collect | |
5 | batch_size int 64 | Training batch size |
6 | input_size int | Size of the input: |
| | obs_dim + act_dim |
7 | target_new_ int 64 | Collect steps per iteration |
| data_count | |
8 | hidden_size int 128 | Linear model hidden size |
9 | collect_count int 100000 | Expert dataset size | One entry is a (s,a)
| | | tuple
10 | clear_buffer_ int 1 | clear buffer per fixed iters | make sure replay
| per_iters | buffer's data count
| | isn't too few.
| | (code work in entry)
== ==================== ======== ============= =================================== =======================
__init__(config, device, tb_logger)
¶
Overview
Initialize self. See help(type(self)) for accurate signature.
Arguments:
- cfg (:obj:EasyDict): Training config
- device (:obj:str): Device usage, i.e. "cpu" or "cuda"
- tb_logger (:obj:SummaryWriter): Logger, defaultly set as 'SummaryWriter' for model summary
load_expert_data()
¶
Overview
Getting the expert data from config.data_path attribute in self
Effects:
This is a side effect function which updates the expert data attribute (i.e. self.expert_data) with fn:concat_state_action_pairs
learn(train_data, expert_data)
¶
Overview
Helper function for train which calculates loss for train data and expert data.
Arguments:
- train_data (:obj:torch.Tensor): Data used for training
- expert_data (:obj:torch.Tensor): Expert data
Returns:
- Combined loss calculated of reward model from using train_data and expert_data.
train()
¶
Overview
Training the Gail reward model. The training and expert data are randomly sampled with designated batch size abstracted from the batch_size attribute in self.cfg and correspondingly, the expert_data as well as train_data attributes initialized `self
Effects: - This is a side effect function which updates the reward model and increment the train iteration count.
estimate(data)
¶
Overview
Estimate reward by rewriting the reward key in each row of the data.
Arguments:
- data (:obj:list): the list of data used for estimation, with at least obs and action keys.
Effects:
- This is a side effect function which updates the reward values in place.
collect_data(data)
¶
Overview
Collecting training data formatted by fn:concat_state_action_pairs.
Arguments:
- data (:obj:Any): Raw training data (e.g. some form of states, actions, obs, etc)
Effects:
- This is a side effect function which updates the data attribute in self
clear_data()
¶
Overview
Clearing training data. This is a side effect function which clears the data attribute in self
concat_state_action_pairs(iterator)
¶
Overview
Concatenate state and action pairs from input.
Arguments:
- iterator (:obj:Iterable): Iterables with at least obs and action tensor keys.
Returns:
- res (:obj:Torch.tensor): State and action pairs.
concat_state_action_pairs_one_hot(iterator, action_size)
¶
Overview
Concatenate state and action pairs from input. Action values are one-hot encoded
Arguments:
- iterator (:obj:Iterable): Iterables with at least obs and action tensor keys.
Returns:
- res (:obj:Torch.tensor): State and action pairs.
Full Source Code
../ding/reward_model/gail_irl_model.py