This pseudocode outlines the core logic for calculating the EMPG advantage, demonstrating how to modulate credits based on step-level entropy.
import numpy as np
import torch
def compute_empg_advantage(tokenizer, batch, k=1.0, k_f=1.0, zeta=0.1):
"""
Args:
tokenizer: The tokenizer for identifying response segments.
batch: A data batch with 'responses', 'old_entropy', 'advantages'.
k (float): Hyperparameter for self-calibrating gradient scaling.
k_f (float): Hyperparameter for the future clarity bonus.
zeta (float): Hyperparameter for the future clarity bonus.
"""
# --- 1. First Pass: Collect Step-Level Entropies ---
all_step_entropies = []
# segments_to_modify stores {'sample_idx', 'start', 'end'} for each step
segments_to_modify = []
for i in range(batch.batch.batch_size[0]):
# Find "assistant" segments, which correspond to agent steps.
token_segments = process_token_sequences(
batch.batch['responses'][i],
tokenizer.encode("<|im_start|>assistant\n"),
tokenizer.encode('<|im_end|>')
)
for start, end in token_segments:
if start >= end: continue
# Calculate the average token-level entropy for the step
step_entropy = batch.batch['old_entropy'][i][start:end].mean().item()
all_step_entropies.append(step_entropy)
segments_to_modify.append({'sample_idx': i, 'start': start, 'end': end})
if not all_step_entropies: return
# --- 2. Calculate Modulated Advantage Components ---
H = np.array(all_step_entropies)
# Batch-level entropy normalization (Eq. 12) with epsilon = 1e-8
min_H, max_H = np.min(H), np.max(H)
H_norm = (H - min_H) / (max_H - min_H + 1e-8)
# Self-calibrating gradient scaling g(H) (Eq. 10)
g_H_unnormalized = np.exp(-k * H_norm)
mean_g_H = np.mean(g_H_unnormalized)
g_H = g_H_unnormalized / (mean_g_H + 1e-8)
# Future clarity bonus f(H) (Eq. 11)
f_H = np.exp(-k_f * H_norm)
# Convert to tensors for PyTorch operations
g_H = torch.tensor(g_H, device=batch.batch['advantages'].device, dtype=torch.float32)
f_H = torch.tensor(f_H, device=batch.batch['advantages'].device, dtype=torch.float32)
# --- 3. Second Pass: Apply Advantage Modulation (Eq. 8) ---
step_advantages = []
for i, segment in enumerate(segments_to_modify):
idx, start, end = segment['sample_idx'], segment['start'], segment['end']
# Apply self-calibrating gradient scaling
batch.batch['advantages'][idx][start:end] *= g_H[i]
# Add future clarity bonus if there is a next step
next_seg = segments_to_modify[i+1] if i+1 < len(segments_to_modify) else None
if next_seg and next_seg['sample_idx'] == idx:
batch.batch['advantages'][idx][start:end] += zeta * f_H[i+1]
step_advantages.append(batch.batch['advantages'][idx][start])
# --- 4. Final Advantage Normalization (Eq. 7) ---
if step_advantages:
final_adv_mean = torch.mean(torch.stack(step_advantages))
batch.batch['advantages'] -= final_adv_mean