Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Transformer Blocks

bunsen::blocks::transformers collects the building blocks used by transformer-family models — attention layers, their caching machinery, and positional embeddings.

API: https://docs.rs/bunsen/latest/bunsen/blocks/transformers/

Attention

The attention submodule houses the attention layers themselves and the helpers they’re built from.

CausalSelfAttention

CausalSelfAttention is multi-head causal self-attention with optional KV-grouping. The config carries:

  • n_head — number of query heads,
  • n_kv_head — number of key/value heads (must divide n_head; equals n_head for plain MHA, less for grouped-query attention),
  • n_embed — embedding dimension,
  • a pluggable NormalizationConfig applied inside the block.

The module exposes a CausalSelfAttentionMeta trait, implemented on both the config and the live module. Parents can read n_head, n_kv_head, and head_dim of whichever form they’re holding, so larger transformers don’t need to cache those numbers themselves. This is the pattern documented in Building Reusable Modules.

forward takes the input embedding plus an optional &mut KVCache for autoregressive decoding. When the cache is None the layer runs in training/prefill mode and recomputes K and V each call; when it’s Some, K and V are appended into the cache and read back across the full sequence.

KVCache

KVCache is the per-layer key/value tensor cache for fast incremental decoding. Built from a KVCacheConfig carrying batch_size, num_heads, seq_len, head_dim, and num_layers, it provides:

  • pos() — the current write head position,
  • prefill(...) — bulk-load K/V from a prompt encode,
  • insert_kv(...) — append a single decoded step’s K/V,
  • reset() — rewind to position 0 without reallocating.

NanoChatGpt uses one shared KVCache across all its layers; see bunsen::kits::gpts::nanochat for the integrated example.

Scaled-dot-product helpers

When you need to wire attention by hand — for a custom block, a fused-kernel experiment, or unit tests — the functional API is available:

  • scaled_dot_product_attention — the full SDPA op given Q, K, V and an optional mask/bias.
  • sdpa_attn_weight — just the softmax-of-scaled-QK^T factor.
  • sdpa_bias — build an additive bias tensor (causal mask, ALiBi, etc.) of the right shape for SDPA.

Embedding

The embedding submodule collects positional embeddings.

RotaryEmbedding

RotaryEmbedding is RoPE with a precomputed frequency table:

  • RotaryEmbeddingConfig::new(seq_len, head_dim) then .init(device) allocates the table once for the maximum sequence length.
  • apply(q, k) rotates query and key tensors.
  • clip_range(t0..t1) returns a sliced view for serving a partial sequence — the natural fit for KV-cache decoding, where each step only needs the rotations for the new positions.
  • cast(dtype) converts the precomputed table between float dtypes without recomputing the trigonometric values.

The free functions inverse_frequency_table and positional_frequency_table are exposed for callers that want to build their own variant of rotary embedding without going through the packaged module.