开源代码:https://ai.gitcode.com/hf_mirrors/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py

这是快速AI编程完成,仅供展示,未参考上下代码块,以及全局信息不足,不适用于应用,只就编程方式,提供参考和学习。

# ============================================================
# DeepSeek 矩阵正则化重构版 v6 (最终修正版)
# 修正:RoPE旋转计算、padding mask形状匹配、因果掩码向量化、缓存类型兼容
# ============================================================

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Union


# ============================================================
# 第一层:配置池 (ConfigPool)
# ============================================================
class DeepseekConfig:
    def __init__(
        self,
        vocab_size: int = 102400,
        hidden_size: int = 2048,
        num_attention_heads: int = 16,
        num_key_value_heads: int = 4,
        num_hidden_layers: int = 27,
        intermediate_size: int = 5632,
        moe_intermediate_size: int = 1536,
        n_routed_experts: int = 64,
        n_shared_experts: int = 2,
        num_experts_per_tok: int = 8,
        norm_topk_prob: bool = True,
        scoring_func: str = 'softmax',
        rms_norm_eps: float = 1e-6,
        rope_theta: float = 10000.0,
        max_position_embeddings: int = 163840,
        attention_dropout: float = 0.0,
        initializer_range: float = 0.02,
        aux_loss_alpha: float = 0.001,
        seq_aux: bool = True,
        first_k_dense_replace: int = 3,
        moe_layer_freq: int = 1,
        use_cache: bool = True,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.moe_intermediate_size = moe_intermediate_size
        self.n_routed_experts = n_routed_experts
        self.n_shared_experts = n_shared_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.norm_topk_prob = norm_topk_prob
        self.scoring_func = scoring_func
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.attention_dropout = attention_dropout
        self.initializer_range = initializer_range
        self.aux_loss_alpha = aux_loss_alpha
        self.seq_aux = seq_aux
        self.first_k_dense_replace = first_k_dense_replace
        self.moe_layer_freq = moe_layer_freq
        self.use_cache = use_cache

        # 派生维度 + 刚性校验
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                f"hidden_size ({hidden_size}) 必须被 num_attention_heads ({num_attention_heads}) 整除"
            )
        if num_attention_heads % num_key_value_heads != 0:
            raise ValueError(
                f"num_attention_heads ({num_attention_heads}) 必须被 num_key_value_heads ({num_key_value_heads}) 整除"
            )
        self.head_dim = hidden_size // num_attention_heads
        self.num_key_value_groups = num_attention_heads // num_key_value_heads


# ============================================================
# 第二层:基础机床库 (Pure Machines)
# ============================================================

class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        dtype = x.dtype
        x_fp32 = x.float()
        return (x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight).to(dtype)


def precompute_rope_freqs(dim: int, max_pos: int, theta: float = 10000.0):
    """生成 RoPE 频率表。返回 [max_pos, dim/2],与奇偶配对旋转匹配。"""
    inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_pos, dtype=inv_freq.dtype)
    freqs = torch.outer(t, inv_freq)
    return freqs.cos(), freqs.sin()


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
    """S-06修正:正确实现奇偶配对旋转。
    x: [B, H, L, D] 或 [B, Hk, L, D]
    cos/sin: [max_pos, D/2] (每个频率对应一对维度)
    position_ids: [B, L]
    """
    # 按实际位置取频率值 → [B, L, D/2] → [B, 1, L, D/2]
    cos = cos[position_ids].unsqueeze(1).to(x.dtype)
    sin = sin[position_ids].unsqueeze(1).to(x.dtype)

    # 奇偶拆分:实部(偶数索引)、虚部(奇数索引)→ 各 [B, H, L, D/2]
    x_real = x[..., ::2]
    x_imag = x[..., 1::2]

    # 标准复数旋转:a' = a*cos - b*sin,  b' = a*sin + b*cos
    out_real = x_real * cos - x_imag * sin
    out_imag = x_real * sin + x_imag * cos

    # 交错合并回原始维度顺序
    out = torch.stack([out_real, out_imag], dim=-1).flatten(-2)
    return out


def qkv_proj(x: torch.Tensor, q_w: torch.Tensor, k_w: torch.Tensor, v_w: torch.Tensor):
    return F.linear(x, q_w), F.linear(x, k_w), F.linear(x, v_w)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    if n_rep == 1:
        return hidden_states
    batch, num_kv_heads, slen, head_dim = hidden_states.shape
    return hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim).reshape(batch, num_kv_heads * n_rep, slen, head_dim)


def _reshape_for_attention(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
    batch, seq_len, _ = x.shape
    return x.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)


def attention_core(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                   attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    """注意力核心机床。mask支持2D/3D/4D格式,内部自动广播。"""
    if q.size(-1) != k.size(-1) or k.size(-1) != v.size(-1):
        raise ValueError(f"Q/K/V head_dim 不一致: {q.size(-1)}, {k.size(-1)}, {v.size(-1)}")
    if q.size(0) != k.size(0) or k.size(0) != v.size(0):
        raise ValueError(f"batch_size 不一致: {q.size(0)}, {k.size(0)}, {v.size(0)}")
    if q.size(1) % k.size(1) != 0:
        raise ValueError(f"Q头数({q.size(1)})必须被KV头数({k.size(1)})整除")

    num_kv_groups = q.size(1) // k.size(1)
    scale = 1.0 / math.sqrt(q.size(-1))

    if k.size(1) != q.size(1):
        k = repeat_kv(k, num_kv_groups)
        v = repeat_kv(v, num_kv_groups)

    attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale

    # 兼容 2D/3D/4D mask 格式
    if attn_mask is not None:
        if attn_mask.dim() == 2:          # [Lq, Lk] → 自动广播
            pass
        elif attn_mask.dim() == 3:        # [B, Lq, Lk] → [B, 1, Lq, Lk]
            attn_mask = attn_mask.unsqueeze(1)
        elif attn_mask.dim() == 4:        # [B, H, Lq, Lk] 或 [B, 1, Lq, Lk]
            if attn_mask.size(1) not in (1, q.size(1)):
                raise ValueError(f"mask head ({attn_mask.size(1)}) 与 Q head ({q.size(1)}) 不兼容")
        else:
            raise ValueError(f"mask维度 ({attn_mask.dim()}) 不支持,仅支持2D/3D/4D")
        attn_weights = attn_weights + attn_mask

    attn_weights = F.softmax(attn_weights.float(), dim=-1).to(q.dtype)
    return torch.matmul(attn_weights, v)


def output_proj(attn_out: torch.Tensor, o_w: torch.Tensor) -> torch.Tensor:
    return F.linear(attn_out, o_w)


class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


def moe_gate(x: torch.Tensor, gate_weight: torch.Tensor, top_k: int,
             scoring_func: str = 'softmax', norm_topk_prob: bool = True):
    num_experts = gate_weight.size(0)
    if top_k < 1 or top_k > num_experts:
        raise ValueError(f"top_k ({top_k}) 必须在 [1, {num_experts}] 范围内")
    if scoring_func not in ('softmax', 'sigmoid'):
        raise ValueError(f"scoring_func 仅支持 'softmax' 或 'sigmoid',收到 '{scoring_func}'")

    logits = F.linear(x, gate_weight)
    scores = logits.softmax(dim=-1) if scoring_func == 'softmax' else logits.sigmoid()
    topk_weight, topk_idx = torch.topk(scores, k=top_k, dim=-1, sorted=False)
    if norm_topk_prob and top_k > 1:
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True).clamp(min=1e-20)
    return topk_idx, topk_weight


def moe_experts_flat(x: torch.Tensor, topk_idx: torch.Tensor, topk_weight: torch.Tensor,
                     experts: nn.ModuleList, shared_expert: Optional[nn.Module] = None) -> torch.Tensor:
    if x.numel() == 0:
        return torch.zeros_like(x)

    batch_tokens, dim = x.shape
    num_experts = len(experts)
    if topk_idx.numel() > 0:
        if topk_idx.min() < 0 or topk_idx.max() >= num_experts:
            raise ValueError(f"topk_idx 越界: [{topk_idx.min()}, {topk_idx.max()}],专家范围 [0, {num_experts})")

    expert_mask = torch.zeros(batch_tokens, num_experts, device=x.device).scatter_(1, topk_idx.long(), 1.0)
    weight_mat = torch.zeros(batch_tokens, num_experts, device=x.device).scatter_(1, topk_idx.long(), topk_weight.float())

    out = torch.zeros(batch_tokens, dim, device=x.device, dtype=x.dtype)
    for i, expert in enumerate(experts):
        sel = expert_mask[:, i].bool()
        if not sel.any():
            continue
        out[sel] += expert(x[sel]) * weight_mat[sel, i:i+1]

    if shared_expert is not None:
        out = out + shared_expert(x)
    return out


def generate_causal_mask(q_len: int, kv_len: int, device: torch.device) -> torch.Tensor:
    """O-01优化:向量化实现因果掩码生成,兼容 KV 缓存场景。
    q_len: query 长度(当前步 token 数)
    kv_len: key 总长度(缓存历史 + 当前步)
    返回形状: [1, 1, q_len, kv_len]
    """
    if q_len <= 0 or kv_len <= 0:
        return torch.zeros(1, 1, 0, 0, device=device)
    # 向量化:用 triu 一次性生成上三角掩码
    # diagonal = past_len + 1 表示:query i 可以看到 key [0, past_len+i]
    past_len = kv_len - q_len
    mask = torch.triu(torch.full((q_len, kv_len), float('-inf'), device=device), diagonal=past_len + 1)
    return mask.unsqueeze(0).unsqueeze(0)


def compute_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    if logits.size(0) != labels.size(0) or logits.size(1) != labels.size(1):
        raise ValueError(f"logits形状 ({logits.shape}) 与 labels形状 ({labels.shape}) 不匹配")

    logits_fp32 = logits.float()
    shift_logits = logits_fp32[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    if not torch.isfinite(shift_logits).all():
        raise ValueError("logits 包含 NaN 或 Inf")

    if (shift_labels == -100).all():
        return shift_logits.sum() * 0.0

    return F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def init_parameters(module: nn.Module):
    for name, param in module.named_parameters():
        if param.ndim >= 2:
            nn.init.normal_(param, mean=0.0, std=0.02)


def _get_cache_length(past_key_values, layer_idx: int = 0) -> int:
    """O-02优化:统一获取缓存长度的工具函数,兼容多种缓存格式。"""
    if past_key_values is None:
        return 0
    # Cache 对象格式(HuggingFace 标准)
    if hasattr(past_key_values, 'get_seq_length'):
        return past_key_values.get_seq_length()
    # 列表/元组格式:每层一个 (k, v) 对
    if isinstance(past_key_values, (list, tuple)):
        if len(past_key_values) == 0:
            return 0
        layer_cache = past_key_values[layer_idx]
        if isinstance(layer_cache, (list, tuple)) and len(layer_cache) >= 2:
            return layer_cache[0].size(2)
    return 0


# ============================================================
# 第三层:数据流水渠 (Logistics Pipeline)
# ============================================================

class DeepseekDecoderLayer(nn.Module):
    def __init__(self, config: DeepseekConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.idx = layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = config.head_dim

        # 权重池
        self.q_proj_w = nn.Parameter(torch.empty(config.num_attention_heads * config.head_dim, config.hidden_size))
        self.k_proj_w = nn.Parameter(torch.empty(config.num_key_value_heads * config.head_dim, config.hidden_size))
        self.v_proj_w = nn.Parameter(torch.empty(config.num_key_value_heads * config.head_dim, config.hidden_size))
        self.o_proj_w = nn.Parameter(torch.empty(config.hidden_size, config.num_attention_heads * config.head_dim))

        is_moe = (config.n_routed_experts is not None and
                  layer_idx >= config.first_k_dense_replace and
                  layer_idx % config.moe_layer_freq == 0)

        if is_moe:
            self.moe_gate_weight = nn.Parameter(torch.empty(config.n_routed_experts, config.hidden_size))
            self.experts = nn.ModuleList([
                SwiGLU(config.hidden_size, config.moe_intermediate_size) for _ in range(config.n_routed_experts)
            ])
            self.shared_expert = SwiGLU(config.hidden_size, config.moe_intermediate_size * config.n_shared_experts)
            self.ffn_pipeline = self._ffn_moe
        else:
            self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)
            self.ffn_pipeline = self._ffn_dense

        self.input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        init_parameters(self)

    def _ffn_dense(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.mlp(hidden_states)

    def _ffn_moe(self, hidden_states: torch.Tensor) -> torch.Tensor:
        bsz, q_len, _ = hidden_states.shape
        flat_x = hidden_states.view(-1, self.hidden_size)
        topk_idx, topk_weight = moe_gate(flat_x, self.moe_gate_weight,
                                         self.config.num_experts_per_tok,
                                         self.config.scoring_func,
                                         self.config.norm_topk_prob)
        out = moe_experts_flat(flat_x, topk_idx, topk_weight, self.experts, self.shared_expert)
        return out.view(bsz, q_len, -1)

    def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor,
                freqs_cos: torch.Tensor, freqs_sin: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                past_key_value=None) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_norm(hidden_states)

        q, k, v = qkv_proj(hidden_states, self.q_proj_w, self.k_proj_w, self.v_proj_w)
        bsz, q_len, _ = q.shape
        q = _reshape_for_attention(q, self.num_heads, self.head_dim)
        k = _reshape_for_attention(k, self.num_kv_heads, self.head_dim)
        v = _reshape_for_attention(v, self.num_kv_heads, self.head_dim)

        q = apply_rope(q, freqs_cos, freqs_sin, position_ids)
        k = apply_rope(k, freqs_cos, freqs_sin, position_ids)

        if past_key_value is not None:
            k, v = past_key_value.update(k, v, self.idx)

        attn_out = attention_core(q, k, v, attn_mask)
        attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        hidden_states = residual + output_proj(attn_out, self.o_proj_w)

        residual = hidden_states
        hidden_states = self.post_attn_norm(hidden_states)
        hidden_states = self.ffn_pipeline(hidden_states)
        return residual + hidden_states


# ============================================================
# 模型外壳 (全局调度器)
# ============================================================

class DeepseekModel(nn.Module):
    def __init__(self, config: DeepseekConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([DeepseekDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        cos, sin = precompute_rope_freqs(config.head_dim, config.max_position_embeddings, config.rope_theta)
        self.register_buffer('freqs_cos', cos, persistent=False)
        self.register_buffer('freqs_sin', sin, persistent=False)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[Union[List, Tuple, object]] = None) -> torch.Tensor:
        batch, seq_len = input_ids.shape

        # 计算缓存历史长度
        past_length = _get_cache_length(past_key_values)

        # 自动生成 position_ids,加上缓存偏移
        if position_ids is None:
            position_ids = torch.arange(past_length, past_length + seq_len,
                                        dtype=torch.long, device=input_ids.device).unsqueeze(0)

        hidden_states = self.embed_tokens(input_ids)

        # 生成因果掩码(考虑缓存长度)
        kv_len = past_length + seq_len
        causal_mask = generate_causal_mask(seq_len, kv_len, input_ids.device)

        # F-14修正:padding mask 扩展到 kv_len 长度,只在当前步 key 位置应用
        attn_mask = causal_mask
        if attention_mask is not None and attention_mask.dim() == 2:
            # attention_mask: [B, L] → padding 部分为 0,有效部分为 1
            # 转换为 [B, 1, 1, L] 的负无穷掩码
            padding_mask = (1.0 - attention_mask[:, None, None, :].float()) * float('-inf')
            # 扩展到 kv_len 长度:前面 past_length 个位置都是有效的(历史key),只在最后 seq_len 个位置应用 padding
            full_padding = torch.zeros(batch, 1, 1, kv_len, device=input_ids.device)
            full_padding[..., -seq_len:] = padding_mask
            attn_mask = causal_mask + full_padding

        for layer in self.layers:
            hidden_states = layer(hidden_states, position_ids, self.freqs_cos, self.freqs_sin,
                                 attn_mask, past_key_values)

        return self.norm(hidden_states)


class DeepseekForCausalLM(nn.Module):
    def __init__(self, config: DeepseekConfig):
        super().__init__()
        self.model = DeepseekModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, **kwargs):
        hidden_states = self.model(input_ids, **kwargs)
        logits = self.lm_head(hidden_states)
        loss = None
        if labels is not None:
            loss = compute_loss(logits, labels)
        return {'loss': loss, 'logits': logits} if loss is not None else {'logits': logits}
Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐