Logging
As reinforcement learning algorithms are historically challenging to debug, it’s important to pay careful attention to logging.
By default, the TRL PPOTrainer saves a lot of relevant information to wandb
or tensorboard
.
Upon initialization, pass one of these two options to the PPOConfig:
config = PPOConfig(
model_name=args.model_name,
log_with=`wandb`, # or `tensorboard`
)
If you want to log with tensorboard, add the kwarg project_kwargs={"logging_dir": PATH_TO_LOGS}
to the PPOConfig.
PPO Logging
Here’s a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
env/reward_mean
: The average reward obtained from the environment. Aliasppo/mean_scores
, which is sed to specifically monitor the reward model.env/reward_std
: The standard deviation of the reward obtained from the environment. Alias `ppo/std_scores
, which is sed to specifically monitor the reward model.env/reward_dist
: The histogram distribution of the reward obtained from the environment.objective/kl
: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.objective/kl_dist
: The histogram distribution of theobjective/kl
.objective/kl_coef
: The coefficient for Kullback-Leibler (KL) divergence in the objective function.ppo/mean_non_score_reward
: The KL penalty calculated byobjective/kl * objective/kl_coef
as the total reward for optimization to prevent the new policy from deviating too far from the old policy.objective/entropy
: The entropy of the model’s policy, calculated by-logprobs.sum(-1).mean()
. High entropy means the model’s actions are more random, which can be beneficial for exploration.
Training stats:
ppo/learning_rate
: The learning rate for the PPO algorithm.ppo/policy/entropy
: The entropy of the model’s policy, calculated bypd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
. It measures the randomness of the policy.ppo/policy/clipfrac
: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.ppo/policy/approxkl
: The approximate KL divergence between the old and new policies, measured by0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
, corresponding to thek2
estimator in http://joschu.net/blog/kl-approx.htmlppo/policy/policykl
: Similar toppo/policy/approxkl
, but measured bymasked_mean(old_logprobs - logprobs, mask)
, corresponding to thek1
estimator in http://joschu.net/blog/kl-approx.htmlppo/policy/ratio
: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.ppo/policy/advantages_mean
: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.ppo/policy/advantages
: The histogram distribution ofppo/policy/advantages_mean
.ppo/returns/mean
: The mean of the TD(λ) returns, calculated byreturns = advantage + values
, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.ppo/returns/var
: The variance of the TD(λ) returns, calculated byreturns = advantage + values
, another indicator of model performance.ppo/val/mean
: The mean of the values, used to monitor the value function’s performance.ppo/val/var
: The variance of the values, used to monitor the value function’s performance.ppo/val/var_explained
: The explained variance for the value function, used to monitor the value function’s performance.ppo/val/clipfrac
: The fraction of the value function’s predicted values that are clipped.ppo/val/vpred
: The predicted values from the value function.ppo/val/error
: The mean squared error between theppo/val/vpred
and returns, used to monitor the value function’s performance.ppo/loss/policy
: The policy loss for the Proximal Policy Optimization (PPO) algorithm.ppo/loss/value
: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.ppo/loss/total
: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
tokens/queries_len_mean
: The average length of the queries tokens.tokens/queries_len_std
: The standard deviation of the length of the queries tokens.tokens/queries_dist
: The histogram distribution of the length of the queries tokens.tokens/responses_len_mean
: The average length of the responses tokens.tokens/responses_len_std
: The standard deviation of the length of the responses tokens.tokens/responses_dist
: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should betokens/responses_len_dist
)objective/logprobs
: The histogram distribution of the log probabilities of the actions taken by the model.objective/ref_logprobs
: The histogram distribution of the log probabilities of the actions taken by the reference model.
Crucial values
During training, many values are logged, here are the most important ones:env/reward_mean
,env/reward_std
,env/reward_dist
: the properties of the reward distribution from the “environment” / reward modelppo/mean_non_score_reward
: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
ppo/loss/value
: it will spike / NaN when not going well.ppo/policy/ratio
:ratio
being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.ppo/policy/clipfrac
andppo/policy/approxkl
: ifratio
is too high, theratio
is going to get clipped, resulting in highclipfrac
and highapproxkl
as well.objective/kl
: it should stay positive so that the policy is not too far away from the reference policy.objective/kl_coef
: The target coefficient withAdaptiveKLController
. Often increases before numerical instabilities.