"""
注意力机制变体实现与对比：MHA / GQA / MQA
对应教程第 3、4 章
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time


class MultiHeadAttention(nn.Module):
    """标准多头注意力 (MHA)"""

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, C = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        scale = math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) / scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


class GroupedQueryAttention(nn.Module):
    """分组查询注意力 (GQA) — LLaMA 2/3 使用"""

    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads % n_kv_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = d_model // n_heads
        self.n_rep = n_heads // n_kv_heads

        self.W_q = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, C = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # 将 KV 头重复以匹配 Q 头数
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)

        scale = math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) / scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


class MultiQueryAttention(nn.Module):
    """多查询注意力 (MQA) — 所有头共享单个 KV"""

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, self.head_dim, bias=False)
        self.W_v = nn.Linear(d_model, self.head_dim, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, C = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(B, T, 1, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(B, T, 1, self.head_dim).transpose(1, 2)

        k = k.expand(-1, self.n_heads, -1, -1)
        v = v.expand(-1, self.n_heads, -1, -1)

        scale = math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) / scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


def count_params(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters())


def benchmark_attention(attn_module: nn.Module, x: torch.Tensor, name: str, n_iters: int = 100):
    """基准测试注意力模块"""
    device = x.device

    # 预热
    for _ in range(10):
        _ = attn_module(x)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    start = time.perf_counter()
    for _ in range(n_iters):
        _ = attn_module(x)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / n_iters * 1000

    params = count_params(attn_module)
    print(f"  {name:5s} | 参数量: {params:>10,} | 延迟: {elapsed:.3f} ms")


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    d_model = 1024
    n_heads = 16
    n_kv_heads = 4
    seq_len = 512
    batch_size = 4

    print("=" * 60)
    print("注意力机制变体对比")
    print("=" * 60)
    print(f"配置: d_model={d_model}, n_heads={n_heads}, n_kv_heads={n_kv_heads}")
    print(f"输入: batch={batch_size}, seq_len={seq_len}\n")

    mha = MultiHeadAttention(d_model, n_heads).to(device)
    gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads).to(device)
    mqa = MultiQueryAttention(d_model, n_heads).to(device)

    x = torch.randn(batch_size, seq_len, d_model, device=device)

    # 正确性验证
    with torch.no_grad():
        out_mha = mha(x)
        out_gqa = gqa(x)
        out_mqa = mqa(x)

    print("输出形状验证:")
    print(f"  MHA: {out_mha.shape}")
    print(f"  GQA: {out_gqa.shape}")
    print(f"  MQA: {out_mqa.shape}")
    print()

    # KV 缓存显存分析
    print("KV-Cache 显存分析 (每层, FP16):")
    n_layers = 32
    kv_mha = 2 * n_layers * n_heads * (d_model // n_heads) * seq_len * 2
    kv_gqa = 2 * n_layers * n_kv_heads * (d_model // n_heads) * seq_len * 2
    kv_mqa = 2 * n_layers * 1 * (d_model // n_heads) * seq_len * 2
    print(f"  MHA: {kv_mha / 1024**2:.1f} MB  (n_kv_heads={n_heads})")
    print(f"  GQA: {kv_gqa / 1024**2:.1f} MB  (n_kv_heads={n_kv_heads})")
    print(f"  MQA: {kv_mqa / 1024**2:.1f} MB  (n_kv_heads=1)")
    print(f"  GQA 节省: {(1 - kv_gqa/kv_mha)*100:.0f}%")
    print(f"  MQA 节省: {(1 - kv_mqa/kv_mha)*100:.0f}%")
    print()

    # 性能基准
    print("性能基准:")
    with torch.no_grad():
        benchmark_attention(mha, x, "MHA")
        benchmark_attention(gqa, x, "GQA")
        benchmark_attention(mqa, x, "MQA")

    # FlashAttention 对比 (如果可用)
    print()
    try:
        print("FlashAttention (PyTorch SDPA) 对比:")
        q = torch.randn(batch_size, n_heads, seq_len, d_model // n_heads, device=device)
        k = torch.randn_like(q)
        v = torch.randn_like(q)

        # 标准实现
        start = time.perf_counter()
        for _ in range(100):
            scale = math.sqrt(d_model // n_heads)
            attn = torch.matmul(q, k.transpose(-2, -1)) / scale
            attn = F.softmax(attn, dim=-1)
            out_naive = torch.matmul(attn, v)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        naive_time = (time.perf_counter() - start) / 100 * 1000

        # SDPA (自动选择最优后端)
        start = time.perf_counter()
        for _ in range(100):
            out_sdpa = F.scaled_dot_product_attention(q, k, v)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        sdpa_time = (time.perf_counter() - start) / 100 * 1000

        print(f"  朴素实现: {naive_time:.3f} ms")
        print(f"  SDPA:     {sdpa_time:.3f} ms")
        print(f"  加速比:   {naive_time / sdpa_time:.2f}x")

        # 验证数值一致性
        diff = (out_naive - out_sdpa).abs().max().item()
        print(f"  最大误差: {diff:.6e}")
    except Exception as e:
        print(f"  SDPA 不可用: {e}")


if __name__ == '__main__':
    main()
