Gemma 4 is Google DeepMind's fourth-generation open model family, released under Apache 2.0. It is the first generation where the architecture itself -- not just scale or training data -- becomes the primary axis of differentiation across variants. The family spans four models (31B, 26B-A4B, E4B, E2B) that share a common structural DNA but diverge in how they allocate capacity, reflecting a design philosophy where the same building blocks are composed differently for server, edge, and on-device deployment. The smaller E-series variants (E2B, E4B) support all four modalities -- text, image, video, and audio -- making them true any-to-any models.
Every Gemma 4 model is a decoder-only transformer with hybrid sliding/full attention. Layers alternate between local sliding-window attention and global full attention in a fixed ratio (5:1 or 4:1). This is not new -- Gemma 3 and Mistral pioneered it. What is new is that Gemma 4 makes these two layer types structurally different:
head_dim=256, more KV heads, standard RoPE (theta=10K), full rotation. Optimized for fine-grained local patterns within a window of 512-1024 tokens.
head_dim=512, fewer KV heads, p-RoPE (theta=1M, partial=0.25), K=V weight sharing. Optimized for long-range semantic attention across the entire 256K context.
This dual-config approach means that every 6th layer operates with a completely different attention geometry -- wider heads, fewer KV groups, and only 25% of dimensions carrying positional information. The transition from Gemma 3 to Gemma 4 can be summarized as moving from "same attention, different window" to "different attention, different window."
Sliding window layers can only attend to the last 512-1024 tokens directly. Information from beyond the window propagates indirectly through hidden states across layers -- earlier tokens influence the current window's hidden states, which then influence the next layer, like a chain of handoffs. The periodic full attention layers (every 5th or 6th) break this indirection by attending to the entire sequence at once, ensuring long-range information isn't lost. Gemma 4 also enforces that the final layer is always global attention, guaranteeing full-context awareness in the last representation.
The family spans four deployment regimes, each making distinct architectural trade-offs within this shared framework:
hidden=2,112) in parallel with a 128-expert MoE (top-8, hidden=704 each) -- their outputs are summed. This is architecturally unusual: most MoE models replace the FFN with experts, but Gemma 4 keeps both, giving each layer always-on dense capacity plus sparse expert specialization.Compared to Gemma 3 27B, the architectural delta is substantial. The table below focuses on the 31B (the most direct comparison), but most innovations propagate to all three variants:
tanh(x/30)*30, preventing extreme values during generation.head_dim=512 (vs 256 for sliding), compensating for more aggressive KV sharing (4 KV heads vs 16). Gemma 3 used a uniform head_dim=128 for all layers. The wider heads give global attention higher per-head capacity to capture long-range dependencies.| Parameter | Gemma 4 31B | Gemma 4 26B-A4B | Gemma 4 E2B | Gemma 3 27B |
|---|---|---|---|---|
| Total Params | ~31B | ~26B | ~5B | 27B |
| Active Params | 31B | ~4B | ~2B eff | 27B |
| Context | 256K | 256K | 128K | 128K |
| Hidden Size | 5,376 | 2,816 | 1,536 | 5,376 |
| Layers | 60 | 30 | 35 | 62 |
| Layer Pattern | 5:1 sliding/full | 5:1 sliding/full | 4:1 sliding/full | 5:1 sliding/full |
| Attention | GQA | GQA | GQA | GQA |
| Q Heads | 32 | 16 | 8 | 32 |
| KV Heads (sliding) | 16 | 8 | 1 | 16 |
| KV Heads (full) | 4 | 2 | 1 | 16 |
| Head Dim (sliding) | 256 | 256 | 256 | 128 |
| Head Dim (full) | 512 | 512 | 512 | 128 |
| FFN Type | GeGLU | MoE + GeGLU | GeGLU | GeGLU |
| FFN Hidden | 21,504 | 2,112 + MoE(704) | 6,144 | 21,504 |
| MoE Experts | -- | 128 (top-8) | -- | -- |
| Vocab | 262,144 | 262,144 | 262,144 | 262,208 |
| RoPE (sliding) | theta=10K | theta=10K | theta=10K | theta=10K |
| RoPE (full) | theta=1M, 25% partial | theta=1M, 25% partial | theta=1M, 25% partial | theta=1M, linear 8x |
| QK Norm | RMSNorm | RMSNorm | RMSNorm | RMSNorm |
| V Norm | RMSNorm (no scale) | RMSNorm (no scale) | RMSNorm (no scale) | -- |
| K=V Sharing | yes (full layers) | yes (full layers) | -- | -- |
| KV Shared Layers | -- | -- | 20 of 35 | -- |
| Per-Layer Input | -- | -- | dim=256 | -- |
| Logit Cap | 30.0 | 30.0 | 30.0 | -- |
| Vision Encoder | ViT 27L, d=1152 | ViT 27L, d=1152 | ViT 16L, d=768 | SigLIP 27L, d=1152 |
| Audio Encoder | -- | -- | Conformer 12L | -- |
| Tie Weights | yes | yes | yes | yes |
Source: Hugging Face blog. Instruction-tuned variants. E4B included from blog (architecture page covers E2B, 26B-A4B, 31B in detail).
| Benchmark | 31B | 26B-A4B | E4B | E2B |
|---|---|---|---|---|
| Arena Score (LMArena est.) | 1452 | 1441 | -- | -- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% |
| AIME 2026 (no tools) | 89.2% | 88.3% | 42.5% | 37.5% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 |
| MMMU Pro (vision) | 76.9% | 73.8% | 52.6% | 44.2% |
| MRCR v2 8-needle 128K | 66.4% | 44.1% | 25.4% | 19.1% |
| CoVoST (audio) | -- | -- | 35.54 | 33.47 |
| FLEURS (audio) | -- | -- | 0.08 | 0.09 |
The 26B-A4B MoE achieves 98% of the 31B's Arena score with only 4B active parameters -- a 7.5x compute reduction. The E2B achieves competitive audio scores despite being a 2.3B-effective model.
Computed from: Q=dim×q_heads×head_dim, K/V=dim×kv_heads×head_dim, O=q_heads×head_dim×dim. GeGLU FFN=3×dim×hidden (gate+up+down). V proj=0 when K=V. MoE: per_expert=3×dim×expert_hidden, router=dim×num_experts.
| Component | Sliding Block | Full Block | Formula |
|---|---|---|---|
| Q proj | 44.0M | 88.1M | 5376 × 32 × dh |
| K proj | 22.0M | 11.0M | 5376 × kv × dh |
| V proj | 22.0M | 0 (K=V) | eliminated |
| O proj | 44.0M | 88.1M | q×dh × 5376 |
| Attention total | 132.1M | 187.2M | |
| GeGLU FFN | 346.8M | 346.8M | 3 × 5376 × 21504 |
| Block total | 478.9M | 534.0M |
| Component | Sliding Block | Full Block | Formula |
|---|---|---|---|
| Q proj | 11.5M | 23.1M | 2816 × 16 × dh |
| K proj | 5.8M | 2.9M | 2816 × kv × dh |
| V proj | 5.8M | 0 (K=V) | |
| O proj | 11.5M | 23.1M | |
| Attention total | 34.6M | 49.0M | |
| Dense GeGLU | 17.8M | 17.8M | 3 × 2816 × 2112 |
| MoE experts (128×) | 761.3M | 761.3M | 128 × 3 × 2816 × 704 |
| MoE active (top-8) | 47.6M | 47.6M | 8 × 5.9M/expert |
| Router | 0.4M | 0.4M | 2816 × 128 |
| FFN total capacity | 779.5M | 779.5M | dense + all experts + router |
| FFN active/token | 65.8M | 65.8M | dense + top-8 + router |
| Block total capacity | 814.1M | 828.5M | |
| Block active/token | 100.4M | 114.8M |
| Component | Sliding Block | Full Block | Formula |
|---|---|---|---|
| Q proj | 3.1M | 6.3M | 1536 × 8 × dh |
| K proj | 0.4M | 0.8M | 1536 × 1 × dh |
| V proj | 0.4M | 0.8M | no K=V on E2B |
| O proj | 3.1M | 6.3M | |
| Attention total | 7.1M | 14.2M | |
| GeGLU FFN | 28.3M | 28.3M | 3 × 1536 × 6144 |
| GeGLU FFN (2× wide) | 56.6M | 56.6M | KV-shared layers only |
| Block total | 35.4M | 42.5M | standard layers |
| Block total (2× wide) | 63.7M | 70.8M | KV-shared layers |
| Component | Per Block | Formula |
|---|---|---|
| Q proj | 22.0M | 5376 × 32 × 128 |
| K proj | 11.0M | 5376 × 16 × 128 |
| V proj | 11.0M | 5376 × 16 × 128 |
| O proj | 22.0M | 32 × 128 × 5376 |
| Attention total | 66.1M | |
| GeGLU FFN | 346.8M | 3 × 5376 × 21504 |
| Block total | 412.9M |
Estimates: weight memory + KV cache (FP16). KV formula: 2 × kv_heads × head_dim × 2 bytes per token per layer, summed across sliding and full layers. E2B benefits from KV sharing (20 of 35 layers reuse cache).
| Precision | 31B | 26B-A4B | E2B | Gemma 3 27B |
|---|---|---|---|---|
| FP16 | 62.0 GB | 52.0 GB | 10.2 GB | 54.0 GB |
| INT8 | 31.0 GB | 26.0 GB | 5.1 GB | 27.0 GB |
| INT4 | 15.5 GB | 13.0 GB | 2.6 GB | 13.5 GB |
| Context | 31B | 26B-A4B | E2B | Gemma 3 27B |
|---|---|---|---|---|
| 4K | 3.4 GB | 0.9 GB | 0.07 GB | 1.9 GB |
| 32K | 27.5 GB | 6.9 GB | 0.6 GB | 15.5 GB |
| 128K | 110 GB | 27.5 GB | 2.3 GB | 62 GB |
| 256K | 220 GB | 55 GB | -- | -- |
31B: 50 sliding layers (16 kv_heads × d=256) + 10 full layers (4 kv_heads × d=512). 26B-A4B: 25 sliding (8 × 256) + 5 full (2 × 512). E2B: 15 unique KV layers after sharing (1 × 256/512). Gemma 3: 62 uniform layers (16 × 128).
| Scenario | 31B | 26B-A4B | E2B | Gemma 3 27B |
|---|---|---|---|---|
| FP16, 4K ctx | 65.4 GB 1× H100 80GB |
52.9 GB 1× H100 80GB |
10.3 GB RTX 4070 12GB |
55.9 GB 1× H100 80GB |
| INT4, 4K ctx | 18.9 GB RTX 4090 24GB |
13.9 GB RTX 4070 Ti 16GB |
2.7 GB Any GPU |
15.4 GB RTX 4090 24GB |
| INT4, 128K ctx | 125.5 GB 2× H100 160GB |
40.5 GB A6000 48GB |
4.9 GB RTX 4060 8GB |
75.5 GB 1× H100 80GB |
The 26B-A4B MoE loads all 26B parameters into VRAM (all experts must be resident), but its KV cache is 4x smaller than the 31B due to fewer KV heads. At INT4 + 128K context, it fits in a single A6000 -- impossible for the 31B. E2B with KV sharing runs full 128K context on a laptop GPU.
head_dim=256 with more KV heads, while full attention layers use head_dim=512 with fewer KV heads. This allows efficient local attention with fine-grained heads and powerful global attention with large heads, all within the same model.
v_proj=None). The key tensor is cloned as the value before normalization, then K and V diverge: K gets k_norm (learned scale) + RoPE, V gets v_norm (no scale, no RoPE). So K and V start identical but develop different representations through their separate norm paths. Present in 31B and 26B-A4B.
partial=0.25) with theta=1M, rather than the linear scaling used in Gemma 3. This provides better long-context extrapolation by leaving 75% of dimensions position-independent.
hidden=2,112) in parallel with 128-expert routed MoE (each expert hidden=704, top-8 routing). The shared dense FFN is 3× larger than individual experts (2112/704), ensuring always-on general capacity. The outputs are summed and scaled by 1/sqrt(2).
vocab × layers×256 table) and a context-aware component (linear projection from main embeddings, RMSNorm'd). These are summed and scaled by 1/√2, then each layer receives its 256-dim slice via a gated bottleneck residual. This accounts for the ~5B total vs ~2.3B effective parameter count in E2B. See the PLE Deep Dive section below for the full mechanism.
hidden=12,288 vs 6,144).
tanh(x/30) * 30 to final logits, bounding them to [-30, 30]. This prevents extreme logit values during generation without hard clipping, improving training stability. Gemma 3 does not use this technique.
theta alone (e.g. 10K → 1M) isn't enough because the lowest frequencies still eventually break.
(1-p) lowest frequency dimensions to zero -- making them position-independent (NoPE). With p=0.25, only the top 25% of frequencies are rotated (positional), while 75% become pure semantic channels that never break regardless of context length.
theta=10,000, all dims rotatedtheta=1,000,000, partial=0.25| Aspect | Gemma 3 (27B) | Gemma 4 (31B) | Impact |
|---|---|---|---|
| Sliding RoPE | theta=10K, full rotation | theta=10K, full rotation | Same -- local attention unchanged |
| Full RoPE theta | theta=1M | theta=1M | Same base frequency |
| Full RoPE scaling | linear 8x | partial=0.25 (p-RoPE) | p-RoPE replaces linear scaling |
| Rotated dims (full) | 128 of 128 (100%) | 128 of 512 (25%) | 75% dims are position-free |
| Semantic capacity | All dims position-coupled | 384 dims pure semantic | Robust long-context semantics |
| Max context | 128K | 256K | 2x context with p-RoPE |
The paper validates p-RoPE on Gemma 7B-scale models. Lower is better.
| Encoding | Wiki | PlanV2 | Properties |
|---|---|---|---|
| NoPE | 4.8594 | 6.6429 | Semantic only, no position |
| RoPE (theta=10K) | 4.4627 | 6.4429 | Standard |
| RoPE (theta=500K) | 4.4485 | 6.4593 | High theta |
| 0.25-RoPE | 4.4592 | 6.4683 | = Gemma 4's setting |
| 0.75-RoPE (inverted) | 4.4537 | 6.4562 | More rotation |
| 0.25-RoPE (full model) | 4.5302 | 6.5111 | p-RoPE on all layers |
| 0.75-RoPE (full model) | 4.4414 | 6.4422 | Best overall perplexity |
Gemma 4 uses 0.25-RoPE on full attention layers only (not all layers), combining the best of both: standard RoPE for local sliding attention + p-RoPE for global full attention.
Per-Layer Embeddings (PLE) is an architectural innovation introduced in the Gemma 4 E-series (E2B, E4B) that gives every decoder layer its own unique token representation. Instead of all layers sharing the same input embedding, each layer receives a dedicated 256-dimensional vector per token, created by combining two complementary signals: a token-identity component (direct embedding lookup from a per-layer table) and a context-aware component (learned projection from the main input embeddings). This combined signal is then injected into each layer via a gated bottleneck mechanism as a third residual block after attention and FFN.
nn.Embedding of shape (vocab_size, num_layers × 256) maps each token to a unique vector for every layer. This is a direct lookup — no context from surrounding tokens. Scaled by √256 = 16. For E2B with 35 layers, each token maps to a 8,960-dim vector that is reshaped into (35, 256). This table alone accounts for ~2B parameters — the majority of the PLE cost.
nn.Linear(hidden_size, num_layers × 256) projects the main inputs_embeds (which already contain vision/audio soft tokens for multimodal inputs) into per-layer space. Scaled by 1/√hidden_size, then RMSNorm'd. This component sees surrounding context through the embedding — critically, for multimodal models, this is where vision and audio features enter the PLE pathway, since the token-identity component can only do discrete token lookups.
d → 256, (2) apply GeLU activation, (3) element-wise multiply with the PLE vector, (4) up-project 256 → d, (5) RMSNorm, (6) residual add. This makes PLE a third residual block after attention and FFN. The element-wise multiply is the key: the PLE vector gates which dimensions of the down-projected hidden state survive, effectively letting each token modulate each layer's output differently.
# In Gemma4TextModel.__init__():
# --- PLE embedding table: one huge embedding, reshaped per-layer ---
self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
vocab_size, # 262,144
num_hidden_layers * hidden_size_per_layer_input, # 35 * 256 = 8,960
padding_idx,
embed_scale=hidden_size_per_layer_input ** 0.5, # sqrt(256) = 16.0
)
# --- Context-aware projection from main embeddings ---
self.per_layer_model_projection = nn.Linear(
hidden_size, # 1536 (E2B)
num_hidden_layers * hidden_size_per_layer_input, # 35 * 256 = 8,960
bias=False,
)
self.per_layer_model_projection_scale = hidden_size ** -0.5 # 1/sqrt(1536)
self.per_layer_projection_norm = Gemma4RMSNorm(hidden_size_per_layer_input)
self.per_layer_input_scale = 2.0 ** -0.5 # 1/sqrt(2)
def get_per_layer_inputs(self, input_ids, inputs_embeds):
# Direct embedding lookup → reshape to (batch, seq, num_layers, 256)
return self.embed_tokens_per_layer(input_ids).reshape(
*input_ids.shape,
self.config.num_hidden_layers, # 35
self.hidden_size_per_layer_input, # 256
)
# Each token gets a unique 256-dim vector for EVERY layer
# No context — pure token identity signal
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None):
# Project main embeddings → per-layer space
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_model_projection_scale # * 1/sqrt(hidden_size)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers, # 35
self.hidden_size_per_layer_input, # 256
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection) # RMSNorm
if per_layer_inputs is None:
return per_layer_projection # context-aware only (e.g., for soft tokens)
# Combine: (context_signal + token_signal) * 1/sqrt(2)
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
# In Gemma4TextModel.forward():
per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
# Shape: (batch, seq_len, num_layers, 256)
# Each layer gets its own 256-dim slice:
for i, decoder_layer in enumerate(self.layers):
per_layer_input = per_layer_inputs[:, :, i, :] # (batch, seq, 256)
hidden_states = decoder_layer(hidden_states, per_layer_input, ...)
inputs_embeds.inputs_embeds which already contain the actual vision/audio feature vectors. So the projection sees real multimodal content at those positions.This asymmetry is intentional: the token-identity path provides a fixed anchor (pad = "I'm a multimodal token"), while the context-aware path carries the actual vision/audio semantics into PLE. Together, they let the model distinguish what kind of non-text token this is (identity) from what it means (context).
# In Gemma4Model.forward() — multimodal PLE setup:
pad_embedding = self.language_model.embed_tokens.weight[pad_token_id, :]
llm_inputs_embeds = torch.where(
multimodal_mask[..., None], # True at image/audio positions
pad_embedding.view(1, 1, -1), # → use pad token for PLE lookup
inputs_embeds, # → use real token for text positions
)
per_layer_inputs = self.language_model.get_per_layer_inputs(
llm_input_ids, llm_inputs_embeds) # token-identity: pad at multimodal positions
# Later: project_per_layer_inputs uses FINAL inputs_embeds (with vision features)
per_layer_inputs = self.language_model.project_per_layer_inputs(
inputs_embeds, # ← these contain actual vision/audio embeddings
per_layer_inputs, # ← these used pad for those positions
)
| Module | Shape | Parameters | Notes |
|---|---|---|---|
embed_tokens_per_layer | (262K, 8960) | ~2.35B | Dominates total cost |
per_layer_model_projection | (1536, 8960) | ~13.8M | Context-aware projection |
per_layer_projection_norm | (256,) | 256 | RMSNorm on projection |
per_layer_input_gate ×35 | (1536, 256) ×35 | ~13.8M | Per-layer down-projection |
per_layer_projection ×35 | (256, 1536) ×35 | ~13.8M | Per-layer up-projection |
post_per_layer_input_norm ×35 | (1536,) ×35 | ~54K | Per-layer RMSNorm |
| PLE Total | ~2.38B | ~5.1B total − ~2.3B effective = PLE cost |
The PLE embedding table (~2.35B params) accounts for the gap between E2B's total parameter count (~5.1B) and its effective parameter count (~2.3B). The per-layer gate/projection pairs add only ~28M total — negligible in comparison. This is a deliberate tradeoff: a massive embedding table provides rich per-layer token representations, while the actual per-layer compute overhead is minimal. For on-device deployment, the PLE table can be stored on flash memory rather than VRAM/RAM — only the embeddings for the current input tokens need to be loaded. This is critical for phones and edge devices where RAM is extremely limited.
Gemma 4 replaces Gemma 3's SigLIP vision encoder with a Vision Transformer (ViT) that supports variable aspect ratios and configurable resolution via token budgets. The key innovations are 2D RoPE for spatial awareness (replacing learned 1D position embeddings), 3×3 spatial pooling to control token count, and a flexible resolution system where users choose how many visual tokens to send to the LLM. Two ViT sizes are used: 150M parameters (E2B, E4B) and 550M parameters (31B, 26B-A4B).
max_patches = budget × 9 (because 3×3 pooling). The image is then resized so the patch count stays within this limit, at a resolution that must be a multiple of 48 pixels (3 patches × 16 px/patch = 48 px per pooled unit). Higher budgets preserve more resolution but cost more compute.
| Token Budget | Max Pre-Pool Patches | Approx. Resolution | Use Case |
|---|---|---|---|
| 70 tokens | 630 | ~336 × 336 | Thumbnail / fast processing |
| 140 tokens | 1,260 | ~480 × 480 | Standard preview |
| 280 tokens (default) | 2,520 | ~672 × 672 | Balanced quality/speed |
| 560 tokens | 5,040 | ~1,008 × 1,008 | High detail |
| 1,120 tokens | 10,080 | ~1,344 × 1,344 | Maximum resolution |
Exact resolution depends on aspect ratio. Non-square images use the same patch budget but distribute patches across width and height according to the original ratio. E-series models use a fixed budget of 280 tokens.
| Aspect | Gemma 3 27B | Gemma 4 31B | Impact |
|---|---|---|---|
| Encoder | SigLIP (contrastive) | ViT (trained jointly) | Better alignment with LLM |
| Position encoding | Learned 1D | 2D RoPE | Native spatial reasoning |
| Pooling | 4×4 (16× reduction) | 3×3 (9× reduction) | More tokens per image |
| Tokens per image | 256 (fixed) | 70-1,120 (variable) | Resolution flexibility |
| Aspect ratio | Fixed square | Variable (preserved) | No distortion |
| Connector | AvgPool + Linear | Linear + RMSNorm | Simpler, scale-matched |
| ViT params (large) | ~400M (SigLIP-400M) | ~550M | Higher capacity |
The E-series models (E2B, E4B) include a native Conformer audio encoder based on the USM (Universal Speech Model) architecture. This enables direct audio understanding -- speech recognition, audio classification, and spoken-language tasks -- without external ASR. The encoder converts raw audio waveforms into a sequence of soft tokens that are interleaved with text and vision tokens in the LLM's input sequence.
[128, 32], producing a 4× temporal reduction. This converts the high-resolution spectrogram into a manageable sequence of continuous embeddings ("soft tokens"). Overlapping chunks ensure no information is lost at boundaries.
0.5 weight (vs standard 1.0), providing smoother gradient flow.
1,536 for E2B). This matches the scale and distribution of text/vision embeddings, allowing the LLM to treat audio tokens identically to text and vision tokens. The projected audio tokens are interleaved into the LLM's input sequence at the positions indicated by audio placeholder tokens.
The Conformer architecture differs from a standard Transformer encoder by adding a convolutional module (causal depthwise Conv1d with kernel=5) in each layer. Self-attention captures global temporal patterns across the audio sequence, while the convolution captures fine-grained local acoustic features -- phonemes, consonant boundaries, pitch transitions -- that are too narrow for attention to model efficiently. The chunked local attention (chunk_size=12, left_context=13) limits each chunk's receptive field for efficiency while the left context ensures continuity across chunk boundaries.
Verified against transformers/models/gemma4/modeling_gemma4.py. Code snippets are exact quotes.
1/√head_dim scaling. Instead, scaling is fixed to 1.0 and the QK norms (with learned scale) absorb the normalization function.self.scaling = 1.0
attention_k_eq_v=True, the V projection is eliminated (v_proj=None). The key tensor is cloned as the value before normalization. Then K gets k_norm (with learned scale) + RoPE, while V gets v_norm (without scale, no RoPE). So K and V start identical but diverge through different normalizations.# v_proj is None when attention_k_eq_v=True
value_states = self.v_proj(hidden_states).view(hidden_shape) \
if self.v_proj is not None else key_states
# K path: scaled norm + RoPE
key_states = self.k_norm(key_states) # with_scale=True
key_states = apply_rotary_pos_emb(key_states, cos, sin)
# V path: unscaled norm, NO RoPE
value_states = self.v_norm(value_states) # with_scale=False
with_scale flag. Q/K norms have learnable scale (multiplicative weight), V norm does not. The weight is initialized to ones (standard RMSNorm), unlike Gemma 2/3's (1 + weight) parameterization.class Gemma4RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6, with_scale=True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if self.with_scale:
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, hidden_states):
normed = self._norm(hidden_states.float())
if self.with_scale:
normed = normed * self.weight.float()
return normed.type_as(hidden_states)
√hidden_size. This is a registered buffer (not a gradient-tracked parameter), initialized from the config. The per-layer embedding uses √per_layer_dim instead.class Gemma4TextScaledWordEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim,
padding_idx, embed_scale=1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_buffer(
"embed_scale", torch.tensor(embed_scale))
def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale
# Main embedding: scale = sqrt(5376) ≈ 73.3
# Per-layer embedding: scale = sqrt(256) = 16.0
# After attention + FFN residuals:
if self.hidden_size_per_layer_input:
residual = hidden_states
hidden_states = self.per_layer_input_gate(hidden_states) # d → 256
hidden_states = self.act_fn(hidden_states)
hidden_states = hidden_states * per_layer_input # gated element-wise
hidden_states = self.per_layer_projection(hidden_states) # 256 → d
hidden_states = self.post_per_layer_input_norm(hidden_states)
hidden_states = residual + hidden_states
# Dense path (standard MLP)
hidden_states = self.pre_feedforward_layernorm(residual)
hidden_states = self.mlp(hidden_states)
if self.enable_moe_block:
hidden_states_1 = self.post_feedforward_layernorm_1(
hidden_states)
# MoE path (parallel, from pre-MLP residual)
hidden_states_flat = residual.reshape(-1, residual.shape[-1])
_, top_k_weights, top_k_index = self.router(
hidden_states_flat)
hidden_states_2 = self.pre_feedforward_layernorm_2(
hidden_states_flat)
hidden_states_2 = self.experts(
hidden_states_2, top_k_index, top_k_weights)
hidden_states_2 = self.post_feedforward_layernorm_2(
hidden_states_2)
# Sum both paths, then final post-norm
hidden_states = hidden_states_1 + hidden_states_2
hidden_states = self.post_feedforward_layernorm(
hidden_states)
hidden_states = residual + hidden_states
1/√hidden_size, projects to num_experts, softmax, selects top-K, normalizes weights, then applies per-expert learned scales.def forward(self, hidden_states):
hidden_states = self.norm(hidden_states) # RMSNorm, no scale
hidden_states = hidden_states * self.scale * self.scalar_root_size
expert_scores = self.proj(hidden_states) # Linear → num_experts
router_probs = softmax(expert_scores, dim=-1)
top_k_weights, top_k_index = torch.topk(
router_probs, k=self.config.top_k_experts)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
return router_probs, top_k_weights, top_k_index
register_buffer (loaded from checkpoint, not trained via gradient), initialized to 1.0. Applied after all residual blocks (attention + FFN + optional per-layer input).hidden_states *= self.layer_scalar # buffer, not nn.Parameter
None (disabled), but published model checkpoints set it to 30.0. Bounds logits to [-30, 30] smoothly.if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
rope_type: "proportional" which sets the lowest 75% of frequency dimensions to zero (cos=1, sin=0), leaving those dimensions position-independent. Sliding layers use standard default RoPE. This is configured per-layer-type in the config, not in the modeling code.# From Gemma4TextConfig:
rope_parameters = {
"sliding_attention": {
"rope_type": "default",
"rope_theta": 10_000.0
},
"full_attention": {
"rope_type": "proportional",
"partial_rotary_factor": 0.25,
"rope_theta": 1_000_000.0
}
}
# Last layer forced to full_attention:
if self.layer_types[-1] != "full_attention":
self.layer_types[-1] = "full_attention"
# In shared layer forward:
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = \
past_key_values.shared_layers[self.kv_shared_layer_index]
# In non-shared layer: store KV for later reuse
if self.store_full_length_kv:
past_key_values.shared_layers[self.layer_idx] = \
key_states, value_states