STAM Optimizer
Stable Training with Adaptive Momentum
Assem Sabry · TokenAI
Published: 1 May 2026
Abstract
Adaptive gradient methods such as Adam and AdamW fix the first-order momentum coefficient β₁ (typically 0.9) for all timesteps and all parameters, regardless of gradient dynamics. This causes overshooting in high-variance regimes and misses faster-convergence opportunities near stationarity. We propose Stable Training with Adaptive Momentum (STAM), which adapts β₁ based on a per-tensor gradient variance proxy derived from momentum residuals. High variance reduces β₁ to damp oscillations; low variance preserves or increases β₁ to accelerate convergence. We further introduce STAMLite, a memory-efficient variant with only O(1) extra state per parameter — half the memory of the full variant and the same footprint as AdamW. Across 16 benchmark phases, STAM/STAMLite achieve top-3 performance on 83% of scored phases and win outright on hyperparameter robustness benchmarks.
Key Innovations
- 1Adaptive first-moment control — β₁ becomes a variance-dependent signal, not a fixed hyperparameter
- 2Residual-based variance proxy — uses g_t − m_{t−1} to measure gradient surprise
- 3Auto-scaling normalization — the τ term eliminates manual temperature tuning
- 4Exact bias correction for time-varying β₁ via running-product correction B_t
- 5Memory-efficient STAMLite variant with O(1) extra state per parameter
The Adaptive Mechanism
STAM adapts β₁ at every step using a per-tensor gradient variance proxy derived from momentum residuals:
# gradient residual — how "surprising" the new gradient is
r_t = g_t - m_{{t-1}}
# EMA variance proxy & auto-scaling
σ²_t = EMA(mean(r_t²))
τ_t = EMA(mean(|r_t|))
# bounded signal in [0, 1]
z_t = σ²_t / (τ_t² + ε)
s_t = z_t / (1 + z_t)
# adaptive momentum coefficient
β₁(t) = β₁_base · (1 - adapt_strength · s_t)
High Variance (s_t → 1)
β₁(t) decreases, damping oscillations and reducing noise carry-over.
Low Variance (s_t → 0)
β₁(t) approaches base, maintaining momentum and accelerating convergence.
Variant
STAM Full
Maximum adaptive fidelity with per-tensor variance estimation and auto-scaling
Variant
STAMLite
Efficient approximation using gradient moments instead of residuals
Phase 9 — Long-Horizon Non-Stationary MLP
STAM Full is a very close 2nd to NAdam under sustained distribution shift
| Optimizer | Accuracy | Loss | Rank |
|---|---|---|---|
| NAdam | 0.978678 | 0.092183 | 1st |
| STAM Full | 0.974365 | 0.111343 | 2nd |
| STAMLite | 0.961507 | 0.155300 | 3rd |
| RMSProp | 0.957764 | 0.212180 | |
| LAMB | 0.926758 | 0.338485 | |
| SGD+Momentum | 0.570150 | 1.2354 | |
| Adagrad | 0.335286 | 1.7839 |
Phase 10 — Hyperparameter Sweep (Headline Result)
STAMLite wins outright — the clearest robustness result in the paper
| Optimizer | Accuracy | Loss | Rank |
|---|---|---|---|
| STAMLite | 0.614059 | 0.918407 | 1st |
| RMSProp | 0.609918 | 0.932397 | 2nd |
| NAdam | 0.602214 | 0.952715 | 3rd |
| STAM Full | 0.598380 | 0.962011 | |
| LAMB | 0.501736 | 1.1961 | |
| SGD+Momentum | 0.438205 | 1.3058 | |
| Adagrad | 0.348850 | 1.5452 |
All 11 Benchmark Phases — STAM Overview
Phase 2 — Convex Regression
Phase 3 — Non-Stationary MLP
Phase 4 — Small-Batch Stress
Phase 5b — Advanced Robustness
Phase 5c — LR Sweep
Phase 5e — Ablation Study
Phase 6 — MNIST / Fashion-MNIST
Phase 7 — CIFAR-10 CNN
Phase 8 — Tiny Transformer
Phase 9 — Long-Horizon Shift
Phase 10 — Hyperparam Sweep
Installation & Usage
# Install from PyPI
pip install stam-optimizer
# Quick start (STAM Full)
from stam_optimizer import STAM
optimizer = STAM(learning_rate=1e-3, b1_base=0.9, adapt_strength=0.2)
# Quick start (STAMLite — recommended default)
from stam_optimizer import STAMLite
optimizer = STAMLite(learning_rate=1e-3, beta1_update_interval=5)
Framework
JAX 0.10.0 + Optax 0.2.8
License
MIT