"""
GRPO (Group Relative Policy Optimization) 训练实现
对应教程第 15 章

核心思想 (DeepSeek-Math, Shao et al. 2024):
  - 对每个 prompt 生成一组 G 个回复
  - 用奖励计算组内相对优势 (无需 critic/value model)
  - 用 PPO-style clipped objective 更新策略

优势估计:
  Â_i = (r_i - mean(r_1..G)) / std(r_1..G)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math


@dataclass
class GRPOConfig:
    group_size: int = 8
    clip_eps: float = 0.2
    beta: float = 0.04
    n_epochs: int = 2
    batch_size: int = 4
    max_grad_norm: float = 1.0
    temperature: float = 0.8


class SimpleLM(nn.Module):
    """简化语言模型"""

    def __init__(self, vocab_size: int = 200, d_model: int = 128, n_layers: int = 2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, 4, d_model * 4, batch_first=True, norm_first=True)
            for _ in range(n_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(input_ids)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device)
        for layer in self.layers:
            x = layer(x, src_mask=causal_mask, is_causal=True)
        return self.head(x)

    def generate(self, prompt: torch.Tensor, max_new_tokens: int = 32,
                 temperature: float = 0.8) -> tuple[torch.Tensor, torch.Tensor]:
        """生成并返回 (sequences, log_probs)"""
        generated = prompt.clone()
        all_log_probs = []

        for _ in range(max_new_tokens):
            logits = self.forward(generated)[:, -1, :] / temperature
            log_probs = F.log_softmax(logits, dim=-1)
            probs = log_probs.exp()
            next_token = torch.multinomial(probs, 1)
            token_log_prob = log_probs.gather(-1, next_token).squeeze(-1)
            all_log_probs.append(token_log_prob)
            generated = torch.cat([generated, next_token], dim=1)

        return generated, torch.stack(all_log_probs, dim=1)


def simple_reward_fn(sequences: torch.Tensor, prompt_len: int) -> torch.Tensor:
    """简化的奖励函数 (演示用)

    实际应用中可以是:
    - 奖励模型打分
    - 代码执行结果 (pass/fail)
    - 数学验证 (答案正确性)
    """
    response = sequences[:, prompt_len:].float()
    diversity = response.unique(dim=-1).shape[-1] / response.shape[-1]
    length_penalty = -0.01 * response.shape[-1]
    return diversity + length_penalty + torch.randn(sequences.shape[0], device=sequences.device) * 0.1


def compute_group_advantages(rewards: torch.Tensor, group_size: int) -> torch.Tensor:
    """计算组内相对优势

    对每组 G 个回复:
      Â_i = (r_i - mean(r_1..G)) / std(r_1..G)

    这是 GRPO 相比 PPO 最大的简化:
    不需要训练 value model，直接用组内统计量作为 baseline
    """
    B = rewards.shape[0]
    n_prompts = B // group_size

    rewards = rewards.view(n_prompts, group_size)
    mean = rewards.mean(dim=1, keepdim=True)
    std = rewards.std(dim=1, keepdim=True).clamp(min=1e-8)

    advantages = (rewards - mean) / std
    return advantages.view(-1)


def grpo_loss(
    policy_log_probs: torch.Tensor,
    old_log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    config: GRPOConfig,
) -> tuple[torch.Tensor, dict]:
    """GRPO 损失函数

    L_GRPO = -E[min(r_t · Â, clip(r_t, 1-ε, 1+ε) · Â)] + β · KL(π_θ || π_ref)

    其中:
      r_t = π_θ(y|x) / π_old(y|x)  重要性采样比率
      Â = 组内相对优势 (无需 GAE)
      KL 惩罚使用 per-token KL
    """
    # 序列级 log probs
    seq_policy_logps = policy_log_probs.sum(dim=-1)
    seq_old_logps = old_log_probs.sum(dim=-1)

    # 重要性采样比率
    ratio = (seq_policy_logps - seq_old_logps).exp()

    # PPO-Clip 目标
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - config.clip_eps, 1 + config.clip_eps) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()

    # KL 散度惩罚 (per-token)
    kl = (policy_log_probs - ref_log_probs).mean()
    kl_loss = config.beta * kl

    total_loss = policy_loss + kl_loss

    with torch.no_grad():
        clip_frac = ((ratio - 1).abs() > config.clip_eps).float().mean()

    stats = {
        'loss': total_loss.item(),
        'policy_loss': policy_loss.item(),
        'kl': kl.item(),
        'clip_frac': clip_frac.item(),
        'mean_advantage': advantages.mean().item(),
        'mean_ratio': ratio.mean().item(),
    }

    return total_loss, stats


def get_per_token_logps(model: nn.Module, sequences: torch.Tensor,
                        prompt_len: int) -> torch.Tensor:
    """计算回复部分每个 token 的 log probability"""
    logits = model(sequences)[:, prompt_len-1:-1, :]
    labels = sequences[:, prompt_len:]
    return F.log_softmax(logits, dim=-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    print("=" * 60)
    print("GRPO (Group Relative Policy Optimization) 训练演示")
    print("=" * 60)

    config = GRPOConfig(group_size=8, clip_eps=0.2, beta=0.04, batch_size=4)
    vocab_size = 200

    # GRPO 只需要两个模型 (vs PPO 的四个)
    policy = SimpleLM(vocab_size, 128, 2).to(device)
    ref_model = SimpleLM(vocab_size, 128, 2).to(device)
    ref_model.load_state_dict(policy.state_dict())
    for p in ref_model.parameters():
        p.requires_grad_(False)

    optimizer = torch.optim.AdamW(policy.parameters(), lr=1e-4)

    print(f"\nGRPO 配置:")
    print(f"  group_size (G) = {config.group_size}")
    print(f"  clip_eps = {config.clip_eps}")
    print(f"  β (KL 系数) = {config.beta}")
    print(f"\nGRPO vs PPO 关键区别:")
    print(f"  - 无需 Value Model (critic)")
    print(f"  - 无需 GAE — 用组内统计量代替")
    print(f"  - 模型数量: 2 (policy + ref) vs 4 (policy + ref + reward + value)")
    print(f"  - 显存节省: ~50%")

    # 训练循环
    print("\n" + "-" * 60)
    print("开始 GRPO 训练")
    print("-" * 60)

    prompt_len = 8
    response_len = 24
    n_steps = 15

    for step in range(n_steps):
        # 1. 为每个 prompt 生成一组回复
        n_prompts = config.batch_size
        prompts = torch.randint(0, vocab_size, (n_prompts, prompt_len), device=device)

        # 扩展 prompt: 每个重复 G 次
        expanded_prompts = prompts.repeat_interleave(config.group_size, dim=0)

        with torch.no_grad():
            sequences, gen_log_probs = policy.generate(
                expanded_prompts, max_new_tokens=response_len,
                temperature=config.temperature
            )

        # 2. 计算奖励
        with torch.no_grad():
            rewards = simple_reward_fn(sequences, prompt_len)

        # 3. 计算组内相对优势
        advantages = compute_group_advantages(rewards, config.group_size)

        # 4. GRPO 更新
        with torch.no_grad():
            old_log_probs = get_per_token_logps(policy, sequences, prompt_len)
            ref_log_probs = get_per_token_logps(ref_model, sequences, prompt_len)

        for epoch in range(config.n_epochs):
            policy_log_probs = get_per_token_logps(policy, sequences, prompt_len)

            loss, stats = grpo_loss(
                policy_log_probs, old_log_probs, ref_log_probs,
                advantages, config
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), config.max_grad_norm)
            optimizer.step()

        if step % 3 == 0:
            # 显示组内奖励分布
            group_rewards = rewards.view(n_prompts, config.group_size)
            best_in_group = group_rewards.max(dim=1).values.mean()
            worst_in_group = group_rewards.min(dim=1).values.mean()

            print(f"  Step {step:3d}: "
                  f"loss={stats['loss']:.4f}, "
                  f"kl={stats['kl']:.4f}, "
                  f"clip={stats['clip_frac']:.2%}, "
                  f"r_mean={rewards.mean():.3f}, "
                  f"r_best={best_in_group:.3f}, "
                  f"r_worst={worst_in_group:.3f}")

    print("\nGRPO 训练完成")

    # GRPO 与其他方法对比
    print("\n" + "=" * 60)
    print("GRPO vs PPO vs DPO 对比")
    print("=" * 60)
    print("""
  ┌────────────────┬──────────┬──────────┬──────────┐
  │                │   PPO    │   DPO    │   GRPO   │
  ├────────────────┼──────────┼──────────┼──────────┤
  │ 需要 RM        │    是    │    否    │  可选(1) │
  │ 需要 Value     │    是    │    否    │    否    │
  │ 在线采样       │    是    │    否    │    是    │
  │ 模型数         │    4     │    2     │    2     │
  │ 优势估计       │   GAE    │   隐式   │ 组内统计 │
  │ 数据效率       │   中等   │    高    │   中等   │
  │ 探索能力       │    强    │    弱    │    强    │
  │ 实现复杂度     │    高    │    低    │    中    │
  │ DeepSeek 使用  │    -     │    -     │  R1/Math │
  └────────────────┴──────────┴──────────┴──────────┘

  (1) GRPO 可用 RM 打分，也可用规则奖励 (如代码执行、数学验证)

  GRPO 在 DeepSeek-R1 中的应用:
  - 使用 GRPO 替代 PPO 进行 RL 训练
  - 奖励来自规则验证 (数学正确性、代码执行)
  - 无需人类标注数据，完全自主学习推理能力
  - Group size G=64，每个 prompt 生成 64 个回复
""")


if __name__ == '__main__':
    main()
