"""
Bradley-Terry 奖励模型训练
对应教程第 12、17 章

核心: P(y_w > y_l | x) = σ(r(x, y_w) - r(x, y_l))
损失: L = -E[log σ(r_w - r_l)]
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math


class RewardModel(nn.Module):
    """奖励模型：在语言模型最后一层上加标量头

    架构: LM backbone → 最后一个 token 的隐藏状态 → 线性层 → 标量奖励
    """

    def __init__(self, d_model: int = 256, n_heads: int = 4, n_layers: int = 2,
                 vocab_size: int = 1000, max_len: int = 512):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, n_layers)

        # 奖励头：隐藏状态 → 标量
        self.reward_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
        """
        input_ids: [B, T]
        返回: [B] 标量奖励值
        """
        B, T = input_ids.shape
        positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
        x = self.embed(input_ids) + self.pos_embed(positions)

        if attention_mask is not None:
            src_key_padding_mask = ~attention_mask.bool()
        else:
            src_key_padding_mask = None

        h = self.encoder(x, src_key_padding_mask=src_key_padding_mask)

        # 取最后一个有效 token
        if attention_mask is not None:
            seq_lens = attention_mask.sum(dim=-1).long() - 1
            last_hidden = h[torch.arange(B), seq_lens]
        else:
            last_hidden = h[:, -1]

        reward = self.reward_head(last_hidden).squeeze(-1)
        return reward


class PreferenceDataset(Dataset):
    """偏好对比数据集

    每条数据: (prompt, chosen_response, rejected_response)
    """

    def __init__(self, n_samples: int = 1000, vocab_size: int = 1000,
                 seq_len: int = 64):
        self.n_samples = n_samples
        self.chosen = torch.randint(0, vocab_size, (n_samples, seq_len))
        self.rejected = torch.randint(0, vocab_size, (n_samples, seq_len))
        self.chosen_mask = torch.ones(n_samples, seq_len)
        self.rejected_mask = torch.ones(n_samples, seq_len)

        # 模拟不同长度
        for i in range(n_samples):
            chosen_len = torch.randint(seq_len // 2, seq_len, (1,)).item()
            rejected_len = torch.randint(seq_len // 2, seq_len, (1,)).item()
            self.chosen_mask[i, chosen_len:] = 0
            self.rejected_mask[i, rejected_len:] = 0

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return {
            'chosen_ids': self.chosen[idx],
            'chosen_mask': self.chosen_mask[idx],
            'rejected_ids': self.rejected[idx],
            'rejected_mask': self.rejected_mask[idx],
        }


def bt_loss(reward_chosen: torch.Tensor, reward_rejected: torch.Tensor) -> torch.Tensor:
    """Bradley-Terry 偏好损失

    L = -log σ(r_w - r_l)
      = -log(1 / (1 + exp(-(r_w - r_l))))
      = log(1 + exp(r_l - r_w))
    """
    return -F.logsigmoid(reward_chosen - reward_rejected).mean()


def margin_loss(reward_chosen: torch.Tensor, reward_rejected: torch.Tensor,
                margin: float = 1.0) -> torch.Tensor:
    """带 margin 的偏好损失 (更稳定的变体)

    L = max(0, margin - (r_w - r_l))
    """
    return F.relu(margin - (reward_chosen - reward_rejected)).mean()


def train_reward_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    print("=" * 60)
    print("Bradley-Terry 奖励模型训练")
    print("=" * 60)

    # 模型与数据
    model = RewardModel(d_model=128, n_heads=4, n_layers=2, vocab_size=500).to(device)
    dataset = PreferenceDataset(n_samples=2000, vocab_size=500, seq_len=48)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(dataloader) * 3)

    params = sum(p.numel() for p in model.parameters())
    print(f"\n模型参数: {params:,}")
    print(f"训练集: {len(dataset)} 对偏好数据\n")

    # 训练循环
    for epoch in range(3):
        model.train()
        total_loss = 0
        n_correct = 0
        n_total = 0

        for batch in dataloader:
            chosen_ids = batch['chosen_ids'].to(device)
            chosen_mask = batch['chosen_mask'].to(device)
            rejected_ids = batch['rejected_ids'].to(device)
            rejected_mask = batch['rejected_mask'].to(device)

            r_chosen = model(chosen_ids, chosen_mask)
            r_rejected = model(rejected_ids, rejected_mask)

            loss = bt_loss(r_chosen, r_rejected)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            n_correct += (r_chosen > r_rejected).sum().item()
            n_total += len(r_chosen)

        acc = n_correct / n_total
        avg_loss = total_loss / len(dataloader)
        print(f"  Epoch {epoch+1}: loss={avg_loss:.4f}, accuracy={acc:.2%}")

    # 评估
    print("\n" + "=" * 60)
    print("奖励模型评估")
    print("=" * 60)

    model.eval()
    with torch.no_grad():
        test_data = PreferenceDataset(n_samples=500, vocab_size=500, seq_len=48)
        test_loader = DataLoader(test_data, batch_size=64)

        rewards_chosen = []
        rewards_rejected = []

        for batch in test_loader:
            r_c = model(batch['chosen_ids'].to(device), batch['chosen_mask'].to(device))
            r_r = model(batch['rejected_ids'].to(device), batch['rejected_mask'].to(device))
            rewards_chosen.extend(r_c.cpu().tolist())
            rewards_rejected.extend(r_r.cpu().tolist())

        rewards_chosen = torch.tensor(rewards_chosen)
        rewards_rejected = torch.tensor(rewards_rejected)

        accuracy = (rewards_chosen > rewards_rejected).float().mean()
        avg_margin = (rewards_chosen - rewards_rejected).mean()

        print(f"\n测试集结果:")
        print(f"  偏好准确率: {accuracy:.2%}")
        print(f"  平均奖励差: {avg_margin:.4f}")
        print(f"  chosen 奖励:  {rewards_chosen.mean():.4f} ± {rewards_chosen.std():.4f}")
        print(f"  rejected 奖励: {rewards_rejected.mean():.4f} ± {rewards_rejected.std():.4f}")

    # Process RM vs Outcome RM 说明
    print("\n" + "=" * 60)
    print("Process RM vs Outcome RM")
    print("=" * 60)
    print("""
  Outcome RM (ORM):
    - 对整个回复给出单一奖励分数
    - 本示例实现的就是 ORM
    - 适用于一般对话场景

  Process RM (PRM):
    - 对推理过程中每一步给出奖励
    - 适用于数学推理等需要逐步验证的场景
    - Lightman et al. (2023): PRM 在数学任务上显著优于 ORM
    - 实现: 在每个推理步骤的 token 位置输出奖励

  Best-of-N 采样:
    - 生成 N 个回复，用 RM 选最好的
    - 简单但有效的对齐方法
    - 计算量 = N × 生成成本 + N × RM 推理成本
""")


if __name__ == '__main__':
    train_reward_model()
