Why this matters Most approaches to making Transformers efficient for very long contexts either use fixed placement rules (uniform interleaving) or score layers in isolation. Those heuristics ignore interdependent effects: the value of keeping a layer depends on which other layers remain full-attention. Treating hybrid conversion as a combinatorial subset-selection problem, FlashMorph offers a practical relaxation that directly optimizes layer choices under a fixed budget, dramatically reducing selection cost while preserving long-context recall.
Key Findings
- Reformulation: Hybrid layer selection is treated as a budget-constrained subset optimization and relaxed into a continuous gate-optimization problem over a morphable model (each pretrained full-attention layer paired with a converted linear-attention branch).
- Lightweight joint selection: Both pretrained full-attention branches and converted linear branches are frozen; only scalar layerwise gates are jointly optimized on synthetic long-context retrieval data, capturing inter-layer complementarity and redundancy.
- Empirical efficiency: Layer-selection uses 20M tokens, ~2.5e17 FLOPs and about 2.1 GPU hours for Qwen3-1.7B—substantially lower than prior methods (PostNAS ≈ 50B tokens, KL-LS ≈ 20B, HALO ≈ 234M tokens).
- Practical pipeline: After gate optimization, gate values are discretized under a preset full-attention budget to instantiate the hybrid model, followed by logits distillation and long-context finetuning.
- Preservation of quality: On Qwen3-series experiments across multiple linear-attention backbones, FlashMorph finds hybrid configs that retain strong long-context recall (Needle-in-a-Haystack retrieval variants), maintain zero-shot commonsense performance, and improve prefill/decode efficiency.
How it works (concise)
- Morphable model: For each layer introduce a linear-attention replacement and a continuous gate α∈[0,1] that interpolates between full and linear attention.
- Gate optimization: Freeze both attention branches; jointly optimize all α on synthetic retrieval tasks with a regularizer encouraging linearization (so the model prefers linear attention when possible).
- Discretization + training: Select the top-K layers by learned α to satisfy the full-attention budget, discard gates, and train the resulting hybrid model with logits distillation and long-context finetuning.
Who it's for + tradeoffs
Great fit if you need to convert pretrained Transformer/LLM backbones to hybrid attention for much longer contexts while minimizing the expensive layer-search stage—especially when you can run a lightweight selection phase and follow up with distillation/finetuning. It is also practical when you care about a measured quality–efficiency trade-off (explicit full-attention budget). Look elsewhere if you cannot afford any finetuning/distillation after selection, if your model architecture prevents straightforward linear-attention replacements, or if your production constraints forbid adding a selection training pass. Results also depend on the chosen linear-attention backend and on the representativeness of the synthetic retrieval data used during gate optimization.
Where it fits
Positioned between simple fixed-interleaving rules and expensive architecture-search or layerwise perturbation methods: FlashMorph captures joint layer effects at a tiny fraction of the token/GPU cost of prior search-based or layerwise methods, making hybridization scalable to larger models.
