"""
RoPE (Rotary Position Embedding) 实现与可视化
对应教程第 3 章
"""

import torch
import torch.nn as nn
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np


def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
    """预计算 RoPE 的复数旋转因子

    RoPE 核心公式:
      freq_i = 1 / (theta^(2i/d)),  i = 0, 1, ..., d/2-1
      freqs_cis[t, i] = exp(j * t * freq_i) = cos(t·freq_i) + j·sin(t·freq_i)
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(seq_len).float()
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """将 RoPE 应用到 Q 和 K

    将实数张量转换为复数，乘以旋转因子，再转回实数:
      x_rope = Re[(x[0]+jx[1]) · (cos(mθ)+jsin(mθ))]
             = x[0]cos(mθ) - x[1]sin(mθ)  (实部)
             = x[0]sin(mθ) + x[1]cos(mθ)  (虚部)
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def rope_manual(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    """手动实现 RoPE（不使用复数，方便理解）

    对每对 (x_{2i}, x_{2i+1}) 应用二维旋转:
      [cos(mθ_i)  -sin(mθ_i)] [x_{2i}  ]
      [sin(mθ_i)   cos(mθ_i)] [x_{2i+1}]
    """
    d = x.shape[-1]
    x1 = x[..., :d // 2]
    x2 = x[..., d // 2:]

    cos_f = freqs.cos()
    sin_f = freqs.sin()

    out1 = x1 * cos_f - x2 * sin_f
    out2 = x1 * sin_f + x2 * cos_f
    return torch.cat([out1, out2], dim=-1)


def verify_rope_property():
    """验证 RoPE 的核心性质：内积只依赖相对位置"""
    print("=" * 60)
    print("验证 RoPE 核心性质：q_m^T · k_n 只依赖 (m-n)")
    print("=" * 60)

    dim = 64
    seq_len = 128
    freqs_cis = precompute_freqs_cis(dim, seq_len)

    q_base = torch.randn(1, 1, 1, dim)
    k_base = torch.randn(1, 1, 1, dim)

    # 测试不同绝对位置但相同相对位置的内积
    relative_dist = 5
    results = []
    for m in [0, 10, 20, 50, 100]:
        n = m + relative_dist
        if n >= seq_len:
            break
        q_rope, _ = apply_rotary_emb(q_base, q_base, freqs_cis[m:m+1])
        _, k_rope = apply_rotary_emb(k_base, k_base, freqs_cis[n:n+1])
        dot = (q_rope * k_rope).sum().item()
        results.append((m, n, dot))
        print(f"  位置 (m={m:3d}, n={n:3d}), 相对距离={relative_dist}: 内积 = {dot:.6f}")

    max_diff = max(abs(r[2] - results[0][2]) for r in results)
    print(f"\n  最大偏差: {max_diff:.2e} (应接近 0)")
    print(f"  结论: {'通过' if max_diff < 1e-5 else '失败'} — 内积与绝对位置无关\n")


def visualize_rope():
    """可视化 RoPE 的旋转模式"""
    dim = 64
    seq_len = 256
    freqs_cis = precompute_freqs_cis(dim, seq_len)
    freqs_real = freqs_cis.real.numpy()
    freqs_imag = freqs_cis.imag.numpy()

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('RoPE (Rotary Position Embedding) 可视化', fontsize=14)

    # 1. 不同维度的旋转频率
    ax = axes[0, 0]
    positions = np.arange(seq_len)
    for i, dim_idx in enumerate([0, 4, 8, 16, 31]):
        ax.plot(positions, freqs_real[:, dim_idx], label=f'dim {dim_idx*2}', alpha=0.8)
    ax.set_xlabel('Position')
    ax.set_ylabel('cos(m·θ_i)')
    ax.set_title('Cosine Component by Dimension')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # 2. 频率热力图
    ax = axes[0, 1]
    im = ax.imshow(freqs_real[:64, :].T, aspect='auto', cmap='RdBu_r',
                   interpolation='nearest', vmin=-1, vmax=1)
    ax.set_xlabel('Position')
    ax.set_ylabel('Dimension pair index')
    ax.set_title('RoPE cos(m·θ_i) Heatmap')
    plt.colorbar(im, ax=ax)

    # 3. 2D 旋转轨迹
    ax = axes[1, 0]
    for pos in [0, 16, 32, 64, 128]:
        ax.plot(freqs_real[pos, :16], freqs_imag[pos, :16], 'o-',
                label=f'pos={pos}', markersize=3, alpha=0.7)
    ax.set_xlabel('Real (cos)')
    ax.set_ylabel('Imaginary (sin)')
    ax.set_title('Rotation Vectors at Different Positions')
    ax.legend(fontsize=8)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

    # 4. 内积衰减曲线
    ax = axes[1, 1]
    q = torch.randn(1, 1, 1, dim)
    k = torch.randn(1, 1, 1, dim)
    freqs_cis_full = precompute_freqs_cis(dim, seq_len)

    dots = []
    for dist in range(seq_len):
        q_r, _ = apply_rotary_emb(q, q, freqs_cis_full[0:1])
        _, k_r = apply_rotary_emb(k, k, freqs_cis_full[dist:dist+1])
        dots.append((q_r * k_r).sum().item())

    ax.plot(range(seq_len), dots, linewidth=0.8)
    ax.set_xlabel('Relative Distance |m-n|')
    ax.set_ylabel('q_m^T · k_n')
    ax.set_title('Attention Score vs Relative Distance')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('/Users/niujitang/Code/llm-training/examples/rope_visualization.png', dpi=150)
    print("可视化已保存到 examples/rope_visualization.png")


def compare_theta_values():
    """比较不同 θ 基值对长度外推的影响"""
    print("=" * 60)
    print("不同 θ 基值的频率分布")
    print("=" * 60)

    dim = 128
    for theta in [10_000, 500_000, 10_000_000]:
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        min_wavelength = 2 * math.pi / freqs.max().item()
        max_wavelength = 2 * math.pi / freqs.min().item()
        print(f"  θ = {theta:>10,}: 波长范围 [{min_wavelength:.1f}, {max_wavelength:.1f}]")

    print()
    print("  LLaMA 1/2: θ=10,000")
    print("  LLaMA 3:   θ=500,000")
    print("  DeepSeek:  θ=10,000,000 (支持更长上下文)")


def main():
    print("RoPE (Rotary Position Embedding) 实现与分析\n")

    verify_rope_property()
    compare_theta_values()
    print()

    # 基准测试两种实现
    print("=" * 60)
    print("两种 RoPE 实现对比")
    print("=" * 60)

    dim = 128
    seq_len = 1024
    batch_size = 4
    n_heads = 16

    x = torch.randn(batch_size, seq_len, n_heads, dim)
    freqs_cis = precompute_freqs_cis(dim, seq_len)

    # 复数实现
    q_complex, k_complex = apply_rotary_emb(x, x, freqs_cis)

    # 手动实现
    freqs_for_manual = torch.outer(torch.arange(seq_len).float(),
                                    1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)))
    freqs_for_manual = freqs_for_manual.unsqueeze(0).unsqueeze(2)
    q_manual = rope_manual(x, freqs_for_manual)

    diff = (q_complex - q_manual).abs().max().item()
    print(f"  两种实现最大误差: {diff:.2e}")
    print(f"  结论: {'一致' if diff < 1e-5 else '不一致'}")

    print("\n生成可视化图表...")
    visualize_rope()


if __name__ == '__main__':
    main()
