TokenAI
← Back to Research

STAM Optimizer

Stable Training with Adaptive Momentum

Assem Sabry · TokenAI

Published: 1 May 2026

Abstract

Adaptive gradient methods such as Adam and AdamW fix the first-order momentum coefficient β₁ (typically 0.9) for all timesteps and all parameters, regardless of gradient dynamics. This causes overshooting in high-variance regimes and misses faster-convergence opportunities near stationarity. We propose Stable Training with Adaptive Momentum (STAM), which adapts β₁ based on a per-tensor gradient variance proxy derived from momentum residuals. High variance reduces β₁ to damp oscillations; low variance preserves or increases β₁ to accelerate convergence. We further introduce STAMLite, a memory-efficient variant with only O(1) extra state per parameter — half the memory of the full variant and the same footprint as AdamW. Across 16 benchmark phases, STAM/STAMLite achieve top-3 performance on 83% of scored phases and win outright on hyperparameter robustness benchmarks.

Key Innovations

  • 1Adaptive first-moment control — β₁ becomes a variance-dependent signal, not a fixed hyperparameter
  • 2Residual-based variance proxy — uses g_t − m_{t−1} to measure gradient surprise
  • 3Auto-scaling normalization — the τ term eliminates manual temperature tuning
  • 4Exact bias correction for time-varying β₁ via running-product correction B_t
  • 5Memory-efficient STAMLite variant with O(1) extra state per parameter

The Adaptive Mechanism

STAM adapts β₁ at every step using a per-tensor gradient variance proxy derived from momentum residuals:

# gradient residual — how "surprising" the new gradient is

r_t = g_t - m_{{t-1}}

# EMA variance proxy & auto-scaling

σ²_t = EMA(mean(r_t²))

τ_t = EMA(mean(|r_t|))

# bounded signal in [0, 1]

z_t = σ²_t / (τ_t² + ε)

s_t = z_t / (1 + z_t)

# adaptive momentum coefficient

β₁(t) = β₁_base · (1 - adapt_strength · s_t)

High Variance (s_t → 1)

β₁(t) decreases, damping oscillations and reducing noise carry-over.

Low Variance (s_t → 0)

β₁(t) approaches base, maintaining momentum and accelerating convergence.

Variant

STAM Full

Maximum adaptive fidelity with per-tensor variance estimation and auto-scaling

Memory~2× AdamW state
Best ForNon-stationary data, small batches, continual learning

Variant

STAMLite

Efficient approximation using gradient moments instead of residuals

Memory~1× AdamW state
Best ForStandard training, large models, default AdamW replacement

Phase 9 — Long-Horizon Non-Stationary MLP

STAM Full is a very close 2nd to NAdam under sustained distribution shift

OptimizerAccuracyLossRank
NAdam0.9786780.0921831st
STAM Full0.9743650.1113432nd
STAMLite0.9615070.1553003rd
RMSProp0.9577640.212180
LAMB0.9267580.338485
SGD+Momentum0.5701501.2354
Adagrad0.3352861.7839

Phase 10 — Hyperparameter Sweep (Headline Result)

STAMLite wins outright — the clearest robustness result in the paper

OptimizerAccuracyLossRank
STAMLite0.6140590.9184071st
RMSProp0.6099180.9323972nd
NAdam0.6022140.9527153rd
STAM Full0.5983800.962011
LAMB0.5017361.1961
SGD+Momentum0.4382051.3058
Adagrad0.3488501.5452

All 11 Benchmark Phases — STAM Overview

Phase 2 — Convex Regression

Winner: SGD+MomentumNot designed for stationary convex tasks

Phase 3 — Non-Stationary MLP

Winner: RMSPropSTAM Full 0.7956, STAMLite 0.7826

Phase 4 — Small-Batch Stress

Winner: LionSTAMLite better stability than AdamW

Phase 5b — Advanced Robustness

Winner: LionCompetitive across all metrics

Phase 5c — LR Sweep

Winner: STAMLite (robustness)Best learning-rate forgiveness

Phase 5e — Ablation Study

Winner: STAMLite (0.5063)Validates adaptive β₁ contribution

Phase 6 — MNIST / Fashion-MNIST

Winner: CompetitiveSTAM/STAMLite competitive with all baselines

Phase 7 — CIFAR-10 CNN

Winner: NAdam (0.6096)STAM Full 0.5912 — strong 2nd place

Phase 8 — Tiny Transformer

Winner: CompetitiveSTAM competitive on sequence modeling

Phase 9 — Long-Horizon Shift

Winner: NAdam (0.9787)STAM Full 0.9744 — very close 2nd

Phase 10 — Hyperparam Sweep

Winner: STAMLite (0.6141)STAMLite WINS — clearest robustness result

Installation & Usage

# Install from PyPI

pip install stam-optimizer

# Quick start (STAM Full)

from stam_optimizer import STAM

optimizer = STAM(learning_rate=1e-3, b1_base=0.9, adapt_strength=0.2)

# Quick start (STAMLite — recommended default)

from stam_optimizer import STAMLite

optimizer = STAMLite(learning_rate=1e-3, beta1_update_interval=5)

Framework

JAX 0.10.0 + Optax 0.2.8

License

MIT