"""
PPO (Proximal Policy Optimization) 训练 LLM — 简化版
对应教程第 13 章

LLM-PPO 四模型架构:
  1. Policy Model (πθ) — 待优化的语言模型
  2. Reference Model (πref) — 冻结的 SFT 基线
  3. Reward Model (rφ) — 评估回复质量
  4. Value Model (Vψ) — 估计状态价值 (用于 GAE)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math


@dataclass
class PPOConfig:
    clip_eps: float = 0.2
    vf_coef: float = 0.5
    kl_coef: float = 0.1
    gamma: float = 1.0
    lam: float = 0.95
    n_epochs: int = 4
    batch_size: int = 4
    max_grad_norm: float = 1.0


class SimpleLM(nn.Module):
    """简化的语言模型用于 PPO 演示"""

    def __init__(self, vocab_size: int = 100, d_model: int = 128, n_layers: int = 2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, 4, d_model * 4, batch_first=True, norm_first=True)
            for _ in range(n_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(input_ids)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device)
        for layer in self.layers:
            x = layer(x, src_mask=causal_mask, is_causal=True)
        return self.head(x)

    def generate(self, prompt: torch.Tensor, max_new_tokens: int = 32,
                 temperature: float = 1.0) -> torch.Tensor:
        """自回归生成"""
        generated = prompt.clone()
        for _ in range(max_new_tokens):
            logits = self.forward(generated)[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            generated = torch.cat([generated, next_token], dim=1)
        return generated


class ValueHead(nn.Module):
    """价值函数头: 在 LM 之上添加标量预测"""

    def __init__(self, d_model: int):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.head(hidden_states).squeeze(-1)


class SimpleRewardModel(nn.Module):
    """简化的奖励模型"""

    def __init__(self, vocab_size: int = 100, d_model: int = 64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(input_ids).mean(dim=1)
        return self.net(x).squeeze(-1)


def compute_log_probs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """计算每个 token 的 log probability"""
    log_probs = F.log_softmax(logits, dim=-1)
    return log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)


def compute_gae(rewards: torch.Tensor, values: torch.Tensor,
                gamma: float, lam: float) -> tuple[torch.Tensor, torch.Tensor]:
    """广义优势估计 (GAE)

    δ_t = r_t + γV(s_{t+1}) - V(s_t)
    Â_t = Σ_{l=0}^{T-t} (γλ)^l · δ_{t+l}

    返回: (advantages, returns)
    """
    T = len(rewards)
    advantages = torch.zeros_like(rewards)
    last_gae = 0

    for t in reversed(range(T)):
        next_value = values[t + 1] if t < T - 1 else 0
        delta = rewards[t] + gamma * next_value - values[t]
        advantages[t] = last_gae = delta + gamma * lam * last_gae

    returns = advantages + values
    return advantages, returns


def ppo_step(
    policy: SimpleLM,
    ref_policy: SimpleLM,
    value_model: ValueHead,
    sequences: torch.Tensor,
    prompt_len: int,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    returns: torch.Tensor,
    config: PPOConfig,
) -> dict:
    """执行一步 PPO 更新

    L_CLIP = min(r_t · Â_t, clip(r_t, 1-ε, 1+ε) · Â_t)
    L_VF = (V_θ(s_t) - R_t)²
    L = -L_CLIP + c₁·L_VF - c₂·H(π_θ)
    """
    response = sequences[:, prompt_len:]

    # 当前策略的 log probs
    logits = policy(sequences)[:, prompt_len-1:-1, :]
    log_probs = compute_log_probs(logits, response)

    # 参考策略的 log probs (KL 惩罚)
    with torch.no_grad():
        ref_logits = ref_policy(sequences)[:, prompt_len-1:-1, :]
        ref_log_probs = compute_log_probs(ref_logits, response)

    # KL 散度惩罚
    kl = (log_probs - ref_log_probs).mean()

    # 重要性采样比率
    ratio = (log_probs - old_log_probs).exp()

    # PPO-Clip 目标
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - config.clip_eps, 1 + config.clip_eps) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()

    # 价值函数损失
    # 获取隐藏状态用于 value head
    with torch.no_grad():
        hidden = policy.embed(sequences)
    values = value_model(hidden[:, prompt_len:, :])
    value_loss = F.mse_loss(values, returns)

    # 熵 bonus (鼓励探索)
    entropy = -(F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum(-1).mean()

    # 总损失
    total_loss = policy_loss + config.vf_coef * value_loss - 0.01 * entropy + config.kl_coef * kl

    # 统计信息
    with torch.no_grad():
        clip_frac = ((ratio - 1).abs() > config.clip_eps).float().mean()
        approx_kl = ((ratio - 1) - ratio.log()).mean()

    return {
        'loss': total_loss,
        'policy_loss': policy_loss.item(),
        'value_loss': value_loss.item(),
        'entropy': entropy.item(),
        'kl': kl.item(),
        'clip_frac': clip_frac.item(),
        'approx_kl': approx_kl.item(),
    }


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    print("=" * 60)
    print("PPO 训练 LLM 演示")
    print("=" * 60)

    config = PPOConfig()
    vocab_size = 100
    d_model = 128

    # 初始化四个模型
    policy = SimpleLM(vocab_size, d_model).to(device)
    ref_policy = SimpleLM(vocab_size, d_model).to(device)
    ref_policy.load_state_dict(policy.state_dict())
    for p in ref_policy.parameters():
        p.requires_grad_(False)

    reward_model = SimpleRewardModel(vocab_size).to(device)
    value_model = ValueHead(d_model).to(device)

    optimizer = torch.optim.AdamW(
        list(policy.parameters()) + list(value_model.parameters()),
        lr=1e-4
    )

    print(f"\n四模型架构:")
    print(f"  Policy:    {sum(p.numel() for p in policy.parameters()):>8,} params")
    print(f"  Reference: {sum(p.numel() for p in ref_policy.parameters()):>8,} params (frozen)")
    print(f"  Reward:    {sum(p.numel() for p in reward_model.parameters()):>8,} params (frozen)")
    print(f"  Value:     {sum(p.numel() for p in value_model.parameters()):>8,} params")

    print(f"\nPPO 配置:")
    print(f"  clip_eps={config.clip_eps}, kl_coef={config.kl_coef}")
    print(f"  gamma={config.gamma}, lambda={config.lam}")
    print(f"  n_epochs={config.n_epochs}")

    # PPO 训练循环
    print("\n" + "-" * 60)
    print("开始 PPO 训练")
    print("-" * 60)

    prompt_len = 8
    response_len = 24
    n_steps = 20

    for step in range(n_steps):
        # 1. 生成回复
        prompts = torch.randint(0, vocab_size, (config.batch_size, prompt_len), device=device)
        with torch.no_grad():
            sequences = policy.generate(prompts, max_new_tokens=response_len, temperature=0.8)

        # 2. 计算奖励
        with torch.no_grad():
            rewards_raw = reward_model(sequences)
            # KL 惩罚
            policy_logits = policy(sequences)[:, prompt_len-1:-1, :]
            ref_logits = ref_policy(sequences)[:, prompt_len-1:-1, :]
            response_tokens = sequences[:, prompt_len:]
            policy_lp = compute_log_probs(policy_logits, response_tokens)
            ref_lp = compute_log_probs(ref_logits, response_tokens)
            kl_per_token = policy_lp - ref_lp

        # 3. 计算 GAE
        with torch.no_grad():
            hidden = policy.embed(sequences)
            values = value_model(hidden[:, prompt_len:, :])

        # 为每个样本计算 GAE
        all_advantages = []
        all_returns = []
        for i in range(config.batch_size):
            per_token_rewards = -config.kl_coef * kl_per_token[i]
            per_token_rewards[-1] += rewards_raw[i]

            adv, ret = compute_gae(per_token_rewards, values[i], config.gamma, config.lam)
            all_advantages.append(adv)
            all_returns.append(ret)

        advantages = torch.stack(all_advantages)
        returns = torch.stack(all_returns)

        # 归一化优势
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        old_log_probs = policy_lp.detach()

        # 4. PPO 更新 (多个 epoch)
        for epoch in range(config.n_epochs):
            stats = ppo_step(
                policy, ref_policy, value_model,
                sequences, prompt_len,
                old_log_probs, advantages, returns, config
            )

            optimizer.zero_grad()
            stats['loss'].backward()
            torch.nn.utils.clip_grad_norm_(
                list(policy.parameters()) + list(value_model.parameters()),
                config.max_grad_norm
            )
            optimizer.step()

        if step % 5 == 0:
            print(f"  Step {step:3d}: "
                  f"reward={rewards_raw.mean():.3f}, "
                  f"policy_loss={stats['policy_loss']:.4f}, "
                  f"value_loss={stats['value_loss']:.4f}, "
                  f"kl={stats['kl']:.4f}, "
                  f"clip_frac={stats['clip_frac']:.2%}")

    print("\nPPO 训练完成")

    # PPO 关键技巧总结
    print("\n" + "=" * 60)
    print("PPO 训练 LLM 关键技巧")
    print("=" * 60)
    print("""
  1. KL 惩罚/约束: 防止策略偏离参考模型太远 (reward hacking)
  2. 优势归一化: 稳定训练，减少方差
  3. 梯度裁剪: 防止梯度爆炸
  4. 多 epoch 更新: 提高样本效率
  5. 价值函数裁剪: 稳定价值估计
  6. 奖励归一化/裁剪: 防止奖励尺度问题
  7. 大 batch size: 减少策略梯度方差

  显存优化:
  - Reference model 可以使用 LoRA 共享权重
  - Value head 可以挂在 policy model 上 (共享 backbone)
  - 使用梯度检查点减少激活显存
""")


if __name__ == '__main__':
    main()
