"""
LoRA (Low-Rank Adaptation) 从零实现
对应教程第 10 章

核心公式: h = (W₀ + ΔW)x = W₀x + BAx
其中 B ∈ R^{d×r}, A ∈ R^{r×k}, r << min(d,k)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass


@dataclass
class LoRAConfig:
    r: int = 8
    alpha: float = 16.0
    dropout: float = 0.0
    target_modules: list = None

    @property
    def scaling(self) -> float:
        return self.alpha / self.r


class LoRALinear(nn.Module):
    """LoRA 线性层

    W_new = W_frozen + (alpha/r) * B @ A

    初始化:
      A ~ N(0, σ²)  (Kaiming uniform)
      B = 0          (保证训练开始时 ΔW = 0)
    """

    def __init__(self, original: nn.Linear, config: LoRAConfig):
        super().__init__()
        self.original = original
        self.config = config

        in_features = original.in_features
        out_features = original.out_features

        # 冻结原始权重
        self.original.weight.requires_grad_(False)
        if self.original.bias is not None:
            self.original.bias.requires_grad_(False)

        # LoRA 低秩矩阵
        self.lora_A = nn.Parameter(torch.empty(config.r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, config.r))

        # Kaiming uniform 初始化 A
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

        self.scaling = config.scaling
        self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_out = self.original(x)
        lora_out = F.linear(F.linear(self.dropout(x), self.lora_A), self.lora_B)
        return base_out + self.scaling * lora_out

    def merge_weights(self) -> nn.Linear:
        """将 LoRA 权重合并回原始权重（推理时使用，无额外延迟）"""
        merged = nn.Linear(
            self.original.in_features,
            self.original.out_features,
            bias=self.original.bias is not None
        )
        merged.weight.data = self.original.weight.data + self.scaling * (self.lora_B @ self.lora_A)
        if self.original.bias is not None:
            merged.bias.data = self.original.bias.data
        return merged

    def extra_repr(self) -> str:
        return (f"in={self.original.in_features}, out={self.original.out_features}, "
                f"r={self.config.r}, alpha={self.config.alpha}, "
                f"scaling={self.scaling:.2f}")


def apply_lora(model: nn.Module, config: LoRAConfig) -> nn.Module:
    """对模型中指定模块应用 LoRA"""
    target_modules = config.target_modules or ['q_proj', 'v_proj']

    for name, module in model.named_modules():
        if any(target in name for target in target_modules):
            if isinstance(module, nn.Linear):
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                parent = model.get_submodule(parent_name) if parent_name else model
                setattr(parent, child_name, LoRALinear(module, config))

    return model


def count_parameters(model: nn.Module) -> tuple[int, int]:
    """返回 (可训练参数, 总参数)"""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


class SimpleTransformer(nn.Module):
    """简单 Transformer 用于演示 LoRA"""

    def __init__(self, d_model: int = 256, n_heads: int = 4, n_layers: int = 4, vocab_size: int = 1000):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            layer = nn.ModuleDict({
                'norm1': nn.LayerNorm(d_model),
                'q_proj': nn.Linear(d_model, d_model, bias=False),
                'k_proj': nn.Linear(d_model, d_model, bias=False),
                'v_proj': nn.Linear(d_model, d_model, bias=False),
                'o_proj': nn.Linear(d_model, d_model, bias=False),
                'norm2': nn.LayerNorm(d_model),
                'up_proj': nn.Linear(d_model, d_model * 4, bias=False),
                'down_proj': nn.Linear(d_model * 4, d_model, bias=False),
            })
            self.layers.append(layer)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(input_ids)
        B, T, C = x.shape

        for layer in self.layers:
            h = layer['norm1'](x)
            q = layer['q_proj'](h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
            k = layer['k_proj'](h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
            v = layer['v_proj'](h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
            attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
            attn = attn.transpose(1, 2).contiguous().view(B, T, C)
            x = x + layer['o_proj'](attn)

            h = layer['norm2'](x)
            x = x + layer['down_proj'](F.silu(layer['up_proj'](h)))

        return self.head(x)


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    print("=" * 60)
    print("LoRA (Low-Rank Adaptation) 从零实现演示")
    print("=" * 60)

    # 创建基础模型
    model = SimpleTransformer(d_model=256, n_heads=4, n_layers=4, vocab_size=1000).to(device)
    trainable_before, total_before = count_parameters(model)
    print(f"\n基础模型参数: {total_before:,} (全部可训练)")

    # 应用 LoRA
    config = LoRAConfig(r=8, alpha=16, dropout=0.05, target_modules=['q_proj', 'v_proj'])
    model = apply_lora(model, config)
    trainable_after, total_after = count_parameters(model)

    print(f"\nLoRA 配置: r={config.r}, alpha={config.alpha}, scaling={config.scaling:.1f}")
    print(f"目标模块: {config.target_modules}")
    print(f"\n应用 LoRA 后:")
    print(f"  总参数:     {total_after:,}")
    print(f"  可训练参数: {trainable_after:,}")
    print(f"  比例:       {trainable_after/total_after*100:.2f}%")
    print(f"  减少:       {(1-trainable_after/trainable_before)*100:.1f}% 参数需要训练")

    # 列出 LoRA 模块
    print("\nLoRA 模块:")
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            print(f"  {name}: {module}")

    # 简单训练演示
    print("\n" + "=" * 60)
    print("LoRA 微调演示")
    print("=" * 60)

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-3
    )

    # 模拟指令微调数据
    batch_size = 4
    seq_len = 64

    for step in range(30):
        input_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)
        target = torch.randint(0, 1000, (batch_size, seq_len), device=device)

        logits = model(input_ids)
        loss = F.cross_entropy(logits.view(-1, 1000), target.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print(f"  Step {step:3d}: loss = {loss.item():.4f}")

    # 权重合并演示
    print("\n" + "=" * 60)
    print("LoRA 权重合并 (用于推理)")
    print("=" * 60)

    test_input = torch.randint(0, 1000, (1, 32), device=device)

    with torch.no_grad():
        out_lora = model(test_input)

    # 合并权重
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            parent_name = '.'.join(name.split('.')[:-1])
            child_name = name.split('.')[-1]
            parent = model.get_submodule(parent_name) if parent_name else model
            setattr(parent, child_name, module.merge_weights().to(device))

    with torch.no_grad():
        out_merged = model(test_input)

    diff = (out_lora - out_merged).abs().max().item()
    print(f"  合并前后输出最大误差: {diff:.2e}")
    print(f"  验证: {'通过' if diff < 1e-5 else '失败'} — 合并不影响输出")

    trainable_merged, total_merged = count_parameters(model)
    print(f"\n  合并后参数: {total_merged:,} (无 LoRA 额外参数)")
    print(f"  推理时无额外计算开销")

    # 不同 rank 的效果
    print("\n" + "=" * 60)
    print("不同 rank 的参数量对比")
    print("=" * 60)
    print(f"{'rank':>6s} | {'可训练参数':>12s} | {'比例':>8s} | {'ΔW 理论容量':>14s}")
    print("-" * 50)
    for r in [1, 2, 4, 8, 16, 32, 64]:
        cfg = LoRAConfig(r=r, target_modules=['q_proj', 'v_proj'])
        tmp_model = SimpleTransformer(d_model=256, n_heads=4, n_layers=4).to(device)
        tmp_model = apply_lora(tmp_model, cfg)
        t, total = count_parameters(tmp_model)
        capacity = r * 256 * 2 * 4  # r × d × 2 modules × 4 layers
        print(f"{r:6d} | {t:>12,} | {t/total*100:>6.2f}% | {capacity:>14,}")


if __name__ == '__main__':
    main()
