GEMMA 4 ARCHITECTURE

Complete technical comparison of Google's Gemma 4 model family — released April 2, 2026 — Apache 2.0

Gemma 4 is Google DeepMind's fourth-generation open model family, released under Apache 2.0. It is the first generation where the architecture itself -- not just scale or training data -- becomes the primary axis of differentiation across variants. The family spans four models (31B, 26B-A4B, E4B, E2B) that share a common structural DNA but diverge in how they allocate capacity, reflecting a design philosophy where the same building blocks are composed differently for server, edge, and on-device deployment. The smaller E-series variants (E2B, E4B) support all four modalities -- text, image, video, and audio -- making them true any-to-any models.

The Shared Skeleton

Every Gemma 4 model is a decoder-only transformer with hybrid sliding/full attention. Layers alternate between local sliding-window attention and global full attention in a fixed ratio (5:1 or 4:1). This is not new -- Gemma 3 and Mistral pioneered it. What is new is that Gemma 4 makes these two layer types structurally different:

Sliding Layers (local)
head_dim=256, more KV heads, standard RoPE (theta=10K), full rotation. Optimized for fine-grained local patterns within a window of 512-1024 tokens.
Full Layers (global)
head_dim=512, fewer KV heads, p-RoPE (theta=1M, partial=0.25), K=V weight sharing. Optimized for long-range semantic attention across the entire 256K context.

This dual-config approach means that every 6th layer operates with a completely different attention geometry -- wider heads, fewer KV groups, and only 25% of dimensions carrying positional information. The transition from Gemma 3 to Gemma 4 can be summarized as moving from "same attention, different window" to "different attention, different window."

Sliding window layers can only attend to the last 512-1024 tokens directly. Information from beyond the window propagates indirectly through hidden states across layers -- earlier tokens influence the current window's hidden states, which then influence the next layer, like a chain of handoffs. The periodic full attention layers (every 5th or 6th) break this indirection by attending to the entire sequence at once, ensuring long-range information isn't lost. Gemma 4 also enforces that the final layer is always global attention, guaranteeing full-context awareness in the last representation.

Four Targets, Four Compositions

The family spans four deployment regimes, each making distinct architectural trade-offs within this shared framework:

31B Server · Dense
The flagship dense model (Arena score: 1452). 60 layers of GQA + GeGLU with the full Gemma 4 attention design. No compression tricks -- every layer computes fresh K, V, and FFN. Image-Text-to-Text with a 27-layer ViT vision encoder using 2D RoPE and learned spatial embeddings. Supports variable image token budgets (70, 140, 280, 560, 1120 tokens).
26B-A4B Server · MoE
The efficiency variant (Arena score: 1441 with only 4B active). 26B total parameters, but only ~4B active per token. Every layer runs a dense GeGLU FFN (hidden=2,112) in parallel with a 128-expert MoE (top-8, hidden=704 each) -- their outputs are summed. This is architecturally unusual: most MoE models replace the FFN with experts, but Gemma 4 keeps both, giving each layer always-on dense capacity plus sparse expert specialization.
E4B Edge · Any-to-Any
The mid-size on-device model. ~8B total parameters with ~4.5B effective, 128K context. Shares the E-series architecture with per-layer input embeddings and KV cache sharing. Supports all four modalities (text, image, video with audio, and standalone audio), using the same ViT vision encoder and USM Conformer audio encoder as E2B. Not analyzed in this page (config not yet published at time of writing).
E2B On-device · Any-to-Any
The most architecturally novel variant. ~5.1B total parameters but only ~2.3B effective -- the gap comes from a per-layer input embedding table that maps each token to a unique 256-dim gated residual for every layer. 20 of 35 layers share KV caches from earlier layers, and KV-shared layers compensate with 2x wider MLPs. Includes both a Conformer audio encoder (USM-style, 12 layers) and a ViT vision encoder, making it a true any-to-any model supporting text, image, video, and audio.

What Changed from Gemma 3

Compared to Gemma 3 27B, the architectural delta is substantial. The table below focuses on the 31B (the most direct comparison), but most innovations propagate to all three variants:

  • Per-type attention geometry -- Gemma 3 used identical head_dim/kv_heads for all layers. Gemma 4 splits: 256/16kv for sliding, 512/4kv for full. This is the single biggest structural change.
  • p-RoPE replaces linear scaling -- Full attention layers now rotate only 25% of dimensions (proportional RoPE) instead of applying 8x linear frequency scaling. This is grounded in the Oxford/DeepMind finding that low-frequency channels carry semantic (not positional) information.
  • K=V weight sharing -- Full attention layers eliminate the V projection entirely, reusing key states as values. Combined with fewer KV heads (4 vs 16), this dramatically cuts per-layer parameters.
  • V-norm (value normalization) -- All variants apply RMSNorm to values (without learned scale), a stabilization technique absent in Gemma 3.
  • Logit soft-capping -- Output logits are bounded via tanh(x/30)*30, preventing extreme values during generation.
  • Final layer always global attention -- Gemma 4 enforces that the last decoder layer uses full (global) attention regardless of the sliding/full pattern. Gemma 3 did not enforce this, e.g. Gemma 3 4B's last layer was local. This ensures every model's final representation has full-context awareness.
  • Head dimensions doubled in global layers -- Full attention layers use head_dim=512 (vs 256 for sliding), compensating for more aggressive KV sharing (4 KV heads vs 16). Gemma 3 used a uniform head_dim=128 for all layers. The wider heads give global attention higher per-head capacity to capture long-range dependencies.
  • Vision encoder upgrade -- SigLIP (Gemma 3) is replaced by a ViT with 2D RoPE and learned 2D positional embeddings, yielding 280 tokens per image (vs 256) through 3x3 pooling (vs 4x4). Supports variable resolution via token budgets (70-1120 tokens). See the Vision Encoder Deep Dive section below.
  • Per-Layer Embeddings / PLE (E-series) -- A completely new mechanism giving each decoder layer a unique token-dependent signal via two pathways: a dedicated per-layer embedding table (token-identity) and a learned projection from main embeddings (context-aware). These are combined and injected as a gated third residual block. New to Gemma 4, not present in any Gemma 3 variant.
Model Overview
Gemma 4 31B IT
DenseVision
Total Params~31B
Active Params31B
Context256K tokens
Hidden Size5,376
Layers60 (5:1 sliding/full)
Gemma 4 26B-A4B IT
MoEVision
Total Params~26B
Active Params~4B
Context256K tokens
Hidden Size2,816
Layers30 (5:1 sliding/full)
Gemma 4 E2B IT
DenseVisionAudio
Total Params~5B (~2B effective)
Active Params~2B eff
Context128K tokens
Hidden Size1,536
Layers35 (4:1 sliding/full)
Gemma 3 27B IT comparison
DenseVision
Total Params27B
Active Params27B
Context128K tokens
Hidden Size5,376
Layers62 (5:1 sliding/full)
Parameter Comparison
Parameter Gemma 4 31B Gemma 4 26B-A4B Gemma 4 E2B Gemma 3 27B
Total Params ~31B ~26B ~5B 27B
Active Params 31B ~4B ~2B eff 27B
Context 256K 256K 128K 128K
Hidden Size 5,376 2,816 1,536 5,376
Layers 60 30 35 62
Layer Pattern 5:1 sliding/full 5:1 sliding/full 4:1 sliding/full 5:1 sliding/full
Attention GQA GQA GQA GQA
Q Heads 32 16 8 32
KV Heads (sliding) 16 8 1 16
KV Heads (full) 4 2 1 16
Head Dim (sliding) 256 256 256 128
Head Dim (full) 512 512 512 128
FFN Type GeGLU MoE + GeGLU GeGLU GeGLU
FFN Hidden 21,504 2,112 + MoE(704) 6,144 21,504
MoE Experts -- 128 (top-8) -- --
Vocab 262,144 262,144 262,144 262,208
RoPE (sliding) theta=10K theta=10K theta=10K theta=10K
RoPE (full) theta=1M, 25% partial theta=1M, 25% partial theta=1M, 25% partial theta=1M, linear 8x
QK Norm RMSNorm RMSNorm RMSNorm RMSNorm
V Norm RMSNorm (no scale) RMSNorm (no scale) RMSNorm (no scale) --
K=V Sharing yes (full layers) yes (full layers) -- --
KV Shared Layers -- -- 20 of 35 --
Per-Layer Input -- -- dim=256 --
Logit Cap 30.0 30.0 30.0 --
Vision Encoder ViT 27L, d=1152 ViT 27L, d=1152 ViT 16L, d=768 SigLIP 27L, d=1152
Audio Encoder -- -- Conformer 12L --
Tie Weights yes yes yes yes
Benchmarks

Source: Hugging Face blog. Instruction-tuned variants. E4B included from blog (architecture page covers E2B, 26B-A4B, 31B in detail).

Benchmark31B26B-A4BE4BE2B
Arena Score (LMArena est.)14521441----
MMLU Pro85.2%82.6%69.4%60.0%
GPQA Diamond84.3%82.3%58.6%43.4%
AIME 2026 (no tools)89.2%88.3%42.5%37.5%
LiveCodeBench v680.0%77.1%52.0%44.0%
Codeforces ELO21501718940633
MMMU Pro (vision)76.9%73.8%52.6%44.2%
MRCR v2 8-needle 128K66.4%44.1%25.4%19.1%
CoVoST (audio)----35.5433.47
FLEURS (audio)----0.080.09

The 26B-A4B MoE achieves 98% of the 31B's Arena score with only 4B active parameters -- a 7.5x compute reduction. The E2B achieves competitive audio scores despite being a 2.3B-effective model.

Per-Block Parameter Estimates

Computed from: Q=dim×q_heads×head_dim, K/V=dim×kv_heads×head_dim, O=q_heads×head_dim×dim. GeGLU FFN=3×dim×hidden (gate+up+down). V proj=0 when K=V. MoE: per_expert=3×dim×expert_hidden, router=dim×num_experts.

Gemma 4 31B — 60 layers (50 sliding + 10 full)
ComponentSliding BlockFull BlockFormula
Q proj44.0M88.1M5376 × 32 × dh
K proj22.0M11.0M5376 × kv × dh
V proj22.0M0 (K=V)eliminated
O proj44.0M88.1Mq×dh × 5376
Attention total132.1M187.2M
GeGLU FFN346.8M346.8M3 × 5376 × 21504
Block total478.9M534.0M
50 × 478.9M + 10 × 534.0M + 1.4B embed = ~30.7B total
Gemma 4 26B-A4B — 30 layers (25 sliding + 5 full)
ComponentSliding BlockFull BlockFormula
Q proj11.5M23.1M2816 × 16 × dh
K proj5.8M2.9M2816 × kv × dh
V proj5.8M0 (K=V)
O proj11.5M23.1M
Attention total34.6M49.0M
Dense GeGLU17.8M17.8M3 × 2816 × 2112
MoE experts (128×)761.3M761.3M128 × 3 × 2816 × 704
MoE active (top-8)47.6M47.6M8 × 5.9M/expert
Router0.4M0.4M2816 × 128
FFN total capacity779.5M779.5Mdense + all experts + router
FFN active/token65.8M65.8Mdense + top-8 + router
Block total capacity814.1M828.5M
Block active/token100.4M114.8M
Total capacity: 25 × 814M + 5 × 829M + 0.7B embed = ~25.2B
Active per token: 25 × 100M + 5 × 115M + 0.7B embed = ~3.8B
Gemma 4 E2B — 35 layers (28 sliding + 7 full)
ComponentSliding BlockFull BlockFormula
Q proj3.1M6.3M1536 × 8 × dh
K proj0.4M0.8M1536 × 1 × dh
V proj0.4M0.8Mno K=V on E2B
O proj3.1M6.3M
Attention total7.1M14.2M
GeGLU FFN28.3M28.3M3 × 1536 × 6144
GeGLU FFN (2× wide)56.6M56.6MKV-shared layers only
Block total35.4M42.5Mstandard layers
Block total (2× wide)63.7M70.8MKV-shared layers
15 standard blocks + 20 double-wide blocks + 0.4B embed + 2.3B per-layer embed = ~5.1B total
Effective (excl. per-layer embed table): ~2.3B
Gemma 3 27B — 62 uniform layers PREVIOUS GEN
ComponentPer BlockFormula
Q proj22.0M5376 × 32 × 128
K proj11.0M5376 × 16 × 128
V proj11.0M5376 × 16 × 128
O proj22.0M32 × 128 × 5376
Attention total66.1M
GeGLU FFN346.8M3 × 5376 × 21504
Block total412.9M
62 × 412.9M + 1.4B embed = ~27.0B total
GPU Memory Requirements

Estimates: weight memory + KV cache (FP16). KV formula: 2 × kv_heads × head_dim × 2 bytes per token per layer, summed across sliding and full layers. E2B benefits from KV sharing (20 of 35 layers reuse cache).

Weight Memory (all params loaded)
Precision31B26B-A4BE2BGemma 3 27B
FP1662.0 GB52.0 GB10.2 GB54.0 GB
INT831.0 GB26.0 GB5.1 GB27.0 GB
INT415.5 GB13.0 GB2.6 GB13.5 GB
KV Cache (FP16, batch=1)
Context31B26B-A4BE2BGemma 3 27B
4K3.4 GB0.9 GB0.07 GB1.9 GB
32K27.5 GB6.9 GB0.6 GB15.5 GB
128K110 GB27.5 GB2.3 GB62 GB
256K220 GB55 GB----

31B: 50 sliding layers (16 kv_heads × d=256) + 10 full layers (4 kv_heads × d=512). 26B-A4B: 25 sliding (8 × 256) + 5 full (2 × 512). E2B: 15 unique KV layers after sharing (1 × 256/512). Gemma 3: 62 uniform layers (16 × 128).

Total VRAM & GPU Recommendation (batch=1)
Scenario31B26B-A4BE2BGemma 3 27B
FP16, 4K ctx 65.4 GB
1× H100 80GB
52.9 GB
1× H100 80GB
10.3 GB
RTX 4070 12GB
55.9 GB
1× H100 80GB
INT4, 4K ctx 18.9 GB
RTX 4090 24GB
13.9 GB
RTX 4070 Ti 16GB
2.7 GB
Any GPU
15.4 GB
RTX 4090 24GB
INT4, 128K ctx 125.5 GB
2× H100 160GB
40.5 GB
A6000 48GB
4.9 GB
RTX 4060 8GB
75.5 GB
1× H100 80GB

The 26B-A4B MoE loads all 26B parameters into VRAM (all experts must be resident), but its KV cache is 4x smaller than the 31B due to fewer KV heads. At INT4 + 128K context, it fits in a single A6000 -- impossible for the 31B. E2B with KV sharing runs full 128K context on a laptop GPU.

Key Architectural Innovations in Gemma 4
Deep Dive: Why partial_rotary_factor=0.25?
Based on: "Round and Round We Go! What makes Rotary Positional Encodings useful?" Barbero, Vitvitskyi, Perivolaropoulos, Pascanu, Velichkovic (Oxford & Google DeepMind, 2024)
The Problem
Standard RoPE rotates all head dimensions. High-frequency components create positional heads (tracking token positions), while low-frequency components carry semantic information (content meaning). For long context, these slow-rotating low-frequency channels can break -- they complete full rotations across the sequence, losing the semantic signal they encode.
The Insight
Not all RoPE dimensions are equal. The paper proves that high frequencies build positional heads robustly, while low frequencies are "most invariant" to position -- they encode semantic attention patterns. For long context, increasing theta alone (e.g. 10K → 1M) isn't enough because the lowest frequencies still eventually break.
p-RoPE Solution
Proportional RoPE (p-RoPE) simply sets the (1-p) lowest frequency dimensions to zero -- making them position-independent (NoPE). With p=0.25, only the top 25% of frequencies are rotated (positional), while 75% become pure semantic channels that never break regardless of context length.
Gemma 4's Implementation
Gemma 4 applies p-RoPE on full attention layers only:
Sliding layers: standard RoPE, theta=10,000, all dims rotated
Full layers: p-RoPE, theta=1,000,000, partial=0.25
This gives sliding layers fine-grained local positioning, while full layers get robust long-range semantic attention with only 25% positional capacity -- exactly the p-RoPE prescription.
A. RoPE Frequency Spectrum (per head) Each head_dim has d/2 frequency pairs. Low frequencies rotate slowly, high frequencies rotate fast. Low freq Medium freq High freq slow rotation medium rotation fast rotation dim 0 dim d/4 dim d/2 Semantic content meaning position-invariant Mixed both semantic & positional vulnerable at long context Positional token position positional heads B. The Problem: Low Frequencies Break at Long Context Short Context (4K tokens) Low-freq wave: partial rotation, signal intact Long Context (256K tokens) Low-freq completes full rotations → signal destroyed C. p-RoPE Fix: Zero Out the Vulnerable Dimensions Standard RoPE (Gemma 3): All dimensions rotated (100%) p-RoPE p=0.25 (Gemma 4): 75% zeroed: cos=1, sin=0 (NoPE) 25% RoPE Position-free: semantic attention that never breaks Works at any context length (4K to 256K+) Positional heads Track token position Key insight: p-RoPE gives global attention layers robust semantic capacity at any context length, while sliding layers keep full positional precision for local patterns (standard RoPE, theta=10K)
RoPE Strategy Evolution: Gemma 3 → Gemma 4
Aspect Gemma 3 (27B) Gemma 4 (31B) Impact
Sliding RoPE theta=10K, full rotation theta=10K, full rotation Same -- local attention unchanged
Full RoPE theta theta=1M theta=1M Same base frequency
Full RoPE scaling linear 8x partial=0.25 (p-RoPE) p-RoPE replaces linear scaling
Rotated dims (full) 128 of 128 (100%) 128 of 512 (25%) 75% dims are position-free
Semantic capacity All dims position-coupled 384 dims pure semantic Robust long-context semantics
Max context 128K 256K 2x context with p-RoPE
Validation Perplexity (from paper)

The paper validates p-RoPE on Gemma 7B-scale models. Lower is better.

EncodingWikiPlanV2Properties
NoPE4.85946.6429Semantic only, no position
RoPE (theta=10K)4.46276.4429Standard
RoPE (theta=500K)4.44856.4593High theta
0.25-RoPE4.45926.4683= Gemma 4's setting
0.75-RoPE (inverted)4.45376.4562More rotation
0.25-RoPE (full model)4.53026.5111p-RoPE on all layers
0.75-RoPE (full model)4.44146.4422Best overall perplexity

Gemma 4 uses 0.25-RoPE on full attention layers only (not all layers), combining the best of both: standard RoPE for local sliding attention + p-RoPE for global full attention.

Deep Dive: Per-Layer Embeddings (PLE)

Per-Layer Embeddings (PLE) is an architectural innovation introduced in the Gemma 4 E-series (E2B, E4B) that gives every decoder layer its own unique token representation. Instead of all layers sharing the same input embedding, each layer receives a dedicated 256-dimensional vector per token, created by combining two complementary signals: a token-identity component (direct embedding lookup from a per-layer table) and a context-aware component (learned projection from the main input embeddings). This combined signal is then injected into each layer via a gated bottleneck mechanism as a third residual block after attention and FFN.

Why PLE?
In standard transformers, all layers share the same input embedding — the same token representation enters every layer. But different layers learn different things: early layers focus on syntax, later layers on semantics. PLE lets each layer receive a custom token-dependent signal that can encode layer-specific information, allowing the model to "re-read" the input at every depth with a different lens. For small models like E2B (~2.3B effective), this is especially powerful: PLE lets a shallow model encode information that would otherwise require more layers to learn.
Token-Identity Component
A single nn.Embedding of shape (vocab_size, num_layers × 256) maps each token to a unique vector for every layer. This is a direct lookup — no context from surrounding tokens. Scaled by √256 = 16. For E2B with 35 layers, each token maps to a 8,960-dim vector that is reshaped into (35, 256). This table alone accounts for ~2B parameters — the majority of the PLE cost.
Context-Aware Component
A learned nn.Linear(hidden_size, num_layers × 256) projects the main inputs_embeds (which already contain vision/audio soft tokens for multimodal inputs) into per-layer space. Scaled by 1/√hidden_size, then RMSNorm'd. This component sees surrounding context through the embedding — critically, for multimodal models, this is where vision and audio features enter the PLE pathway, since the token-identity component can only do discrete token lookups.
Gated Injection (per layer)
Each layer receives its 256-dim PLE slice and applies a gated bottleneck: (1) down-project hidden states from d → 256, (2) apply GeLU activation, (3) element-wise multiply with the PLE vector, (4) up-project 256 → d, (5) RMSNorm, (6) residual add. This makes PLE a third residual block after attention and FFN. The element-wise multiply is the key: the PLE vector gates which dimensions of the down-projected hidden state survive, effectively letting each token modulate each layer's output differently.
Model-Level: Creating PLE Inputs
# In Gemma4TextModel.__init__():
# --- PLE embedding table: one huge embedding, reshaped per-layer ---
self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
    vocab_size,                                          # 262,144
    num_hidden_layers * hidden_size_per_layer_input,     # 35 * 256 = 8,960
    padding_idx,
    embed_scale=hidden_size_per_layer_input ** 0.5,      # sqrt(256) = 16.0
)

# --- Context-aware projection from main embeddings ---
self.per_layer_model_projection = nn.Linear(
    hidden_size,                                         # 1536 (E2B)
    num_hidden_layers * hidden_size_per_layer_input,     # 35 * 256 = 8,960
    bias=False,
)
self.per_layer_model_projection_scale = hidden_size ** -0.5  # 1/sqrt(1536)
self.per_layer_projection_norm = Gemma4RMSNorm(hidden_size_per_layer_input)
self.per_layer_input_scale = 2.0 ** -0.5                    # 1/sqrt(2)
Step 1: Token-Identity Lookup
def get_per_layer_inputs(self, input_ids, inputs_embeds):
    # Direct embedding lookup → reshape to (batch, seq, num_layers, 256)
    return self.embed_tokens_per_layer(input_ids).reshape(
        *input_ids.shape,
        self.config.num_hidden_layers,   # 35
        self.hidden_size_per_layer_input, # 256
    )
    # Each token gets a unique 256-dim vector for EVERY layer
    # No context — pure token identity signal
Step 2: Context-Aware Projection + Combination
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None):
    # Project main embeddings → per-layer space
    per_layer_projection = self.per_layer_model_projection(inputs_embeds)
    per_layer_projection *= self.per_layer_model_projection_scale  # * 1/sqrt(hidden_size)
    per_layer_projection = per_layer_projection.reshape(
        *inputs_embeds.shape[:-1],
        self.config.num_hidden_layers,      # 35
        self.hidden_size_per_layer_input,    # 256
    )
    per_layer_projection = self.per_layer_projection_norm(per_layer_projection)  # RMSNorm

    if per_layer_inputs is None:
        return per_layer_projection  # context-aware only (e.g., for soft tokens)

    # Combine: (context_signal + token_signal) * 1/sqrt(2)
    return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
Step 3: Forward Pass — Slicing & Distribution
# In Gemma4TextModel.forward():
per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
# Shape: (batch, seq_len, num_layers, 256)

# Each layer gets its own 256-dim slice:
for i, decoder_layer in enumerate(self.layers):
    per_layer_input = per_layer_inputs[:, :, i, :]  # (batch, seq, 256)
    hidden_states = decoder_layer(hidden_states, per_layer_input, ...)
Complete PLE Data Flow
PER-LAYER EMBEDDINGS: COMPLETE DATA FLOW MAIN EMBEDDING DEDICATED PLE TABLE input_ids (batch, seq_len) embed_tokens (vocab, hidden_size) × √hidden_size inputs_embeds (batch, seq, hidden_size) to decoder layers embed_tokens_per_layer (vocab, layers × 256) unique table — not shared with main embed × √256 = 16 reshape → (B,S,layers,256) TOKEN-IDENTITY SIGNAL per_layer_model_projection Linear(hidden_size, layers×256) × 1/√hidden_size RMSNorm CONTEXT-AWARE SIGNAL + × 1/√2 per_layer_ inputs (B, S, layers, 256) [:,:,i,:] → layer i Layer 0 Layer 1 ...
Multimodal Handling
For multimodal inputs (images, video, audio), the two PLE components diverge in how they handle non-text tokens:
  • Token-identity component: Positions corresponding to vision/audio soft tokens receive the pad token embedding (since they have no meaningful token ID). The model computes PLE lookups before vision features replace image placeholders in inputs_embeds.
  • Context-aware component: Uses the final inputs_embeds which already contain the actual vision/audio feature vectors. So the projection sees real multimodal content at those positions.

This asymmetry is intentional: the token-identity path provides a fixed anchor (pad = "I'm a multimodal token"), while the context-aware path carries the actual vision/audio semantics into PLE. Together, they let the model distinguish what kind of non-text token this is (identity) from what it means (context).

# In Gemma4Model.forward() — multimodal PLE setup:
pad_embedding = self.language_model.embed_tokens.weight[pad_token_id, :]
llm_inputs_embeds = torch.where(
    multimodal_mask[..., None],    # True at image/audio positions
    pad_embedding.view(1, 1, -1),  # → use pad token for PLE lookup
    inputs_embeds,                 # → use real token for text positions
)
per_layer_inputs = self.language_model.get_per_layer_inputs(
    llm_input_ids, llm_inputs_embeds)  # token-identity: pad at multimodal positions

# Later: project_per_layer_inputs uses FINAL inputs_embeds (with vision features)
per_layer_inputs = self.language_model.project_per_layer_inputs(
    inputs_embeds,      # ← these contain actual vision/audio embeddings
    per_layer_inputs,   # ← these used pad for those positions
)
Parameter Cost (E2B: 35 layers, hidden=1536, PLE dim=256, vocab=262K)
ModuleShapeParametersNotes
embed_tokens_per_layer(262K, 8960)~2.35BDominates total cost
per_layer_model_projection(1536, 8960)~13.8MContext-aware projection
per_layer_projection_norm(256,)256RMSNorm on projection
per_layer_input_gate ×35(1536, 256) ×35~13.8MPer-layer down-projection
per_layer_projection ×35(256, 1536) ×35~13.8MPer-layer up-projection
post_per_layer_input_norm ×35(1536,) ×35~54KPer-layer RMSNorm
PLE Total~2.38B~5.1B total − ~2.3B effective = PLE cost

The PLE embedding table (~2.35B params) accounts for the gap between E2B's total parameter count (~5.1B) and its effective parameter count (~2.3B). The per-layer gate/projection pairs add only ~28M total — negligible in comparison. This is a deliberate tradeoff: a massive embedding table provides rich per-layer token representations, while the actual per-layer compute overhead is minimal. For on-device deployment, the PLE table can be stored on flash memory rather than VRAM/RAM — only the embeddings for the current input tokens need to be loaded. This is critical for phones and edge devices where RAM is extremely limited.

Deep Dive: Vision Encoder & Variable Resolution

Gemma 4 replaces Gemma 3's SigLIP vision encoder with a Vision Transformer (ViT) that supports variable aspect ratios and configurable resolution via token budgets. The key innovations are 2D RoPE for spatial awareness (replacing learned 1D position embeddings), 3×3 spatial pooling to control token count, and a flexible resolution system where users choose how many visual tokens to send to the LLM. Two ViT sizes are used: 150M parameters (E2B, E4B) and 550M parameters (31B, 26B-A4B).

Variable Aspect Ratio
Unlike standard ViT (which resizes all images to a fixed square), Gemma 4's encoder preserves the original aspect ratio. Images are resized so that 16×16 pixel patches still tile cleanly. Where a perfect fit isn't possible, minimal padding is added. This means a tall portrait and a wide landscape produce different numbers of patches -- the model sees the image's true shape rather than a distorted square.
2D RoPE for Spatial Awareness
Standard ViT treats patches as a 1D sequence with learned position embeddings, losing explicit 2D structure. Gemma 4 uses 2D Rotary Position Embeddings: each patch embedding is split into two equal halves. RoPE is applied independently to each half -- one encodes the width (column) position, the other the height (row) position. This lets the model natively reason about spatial relationships: "patch A is above and to the left of patch B" is directly encoded in the position signal.
3×3 Spatial Pooling
After the ViT processes all patches, every 3×3 block of neighboring patches is merged into a single embedding by averaging. This 9× reduction converts raw patch embeddings into compact visual tokens suitable for the LLM. The pooling is spatial-aware: it merges patches that are physically adjacent in the image, preserving local structure. This yields 280 tokens per standard image (vs 256 with Gemma 3's 4×4 pooling of SigLIP patches).
Variable Token Budget
The 31B model supports 5 resolution tiers via a user-selectable token budget that controls how many visual tokens the LLM receives. The budget determines the maximum number of pre-pooling patches: max_patches = budget × 9 (because 3×3 pooling). The image is then resized so the patch count stays within this limit, at a resolution that must be a multiple of 48 pixels (3 patches × 16 px/patch = 48 px per pooled unit). Higher budgets preserve more resolution but cost more compute.
Token Budget → Resolution Mapping (31B, approximate for square images)
Token BudgetMax Pre-Pool PatchesApprox. ResolutionUse Case
70 tokens630~336 × 336Thumbnail / fast processing
140 tokens1,260~480 × 480Standard preview
280 tokens (default)2,520~672 × 672Balanced quality/speed
560 tokens5,040~1,008 × 1,008High detail
1,120 tokens10,080~1,344 × 1,344Maximum resolution

Exact resolution depends on aspect ratio. Non-square images use the same patch budget but distribute patches across width and height according to the original ratio. E-series models use a fixed budget of 280 tokens.

Vision Encoder Pipeline
VISION ENCODER: IMAGE → VISUAL TOKENS Input Image any aspect ratio any resolution Adaptive Resize preserve aspect ratio align to 48px grid 16×16 Patches variable count + pad if needed budget×9 max ViT Encoder 16-27 layers 2D RoPE self-attention + FFN 150M or 550M params 3×3 Pool avg 9 neighbors 9× reduction Linear Proj ViT dim → LLM dim + RMSNorm match LLM scale ↓ visual tokens interleaved with text 2D RoPE: HOW IT WORKS patch embedding split in half RoPE(w) column pos RoPE(h) row pos Each half independently encodes one spatial axis. The model can directly reason about 2D layout: "above" → row diff "left of" → col diff theta=100 (low vs text) VIT ENCODER CONFIGURATIONS 31B / 26B-A4B (550M params) 27 layers, dim=1152, 16 heads head_dim=72, FFN=4304, GELU 2D RoPE(theta=100) E2B / E4B (150M params) 16 layers, dim=768, 12 heads head_dim=64, FFN=3072, GELU 2D RoPE(theta=100), clipped_linears
Vision Encoder: Gemma 3 → Gemma 4
AspectGemma 3 27BGemma 4 31BImpact
EncoderSigLIP (contrastive)ViT (trained jointly)Better alignment with LLM
Position encodingLearned 1D2D RoPENative spatial reasoning
Pooling4×4 (16× reduction)3×3 (9× reduction)More tokens per image
Tokens per image256 (fixed)70-1,120 (variable)Resolution flexibility
Aspect ratioFixed squareVariable (preserved)No distortion
ConnectorAvgPool + LinearLinear + RMSNormSimpler, scale-matched
ViT params (large)~400M (SigLIP-400M)~550MHigher capacity
Deep Dive: Audio Encoder (Conformer)

The E-series models (E2B, E4B) include a native Conformer audio encoder based on the USM (Universal Speech Model) architecture. This enables direct audio understanding -- speech recognition, audio classification, and spoken-language tasks -- without external ASR. The encoder converts raw audio waveforms into a sequence of soft tokens that are interleaved with text and vision tokens in the LLM's input sequence.

Step 1: Feature Extraction
Raw audio is converted to a mel-spectrogram -- a 2D representation with time on the horizontal axis and frequency bands on the vertical axis. This is a standard audio preprocessing step that decomposes the waveform into perceptually relevant frequency components, mimicking how the human cochlea processes sound.
Step 2: Subsampling (4× reduction)
The mel features pass through two stages of 2D convolution with kernel sizes [128, 32], producing a 4× temporal reduction. This converts the high-resolution spectrogram into a manageable sequence of continuous embeddings ("soft tokens"). Overlapping chunks ensure no information is lost at boundaries.
Step 3: Conformer Encoder
12 Conformer layers process the soft tokens. Each layer combines:
  • Self-attention with chunked local attention (chunk_size=12, context_left=13, context_right=0 = causal)
  • Causal Conv1d (kernel=5) -- the key difference from standard Transformers, capturing local audio patterns
  • Feed-forward with SiLU activation
  • Attention logit capping at 50.0 (vs 30.0 for text)
Residual connections use a 0.5 weight (vs standard 1.0), providing smoother gradient flow.
Step 4: Projection to LLM Space
The final Conformer embeddings (dim=1,024) are projected via a linear layer to the LLM's hidden dimension (1,536 for E2B). This matches the scale and distribution of text/vision embeddings, allowing the LLM to treat audio tokens identically to text and vision tokens. The projected audio tokens are interleaved into the LLM's input sequence at the positions indicated by audio placeholder tokens.
Audio Encoder Pipeline
AUDIO ENCODER: WAVEFORM → SOFT TOKENS Raw Audio waveform Mel Spectro- gram time × freq bands 2× Conv2d kernels [128, 32] 4× temporal ↓ Conformer 12 layers, dim=1024 attention + Conv1d(k=5) chunked local, causal logit_cap=50, res_weight=0.5 Linear Proj 1024 → 1536 match LLM dim Audio Tokens soft tokens Audio tokens are interleaved with text + vision tokens in the LLM input sequence Conformer = Transformer + Convolution: attention captures global audio patterns, Conv1d captures local acoustic features USM (Universal Speech Model) architecture -- same family as Google's production speech systems

The Conformer architecture differs from a standard Transformer encoder by adding a convolutional module (causal depthwise Conv1d with kernel=5) in each layer. Self-attention captures global temporal patterns across the audio sequence, while the convolution captures fine-grained local acoustic features -- phonemes, consonant boundaries, pitch transitions -- that are too narrow for attention to model efficiently. The chunked local attention (chunk_size=12, left_context=13) limits each chunk's receptive field for efficiency while the left context ensures continuity across chunk boundaries.

Code-Verified Architecture Details

Verified against transformers/models/gemma4/modeling_gemma4.py. Code snippets are exact quotes.

1. Attention Scaling = 1.0 (not 1/√d)
Gemma 4 does NOT use the standard 1/√head_dim scaling. Instead, scaling is fixed to 1.0 and the QK norms (with learned scale) absorb the normalization function.
self.scaling = 1.0
2. K=V Sharing: Keys and Values Diverge Through Norms
When attention_k_eq_v=True, the V projection is eliminated (v_proj=None). The key tensor is cloned as the value before normalization. Then K gets k_norm (with learned scale) + RoPE, while V gets v_norm (without scale, no RoPE). So K and V start identical but diverge through different normalizations.
# v_proj is None when attention_k_eq_v=True
value_states = self.v_proj(hidden_states).view(hidden_shape) \
    if self.v_proj is not None else key_states

# K path: scaled norm + RoPE
key_states = self.k_norm(key_states)          # with_scale=True
key_states = apply_rotary_pos_emb(key_states, cos, sin)

# V path: unscaled norm, NO RoPE
value_states = self.v_norm(value_states)      # with_scale=False
W_k (K proj) clone K path V path k_norm with_scale=True v_norm with_scale=False RoPE no RoPE K V same source, different paths
3. RMSNorm: with_scale vs without
Gemma 4 introduces a with_scale flag. Q/K norms have learnable scale (multiplicative weight), V norm does not. The weight is initialized to ones (standard RMSNorm), unlike Gemma 2/3's (1 + weight) parameterization.
class Gemma4RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6, with_scale=True):
        super().__init__()
        self.eps = eps
        self.with_scale = with_scale
        if self.with_scale:
            self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, hidden_states):
        normed = self._norm(hidden_states.float())
        if self.with_scale:
            normed = normed * self.weight.float()
        return normed.type_as(hidden_states)
4. Embedding Scaling (learnable buffer)
Token embeddings are multiplied by √hidden_size. This is a registered buffer (not a gradient-tracked parameter), initialized from the config. The per-layer embedding uses √per_layer_dim instead.
class Gemma4TextScaledWordEmbedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim,
                 padding_idx, embed_scale=1.0):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_buffer(
            "embed_scale", torch.tensor(embed_scale))

    def forward(self, input_ids):
        return super().forward(input_ids) * self.embed_scale

# Main embedding: scale = sqrt(5376) ≈ 73.3
# Per-layer embedding: scale = sqrt(256) = 16.0
5. Per-Layer Input: Third Residual Block
The per-layer input is NOT injected during attention or FFN. It's a third residual block applied after both. A gating linear projects hidden states to the per-layer dim, applies activation, multiplies element-wise with the per-layer embedding slice, projects back, norms, then adds as residual.
# After attention + FFN residuals:
if self.hidden_size_per_layer_input:
    residual = hidden_states
    hidden_states = self.per_layer_input_gate(hidden_states)  # d → 256
    hidden_states = self.act_fn(hidden_states)
    hidden_states = hidden_states * per_layer_input  # gated element-wise
    hidden_states = self.per_layer_projection(hidden_states)  # 256 → d
    hidden_states = self.post_per_layer_input_norm(hidden_states)
    hidden_states = residual + hidden_states
hidden_states (after attn+FFN) residual gate_proj (d→256) act_fn × per_layer_emb emb 256-dim slice for this layer projection (256→d) RMSNorm +
6. MoE: Parallel Dense + Routed with 3 Post-Norms
The dense MLP and MoE branch run in parallel from the same pre-MLP residual. Each output gets its own post-norm, then they're summed and the sum goes through a third post-norm before the residual add. MoE layers have 7 RMSNorm modules (vs 4 for standard layers).
# Dense path (standard MLP)
hidden_states = self.pre_feedforward_layernorm(residual)
hidden_states = self.mlp(hidden_states)

if self.enable_moe_block:
    hidden_states_1 = self.post_feedforward_layernorm_1(
        hidden_states)

    # MoE path (parallel, from pre-MLP residual)
    hidden_states_flat = residual.reshape(-1, residual.shape[-1])
    _, top_k_weights, top_k_index = self.router(
        hidden_states_flat)
    hidden_states_2 = self.pre_feedforward_layernorm_2(
        hidden_states_flat)
    hidden_states_2 = self.experts(
        hidden_states_2, top_k_index, top_k_weights)
    hidden_states_2 = self.post_feedforward_layernorm_2(
        hidden_states_2)

    # Sum both paths, then final post-norm
    hidden_states = hidden_states_1 + hidden_states_2
    hidden_states = self.post_feedforward_layernorm(
        hidden_states)

hidden_states = residual + hidden_states
residual (pre-MLP) DENSE pre_ff_norm GeGLU MLP post_ff_norm_1 MOE router pre_ff_norm_2 128 Experts post_ff_norm_2 Σ sum post_feedforward_norm 3rd post-norm (MoE-only) +
7. Router: Norm → Scale → Softmax → TopK → Per-Expert Scale
The router applies RMSNorm (without scale), multiplies by a learned per-dim scale × 1/√hidden_size, projects to num_experts, softmax, selects top-K, normalizes weights, then applies per-expert learned scales.
def forward(self, hidden_states):
    hidden_states = self.norm(hidden_states)        # RMSNorm, no scale
    hidden_states = hidden_states * self.scale * self.scalar_root_size

    expert_scores = self.proj(hidden_states)        # Linear → num_experts
    router_probs = softmax(expert_scores, dim=-1)

    top_k_weights, top_k_index = torch.topk(
        router_probs, k=self.config.top_k_experts)

    top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
    top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]

    return router_probs, top_k_weights, top_k_index
8. Layer Scalar (buffer, not parameter)
Each decoder layer multiplies its output by a per-layer scalar. This is a register_buffer (loaded from checkpoint, not trained via gradient), initialized to 1.0. Applied after all residual blocks (attention + FFN + optional per-layer input).
hidden_states *= self.layer_scalar  # buffer, not nn.Parameter
9. Logit Soft-Capping
Applied after the LM head projection. The config default is None (disabled), but published model checkpoints set it to 30.0. Bounds logits to [-30, 30] smoothly.
if self.config.final_logit_softcapping is not None:
    logits = logits / self.config.final_logit_softcapping
    logits = torch.tanh(logits)
    logits = logits * self.config.final_logit_softcapping
10. Proportional RoPE Config
Full attention layers use rope_type: "proportional" which sets the lowest 75% of frequency dimensions to zero (cos=1, sin=0), leaving those dimensions position-independent. Sliding layers use standard default RoPE. This is configured per-layer-type in the config, not in the modeling code.
# From Gemma4TextConfig:
rope_parameters = {
    "sliding_attention": {
        "rope_type": "default",
        "rope_theta": 10_000.0
    },
    "full_attention": {
        "rope_type": "proportional",
        "partial_rotary_factor": 0.25,
        "rope_theta": 1_000_000.0
    }
}

# Last layer forced to full_attention:
if self.layer_types[-1] != "full_attention":
    self.layer_types[-1] = "full_attention"
Sliding Layer (100% rotated) head_dim=256, theta=10K all 256 dims rotated (RoPE) 0 256 Full Layer (25% rotated) head_dim=512, theta=1M, partial=0.25 384 dims: cos=1, sin=0 (NoPE) 128 RoPE 0 384 512 Rotated (positional heads) Zeroed (semantic, position-free) Gemma 3 full layers: 128/128 rotated (100%) Gemma 4 full layers: 128/512 rotated (25%) 75% of dims are pure semantic channels that never break at long context
11. KV Cache Sharing (E2B/E4B)
Shared layers skip K/V projection entirely and load cached KV from the last non-shared layer of the same attention type (sliding or full). The query is still computed fresh. The stored KV includes the full sequence length for reuse.
# In shared layer forward:
if self.is_kv_shared_layer and past_key_values is not None:
    key_states, value_states = \
        past_key_values.shared_layers[self.kv_shared_layer_index]

# In non-shared layer: store KV for later reuse
if self.store_full_length_kv:
    past_key_values.shared_layers[self.layer_idx] = \
        key_states, value_states
Architecture Diagrams
Gemma 4 31B IT
Gemma 4 26B-A4B IT
Gemma 4 E2B IT
Previous Generation
Gemma 3 27B IT