"""
Mixture of Experts (MoE) 层实现 — 含 Top-K 路由与负载均衡损失
对应教程第 5 章
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class Expert(nn.Module):
    """单个专家：标准 FFN (SwiGLU 变体)"""

    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TopKRouter(nn.Module):
    """Top-K 路由器，含负载均衡辅助损失"""

    def __init__(self, d_model: int, n_experts: int, top_k: int):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.gate = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        返回: (router_logits, topk_indices, topk_weights)
        router_logits: [B*T, n_experts] 用于计算辅助损失
        topk_indices:  [B*T, top_k]     选中的专家索引
        topk_weights:  [B*T, top_k]     归一化后的路由权重
        """
        logits = self.gate(x)
        topk_weights, topk_indices = torch.topk(logits, self.top_k, dim=-1)
        topk_weights = F.softmax(topk_weights, dim=-1)
        return logits, topk_indices, topk_weights


class MoELayer(nn.Module):
    """Mixture of Experts 层

    支持两种路由策略:
    - Top-K (GShard / Mixtral 风格)
    - Switch Transformer (Top-1)

    包含负载均衡损失，防止路由坍缩到少数专家
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_experts: int = 8,
        top_k: int = 2,
        n_shared_experts: int = 0,
        aux_loss_weight: float = 0.01,
    ):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.n_shared_experts = n_shared_experts
        self.aux_loss_weight = aux_loss_weight

        self.router = TopKRouter(d_model, n_experts, top_k)
        self.experts = nn.ModuleList([Expert(d_model, d_ff) for _ in range(n_experts)])

        if n_shared_experts > 0:
            self.shared_experts = nn.ModuleList(
                [Expert(d_model, d_ff) for _ in range(n_shared_experts)]
            )
        else:
            self.shared_experts = None

    def load_balancing_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
        """负载均衡辅助损失 (Switch Transformer 公式)

        L_aux = N · Σ_i (f_i · P_i)

        其中:
          f_i = 分配到专家 i 的 token 比例
          P_i = 路由器分配给专家 i 的平均概率
          N = 专家数量

        最优值: 每个专家均匀分配时 L_aux = 1
        """
        probs = F.softmax(router_logits, dim=-1)  # [B*T, n_experts]

        # f_i: 每个专家被选中的 token 比例
        topk_indices = torch.topk(router_logits, self.top_k, dim=-1).indices
        expert_mask = F.one_hot(topk_indices, self.n_experts).sum(dim=1)  # [B*T, n_experts]
        f = expert_mask.float().mean(dim=0)  # [n_experts]

        # P_i: 平均路由概率
        P = probs.mean(dim=0)  # [n_experts]

        return self.n_experts * (f * P).sum()

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        返回: (output, aux_loss)
        """
        B, T, C = x.shape
        x_flat = x.view(-1, C)  # [B*T, C]

        router_logits, topk_indices, topk_weights = self.router(x_flat)

        # 计算负载均衡损失
        aux_loss = self.aux_loss_weight * self.load_balancing_loss(router_logits)

        # 逐 token 分发到选中的专家
        output = torch.zeros_like(x_flat)
        for i in range(self.n_experts):
            # 找到路由到专家 i 的 token
            expert_mask = (topk_indices == i).any(dim=-1)  # [B*T]
            if not expert_mask.any():
                continue

            expert_input = x_flat[expert_mask]
            expert_output = self.experts[i](expert_input)

            # 按路由权重加权
            weight_mask = (topk_indices[expert_mask] == i).float()
            weights = (topk_weights[expert_mask] * weight_mask).sum(dim=-1, keepdim=True)
            output[expert_mask] += weights * expert_output

        # 加入共享专家输出 (DeepSeek-MoE 风格)
        if self.shared_experts is not None:
            shared_out = sum(expert(x_flat) for expert in self.shared_experts)
            output = output + shared_out / self.n_shared_experts

        return output.view(B, T, C), aux_loss


class MoETransformerBlock(nn.Module):
    """带 MoE FFN 的 Transformer 块"""

    def __init__(self, d_model: int, n_heads: int, d_ff: int,
                 n_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.norm1 = nn.RMSNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm2 = nn.RMSNorm(d_model)
        self.moe = MoELayer(d_model, d_ff, n_experts, top_k)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        h = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        moe_out, aux_loss = self.moe(self.norm2(h))
        return h + moe_out, aux_loss


def analyze_routing(moe: MoELayer, x: torch.Tensor):
    """分析路由分布"""
    B, T, C = x.shape
    x_flat = x.view(-1, C)

    with torch.no_grad():
        logits, indices, weights = moe.router(x_flat)

    # 统计每个专家被选中的次数
    expert_counts = torch.zeros(moe.n_experts)
    for i in range(moe.n_experts):
        expert_counts[i] = (indices == i).sum().item()

    total = indices.numel()
    print("\n路由分布分析:")
    print("-" * 40)
    for i in range(moe.n_experts):
        pct = expert_counts[i] / total * 100
        bar = "█" * int(pct * 2)
        print(f"  Expert {i}: {expert_counts[i]:5.0f} ({pct:5.1f}%) {bar}")

    # 负载均衡指标
    ideal = total / moe.n_experts
    imbalance = (expert_counts.max() - expert_counts.min()) / ideal * 100
    print(f"\n  理想分配: {ideal:.0f} tokens/expert")
    print(f"  不均衡度: {imbalance:.1f}%")

    # 路由熵
    probs = F.softmax(logits, dim=-1).mean(dim=0)
    entropy = -(probs * probs.log()).sum().item()
    max_entropy = math.log(moe.n_experts)
    print(f"  路由熵: {entropy:.3f} / {max_entropy:.3f} (越高越均匀)")


import math

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"设备: {device}\n")

    d_model = 512
    d_ff = 1024
    n_experts = 8
    top_k = 2
    batch_size = 2
    seq_len = 256

    print("=" * 60)
    print("MoE (Mixture of Experts) 层演示")
    print("=" * 60)

    # 标准 MoE
    moe = MoELayer(d_model, d_ff, n_experts, top_k).to(device)
    x = torch.randn(batch_size, seq_len, d_model, device=device)

    output, aux_loss = moe(x)
    print(f"\n配置: {n_experts} experts, top-{top_k}")
    print(f"输入:  {x.shape}")
    print(f"输出:  {output.shape}")
    print(f"辅助损失: {aux_loss.item():.4f}")

    # 参数量对比
    dense_params = 2 * d_model * d_ff * 3  # SwiGLU: w1, w2, w3
    moe_params = sum(p.numel() for p in moe.parameters())
    active_params = dense_params * top_k + sum(p.numel() for p in moe.router.parameters())
    print(f"\n参数量对比:")
    print(f"  Dense FFN:        {dense_params:>10,}")
    print(f"  MoE 总参数:       {moe_params:>10,}")
    print(f"  MoE 激活参数:     {active_params:>10,}")
    print(f"  总参数膨胀:       {moe_params / dense_params:.1f}x")
    print(f"  每 token 计算量:  {active_params / dense_params:.1f}x (vs Dense)")

    # 路由分析
    analyze_routing(moe, x)

    # DeepSeek-MoE 风格 (含共享专家)
    print("\n" + "=" * 60)
    print("DeepSeek-MoE 风格 (含共享专家)")
    print("=" * 60)

    moe_ds = MoELayer(d_model, d_ff, n_experts, top_k,
                       n_shared_experts=2, aux_loss_weight=0.01).to(device)
    output_ds, aux_loss_ds = moe_ds(x)
    print(f"\n配置: {n_experts} routed + 2 shared experts, top-{top_k}")
    print(f"输出:  {output_ds.shape}")
    print(f"辅助损失: {aux_loss_ds.item():.4f}")
    print(f"总参数: {sum(p.numel() for p in moe_ds.parameters()):,}")

    # 训练演示
    print("\n" + "=" * 60)
    print("MoE 训练演示 (简单回归任务)")
    print("=" * 60)

    block = MoETransformerBlock(d_model, 8, d_ff, n_experts, top_k).to(device)
    optimizer = torch.optim.AdamW(block.parameters(), lr=1e-3)

    target = torch.randn(batch_size, seq_len, d_model, device=device)

    for step in range(50):
        output, aux_loss = block(x)
        task_loss = F.mse_loss(output, target)
        total_loss = task_loss + aux_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print(f"  Step {step:3d}: task_loss={task_loss.item():.4f}, "
                  f"aux_loss={aux_loss.item():.4f}, total={total_loss.item():.4f}")

    print("\n训练后路由分布:")
    analyze_routing(block.moe, x)


if __name__ == '__main__':
    main()
