Modular Manifolds Author: Jeremy Bernstein Published: Sep 26, 2025 Source: Thinking Machines Lab --- Overview Training large neural networks requires maintaining the "health" of tensors (weights, activations, gradients) to prevent problems like numerical underflow, overflow, and unstable training dynamics due to changing tensor magnitudes. Normalization is widely used for activations and gradients but less commonly for weight matrices. Weight constraints can normalize weights, improve optimization stability, and provide robustness guarantees. The post introduces a concept of constraining weight matrices to submanifolds and co-designing optimization algorithms with these constraints. Example: A manifold version of the Muon optimizer with weights constrained to the Stiefel manifold, where matrices have unit condition number. Ends with the theory of modular manifolds, aiming to scale manifold constraints and optimization to entire networks. --- The Shape of a Manifold Optimizer (Warm-up) Consider optimizing a vector constrained to a hypersphere (unit Euclidean norm). The tangent space at a point on the manifold is the set of directions orthogonal to the current vector. The optimization step is taken within the tangent space to maintain the manifold constraint precisely. Distance is measured typically with the Euclidean norm, but alternative norms change step directions. Formulated as a constrained optimization problem: Minimize linear loss change subject to: Step length constraint (norm of step equals learning rate) Tangent space constraint (step orthogonal to weight vector) Solution for optimal step direction: \[ a\text{opt} = -\eta \times \frac{g - w w^\top g}{\|g - ww^\top g\|2} \] After taking the tangent step, retraction maps project weights back exactly onto the manifold (e.g., normalization by \(\sqrt{1+\eta^2}\) for the hypersphere). Summarized manifold optimizer steps: Find unit tangent direction best aligned with negative gradient. Scale by learning rate and subtract from weights. Retract weights back to the manifold. Varying the manifold and norm choices recovers familiar optimizers (vanilla gradient descent, sign descent) and novel ones like Muon and manifold Muon for matrices. --- Manifold Muon Typical weight matrix \(W\) in a neural network maps input vectors \(x\) to output \(y = Wx\). Goal: constrain \(W\) so singular values are all 1, ensuring no excessive shrinking or stretching. This defines \(W\) to lie on the Stiefel manifold: \[ \mathsf{Stiefel}(m,n) := \{ W \in \mathbb{R}^{m \times n} \mid W^\top W = In \} \] Tangent space condition for \(A \in \mathbb{R}^{m \times n}\) at \(W\): \[ A^\top W + W^\top A = 0 \] Use spectral norm as distance measure (largest singular value). The constrained optimization for Manifold Muon: \[ \min{A} \operatorname{trace}(G^\top A) \quad \text{subject to} \quad \|A\|\text{spectral} \leq \eta, \quad A^\top W + W^\top A = 0 \] Solved via dual ascent: Introduce dual variable \(\Lambda\). Reformulate constrained problem into a saddle point problem. Solve inner minimization using the matrix sign function (projects singular values to ±1). Iterate gradient ascent on \(\Lambda\) to find an optimal solution. Algorithm steps: Update dual variable \(\Lambda\) by gradient ascent. Compute update \(A\text{opt} = -\eta \times \operatorname{msign}(G + 2W(\Lambda\mathrm{opt} + \Lambda\mathrm{opt}^\top))\). Update weights \(W \gets W + A_\text{opt}\). Retract \(W \gets \operatorname{msign}(W)\). Demonstrated on a small CIFAR-10 MLP, where manifold Muon improved test and train accuracy compared to AdamW. --- Modular Manifolds Extends manifold constraints and optimization logic to entire networks by treating each module (layer or subnetwork) as an object