"""
DPO (Direct Preference Optimization) 训练
对应教程第 14 章

核心推导:
  RLHF 目标: max E[r(x,y)] - β·KL(π_θ || π_ref)
  最优解: π*(y|x) = (1/Z)·π_ref(y|x)·exp(r(y,x)/β)
  隐式奖励: r(y,x) = β·log(π_θ(y|x)/π_ref(y|x)) + β·log Z(x)

  DPO 损失:
  L_DPO = -E[log σ(β·(log π_θ(y_w|x)/π_ref(y_w|x) - log π_θ(y_l|x)/π_ref(y_l|x)))]
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class DPOConfig:
    beta: float = 0.1
    label_smoothing: float = 0.0
    loss_type: str = 'sigmoid'  # sigmoid | hinge | ipo
    lr: float = 5e-5
    n_epochs: int = 3
    batch_size: int = 4


class SimpleLM(nn.Module):
    """简化语言模型"""

    def __init__(self, vocab_size: int = 200, d_model: int = 128, n_layers: int = 2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, 4, d_model * 4, batch_first=True, norm_first=True)
            for _ in range(n_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(input_ids)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device)
        for layer in self.layers:
            x = layer(x, src_mask=causal_mask, is_causal=True)
        return self.head(x)


def get_per_token_logps(model: nn.Module, input_ids: torch.Tensor,
                        prompt_len: int) -> torch.Tensor:
    """计算回复部分每个 token 的 log probability"""
    logits = model(input_ids)
    # 只取回复部分
    response_logits = logits[:, prompt_len-1:-1, :]
    response_labels = input_ids[:, prompt_len:]
    log_probs = F.log_softmax(response_logits, dim=-1)
    per_token_logps = log_probs.gather(-1, response_labels.unsqueeze(-1)).squeeze(-1)
    return per_token_logps


def dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    config: DPOConfig,
) -> tuple[torch.Tensor, dict]:
    """DPO 损失函数

    L = -log σ(β · (log π_θ(y_w)/π_ref(y_w) - log π_θ(y_l)/π_ref(y_l)))
      = -log σ(β · (Δ_chosen - Δ_rejected))

    其中 Δ = log π_θ - log π_ref (log-ratio)
    """
    # 对整个序列的 log prob 求和
    chosen_logps = policy_chosen_logps.sum(dim=-1)
    rejected_logps = policy_rejected_logps.sum(dim=-1)
    ref_chosen = ref_chosen_logps.sum(dim=-1)
    ref_rejected = ref_rejected_logps.sum(dim=-1)

    # log-ratio
    chosen_logratios = chosen_logps - ref_chosen
    rejected_logratios = rejected_logps - ref_rejected

    # 隐式奖励差
    logits = config.beta * (chosen_logratios - rejected_logratios)

    if config.loss_type == 'sigmoid':
        # 标准 DPO
        if config.label_smoothing > 0:
            # 带标签平滑的 DPO (cDPO)
            losses = (
                -F.logsigmoid(logits) * (1 - config.label_smoothing)
                - F.logsigmoid(-logits) * config.label_smoothing
            )
        else:
            losses = -F.logsigmoid(logits)

    elif config.loss_type == 'hinge':
        # Hinge loss DPO
        losses = F.relu(1 - logits)

    elif config.loss_type == 'ipo':
        # IPO (Identity Preference Optimization)
        # L_IPO = (logits - 1/(2β))²
        losses = (logits - 1 / (2 * config.beta)) ** 2

    loss = losses.mean()

    # 统计信息
    with torch.no_grad():
        chosen_rewards = config.beta * chosen_logratios
        rejected_rewards = config.beta * rejected_logratios
        reward_margin = (chosen_rewards - rejected_rewards).mean()
        accuracy = (chosen_rewards > rejected_rewards).float().mean()

    stats = {
        'loss': loss.item(),
        'reward_margin': reward_margin.item(),
        'accuracy': accuracy.item(),
        'chosen_reward': chosen_rewards.mean().item(),
        'rejected_reward': rejected_rewards.mean().item(),
    }

    return loss, stats


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    print("=" * 60)
    print("DPO (Direct Preference Optimization) 训练演示")
    print("=" * 60)

    config = DPOConfig(beta=0.1, lr=5e-5, n_epochs=3, batch_size=8)
    vocab_size = 200

    # 初始化策略和参考模型
    policy = SimpleLM(vocab_size, 128, 2).to(device)
    ref_model = SimpleLM(vocab_size, 128, 2).to(device)
    ref_model.load_state_dict(policy.state_dict())
    for p in ref_model.parameters():
        p.requires_grad_(False)

    optimizer = torch.optim.AdamW(policy.parameters(), lr=config.lr, weight_decay=0.01)

    print(f"\nDPO 配置:")
    print(f"  β = {config.beta} (温度参数，越大越保守)")
    print(f"  loss_type = {config.loss_type}")
    print(f"  model params = {sum(p.numel() for p in policy.parameters()):,}")
    print(f"\n比 PPO 的优势: 无需奖励模型、无需价值模型、无需在线采样")

    # 生成模拟偏好数据
    n_samples = 500
    prompt_len = 16
    response_len = 32
    total_len = prompt_len + response_len

    prompts = torch.randint(0, vocab_size, (n_samples, prompt_len))
    chosen = torch.cat([prompts, torch.randint(0, vocab_size, (n_samples, response_len))], dim=1)
    rejected = torch.cat([prompts, torch.randint(0, vocab_size, (n_samples, response_len))], dim=1)

    print(f"\n训练数据: {n_samples} 对偏好样本")
    print(f"  prompt_len={prompt_len}, response_len={response_len}")

    # 训练循环
    print("\n" + "-" * 60)
    print("开始 DPO 训练")
    print("-" * 60)

    for epoch in range(config.n_epochs):
        total_stats = {k: 0 for k in ['loss', 'reward_margin', 'accuracy', 'chosen_reward', 'rejected_reward']}
        n_batches = 0

        indices = torch.randperm(n_samples)
        for i in range(0, n_samples - config.batch_size + 1, config.batch_size):
            batch_idx = indices[i:i+config.batch_size]
            chosen_batch = chosen[batch_idx].to(device)
            rejected_batch = rejected[batch_idx].to(device)

            # 计算策略和参考模型的 log probs
            policy_chosen_logps = get_per_token_logps(policy, chosen_batch, prompt_len)
            policy_rejected_logps = get_per_token_logps(policy, rejected_batch, prompt_len)

            with torch.no_grad():
                ref_chosen_logps = get_per_token_logps(ref_model, chosen_batch, prompt_len)
                ref_rejected_logps = get_per_token_logps(ref_model, rejected_batch, prompt_len)

            loss, stats = dpo_loss(
                policy_chosen_logps, policy_rejected_logps,
                ref_chosen_logps, ref_rejected_logps,
                config
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

            for k, v in stats.items():
                total_stats[k] += v
            n_batches += 1

        avg_stats = {k: v / n_batches for k, v in total_stats.items()}
        print(f"  Epoch {epoch+1}: "
              f"loss={avg_stats['loss']:.4f}, "
              f"acc={avg_stats['accuracy']:.2%}, "
              f"margin={avg_stats['reward_margin']:.4f}, "
              f"r_chosen={avg_stats['chosen_reward']:.3f}, "
              f"r_rejected={avg_stats['rejected_reward']:.3f}")

    # DPO 变体对比
    print("\n" + "=" * 60)
    print("DPO 变体对比")
    print("=" * 60)

    variant_configs = [
        ('DPO (标准)', DPOConfig(beta=0.1, loss_type='sigmoid')),
        ('cDPO (标签平滑)', DPOConfig(beta=0.1, loss_type='sigmoid', label_smoothing=0.1)),
        ('IPO', DPOConfig(beta=0.1, loss_type='ipo')),
        ('Hinge DPO', DPOConfig(beta=0.1, loss_type='hinge')),
    ]

    # 用固定数据测试各变体损失
    test_chosen = chosen[:16].to(device)
    test_rejected = rejected[:16].to(device)

    with torch.no_grad():
        p_chosen = get_per_token_logps(policy, test_chosen, prompt_len)
        p_rejected = get_per_token_logps(policy, test_rejected, prompt_len)
        r_chosen = get_per_token_logps(ref_model, test_chosen, prompt_len)
        r_rejected = get_per_token_logps(ref_model, test_rejected, prompt_len)

    print(f"\n{'变体':20s} | {'损失':>8s} | {'准确率':>8s} | {'奖励差':>8s}")
    print("-" * 55)
    for name, cfg in variant_configs:
        _, stats = dpo_loss(p_chosen, p_rejected, r_chosen, r_rejected, cfg)
        print(f"{name:20s} | {stats['loss']:8.4f} | {stats['accuracy']:7.2%} | {stats['reward_margin']:8.4f}")

    print("""
  DPO vs PPO 选择指南:
  ┌──────────────┬──────────────┬──────────────┐
  │              │     DPO      │     PPO      │
  ├──────────────┼──────────────┼──────────────┤
  │ 需要奖励模型 │      否      │      是      │
  │ 需要在线采样 │      否      │      是      │
  │ 模型数量     │      2       │      4       │
  │ 显存需求     │      低      │      高      │
  │ 实现复杂度   │      低      │      高      │
  │ 分布偏移     │   可能严重   │   在线纠正   │
  │ 适用场景     │  离线偏好数据 │  需要探索     │
  └──────────────┴──────────────┴──────────────┘
""")


if __name__ == '__main__':
    main()
