ding.framework.middleware.functional.trainer¶
ding.framework.middleware.functional.trainer
¶
trainer(cfg, policy, log_freq=100)
¶
Overview
The middleware that executes a single training process.
Arguments:
- cfg (:obj:EasyDict): Config.
- policy (:obj:Policy): The policy to be trained in step-by-step mode.
- log_freq (:obj:int): The frequency (iteration) of showing log.
multistep_trainer(policy, log_freq=100)
¶
Overview
The middleware that executes training for a target num of steps.
Arguments:
- policy (:obj:Policy): The policy specialized for multi-step training.
- log_freq (:obj:int): The frequency (iteration) of showing log.
Full Source Code
../ding/framework/middleware/functional/trainer.py
1from typing import TYPE_CHECKING, Callable, Union 2from easydict import EasyDict 3import treetensor.torch as ttorch 4from ditk import logging 5import numpy as np 6from ding.policy import Policy 7from ding.framework import task, OfflineRLContext, OnlineRLContext 8 9 10def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable: 11 """ 12 Overview: 13 The middleware that executes a single training process. 14 Arguments: 15 - cfg (:obj:`EasyDict`): Config. 16 - policy (:obj:`Policy`): The policy to be trained in step-by-step mode. 17 - log_freq (:obj:`int`): The frequency (iteration) of showing log. 18 """ 19 if task.router.is_active and not task.has_role(task.role.LEARNER): 20 return task.void() 21 22 def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): 23 """ 24 Input of ctx: 25 - train_data (:obj:`Dict`): The data used to update the network. It will train only if \ 26 the data is not empty. 27 - train_iter: (:obj:`int`): The training iteration count. The log will be printed once \ 28 it reachs certain values. 29 Output of ctx: 30 - train_output (:obj:`Dict`): The training output in the Dict format, including loss info. 31 """ 32 33 if ctx.train_data is None: 34 return 35 train_output = policy.forward(ctx.train_data) 36 if ctx.train_iter % log_freq == 0: 37 if isinstance(train_output, list): 38 train_output_loss = np.mean([item['total_loss'] for item in train_output]) 39 else: 40 train_output_loss = train_output['total_loss'] 41 if isinstance(ctx, OnlineRLContext): 42 logging.info( 43 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( 44 ctx.train_iter, ctx.env_step, train_output_loss 45 ) 46 ) 47 elif isinstance(ctx, OfflineRLContext): 48 logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output_loss)) 49 else: 50 raise TypeError("not supported ctx type: {}".format(type(ctx))) 51 ctx.train_iter += 1 52 ctx.train_output = train_output 53 54 return _train 55 56 57def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable: 58 """ 59 Overview: 60 The middleware that executes training for a target num of steps. 61 Arguments: 62 - policy (:obj:`Policy`): The policy specialized for multi-step training. 63 - log_freq (:obj:`int`): The frequency (iteration) of showing log. 64 """ 65 if task.router.is_active and not task.has_role(task.role.LEARNER): 66 return task.void() 67 last_log_iter = -1 68 69 def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): 70 """ 71 Input of ctx: 72 - train_data: The data used to update the network. 73 It will train only if the data is not empty. 74 - train_iter: (:obj:`int`): The training iteration count. 75 The log will be printed if it reachs certain values. 76 Output of ctx: 77 - train_output (:obj:`List[Dict]`): The training output listed by steps. 78 """ 79 80 if ctx.train_data is None: # no enough data from data fetcher 81 return 82 if hasattr(policy, "_device"): # For ppof policy 83 data = ctx.train_data.to(policy._device) 84 elif hasattr(policy, "get_attribute"): # For other policy 85 data = ctx.train_data.to(policy.get_attribute("device")) 86 else: 87 assert AttributeError("Policy should have attribution '_device'.") 88 train_output = policy.forward(data) 89 nonlocal last_log_iter 90 if ctx.train_iter - last_log_iter >= log_freq: 91 loss = np.mean([o['total_loss'] for o in train_output]) 92 if isinstance(ctx, OfflineRLContext): 93 logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, loss)) 94 else: 95 logging.info( 96 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(ctx.train_iter, ctx.env_step, loss) 97 ) 98 last_log_iter = ctx.train_iter 99 ctx.train_iter += len(train_output) 100 ctx.train_output = train_output 101 102 return _train 103 104 105# TODO reward model