← Back to papers

Paper deep dive

SplInterp: Improving our Understanding and Training of Sparse Autoencoders

Jeremy Budd, Javier Ideami, Benjamin Macdowall Rynne, Keith Duggar, Randall Balestriero

Year: 2025Venue: arXiv preprintArea: Mechanistic Interp.Type: TheoreticalEmbeddings: 206

Models: Gemma-2-2B

Intelligence

Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 96%

Last extracted: 3/11/2026, 12:37:41 AM

Summary

The paper introduces 'SplInterp', a framework that analyzes Sparse Autoencoders (SAEs) through the lens of spline theory. It demonstrates that SAEs are piecewise affine splines, characterizes their geometry using power diagrams, and proposes a novel training algorithm, PAM-SGD, which improves sample efficiency and sparsity in LLM and MNIST experiments.

Entities (5)

PAM-SGD · algorithm · 100%Sparse Autoencoders · model-architecture · 100%Mechanistic Interpretability · research-field · 98%Power diagrams · mathematical-concept · 95%Spline theory of deep learning · theoretical-framework · 95%

Relation Signals (3)

PAM-SGD trains Sparse Autoencoders

confidence 100% · we develop a novel proximal alternating method SGD (PAM-SGD) algorithm for training SAEs

Power diagrams characterizesgeometryof Sparse Autoencoders

confidence 95% · We characterise the underlying geometry of (TopK) SAEs using power diagrams.

Spline theory of deep learning explains Sparse Autoencoders

confidence 95% · we seek to enhance the theoretical understanding of SAEs, using the spline theory of deep learning.

Cypher Suggestions (2)

Find all algorithms used to train Sparse Autoencoders · confidence 90% · unvalidated

MATCH (a:Algorithm)-[:TRAINS]->(s:ModelArchitecture {name: 'Sparse Autoencoders'}) RETURN a.name

Map the relationship between theoretical frameworks and model architectures · confidence 90% · unvalidated

MATCH (t:TheoreticalFramework)-[:EXPLAINS]->(m:ModelArchitecture) RETURN t.name, m.name

Abstract

Abstract:Sparse autoencoders (SAEs) have received considerable recent attention as tools for mechanistic interpretability, showing success at extracting interpretable features even from very large LLMs. However, this research has been largely empirical, and there have been recent doubts about the true utility of SAEs. In this work, we seek to enhance the theoretical understanding of SAEs, using the spline theory of deep learning. By situating SAEs in this framework: we discover that SAEs generalise ``$k$-means autoencoders'' to be piecewise affine, but sacrifice accuracy for interpretability vs. the optimal ``$k$-means-esque plus local principal component analysis (PCA)'' piecewise affine autoencoder. We characterise the underlying geometry of (TopK) SAEs using power diagrams. And we develop a novel proximal alternating method SGD (PAM-SGD) algorithm for training SAEs, with both solid theoretical foundations and promising empirical results in MNIST and LLM experiments, particularly in sample efficiency and (in the LLM setting) improved sparsity of codes. All code is available at: this https URL

Tags

ai-safety (imported, 100%)mechanistic-interp (suggested, 92%)theoretical (suggested, 88%)

Links

PDF not stored locally. Use the link above to view on the source site.

Full Text

205,991 characters extracted from source content.

Expand or collapse full text

./ SplInterp: Improving our Understanding and Training of Sparse Autoencoders Jeremy Budd School of Mathematics University of Birmingham &Javier Ideami Ideami Studios &Benjamin Macdowall Rynne Department of Mathematics and Statistics University of Limerick &Keith Duggar XRAI Inc. &Randall Balestriero Department of Computer Science Brown University Corresponding author: j.m.budd@bham.ac.uk Abstract Sparse autoencoders (SAEs) have received considerable recent attention as tools for mechanistic interpretability, showing success at extracting interpretable features even from very large LLMs. However, this research has been largely empirical, and there have been recent doubts about the true utility of SAEs. In this work, we seek to enhance the theoretical understanding of SAEs, using the spline theory of deep learning. By situating SAEs in this framework: we discover that SAEs generalise “k-means autoencoders” to be piecewise affine, but sacrifice accuracy for interpretability vs. the optimal “k-means-esque plus local principal component analysis (PCA)” piecewise affine autoencoder. We characterise the underlying geometry of (TopK) SAEs using power diagrams. And we develop a novel proximal alternating method SGD (PAM-SGD) algorithm for training SAEs, with both solid theoretical foundations and promising empirical results in MNIST and LLM experiments, particularly in sample efficiency and (in the LLM setting) improved sparsity of codes. All code is available at: https://github.com/splInterp2025/splInterp 1 Introduction Figure 1: Key ideas in this work. Left: Sparse autoencoders (SAEs) vs. regular autoencoders; in an SAE, the code z can have very high dimension, but most entries are zero (greyed out). Right: The superior sample efficiency (on MNIST and Gemma-2-2B) of our novel proximal alternating method SGD (PAM-SGD) algorithm. One of the fundamental challenges in modern AI is the interpretability of machine learning systems: how can we look inside these increasingly more complicated (and capable) “black boxes”? The better we understand what makes these systems tick, the better we can diagnose problems with them, direct their behaviour, and ultimately build more effective, reliable, and fair systems. Indeed, for some tasks interpretability may be a legal requirement, for example Article 86 of the European Union’s AI Act describes a “right to explanation” for persons affected by the decisions of certain types of AI system. A mechanistic interpretability technique that has seen considerable recent attention is to use sparse autoencoders (SAEs), see e.g. Bricken et al. (2023); Huben et al. (2024); Gao et al. (2025). A major obstacle in interpreting a neural network is that neurons seem to be polysemantic, responding to mixtures of seemingly unrelated features Olah et al. (2017). One hypothesis is that this is caused by superposition, a phenomenon in which a neural network represents features via (linear) combinations of neurons, to pack in more features. In Elhage et al. (2022), superposition was shown to arise in a toy model in response to sparsity of the underlying features. The key idea of using SAEs was that these might be able to disentangle this superposition and extract monosemantic features. Initial work found some success, e.g. Templeton et al. (2024) used SAEs to extract millions of features from Claude 3 Sonnet, including highly interpretable ones such as a “Golden Gate Bridge” feature, which could be used to make Claude incessantly talk about the bridge. However, there have been some recent doubts about the utility of SAEs for mechanistic interpretability. In a recent post by the Google DeepMind mechanistic interpretability team Smith et al. (2025), works finding key issues with SAEs were highlighted, e.g. Leask et al. (2025), and SAEs were found to underperform linear probes at a downstream task. The team argued that whilst SAEs are not useless, they will not be a game-changer for interpretability, and speculated that the field is over-invested in them. Furthermore, in what we will dub a “dead salmon” experiment (in honour of Bennett et al. (2009)), Heap et al. (2025) found that SAEs can extract features from randomly weighted Transformers that have very similar auto-interpretability scores to features extracted from a trained network–suggesting that SAE “interpretations” may not reflect what is actually going on in a model. This empirical uncertainty motivated us to look at SAEs through a more theoretical lens, inspired by the spline theory of deep learning Balestriero & Baraniuk (2018). Using this perspective, we: I. Unify, situating SAEs within the spline theory framework, and showing how SAEs form a bridge between the classical ML techniques of k-means and principal component analysis (PCA) and contemporary deep learning. (Section 2, proofs in Appendix B) I. Interpret, characterising and visualising the spline geometry of (TopK) SAEs in terms of weighted Voronoi diagrams. (Sections 2, A and B) I. Innovate, developing a novel proximal alternating method SGD (PAM-SGD) algorithm for training SAEs, with both solid theoretical foundations and promising empirical results, which is inspired by the spline geometric way of thinking. In particular, we find in both MNIST and LLM experiments that PAM-SGD outperforms SGD in low-training-data settings, addressing an important concern with SAEs. (Sections 3, C and E) 2 The spline geometry of sparse autoencoders (SAEs) 2.1 A primer on SAEs An SAE composes an encoding, which maps an input x∈ℝnsuperscriptℝx ^nx ∈ blackboard_Rn to a code z∈ℝdsuperscriptℝz ^dz ∈ blackboard_Rd (enginereed to be sparse), with a decoding which maps z to an output x^∈ℝn^superscriptℝ x ^nover start_ARG x end_ARG ∈ blackboard_Rn (engineered so that x^≈x x≈ xover start_ARG x end_ARG ≈ x). Unlike a traditional autoencoder, in an SAE one may choose the hidden dimension d≫nmuch-greater-thand nd ≫ n, but the sparsity of z will be engineered to be much less than n, see Figure 1(left). The SAE encoding is given by z:=ρ⁢(Wenc⁢x+benc),assignsubscriptencsubscriptencz:=ρ(W_encx+b_enc),z := ρ ( Wenc x + benc ) , where Wenc∈ℝd×nsubscriptencsuperscriptℝW_enc ^d× nWenc ∈ blackboard_Rd × n, bdec∈ℝdsubscriptdecsuperscriptℝb_dec ^dbdec ∈ blackboard_Rd, and ρ is a given activation function. Notable choices for ρ include ReLU Bricken et al. (2023), JumpReLU Rajamanoharan et al. (2024) where τ∈ℝτ τ ∈ blackboard_R is a parameter and ρ⁢(v)i:=vi,if ⁢vi>τ,0,otherwise,assignsubscriptcasessubscriptif subscript0otherwiseρ(v)_i:= casesv_i,&if v_i>τ,\\ 0,&otherwise, casesρ ( v )i := start_ROW start_CELL vitalic_i , end_CELL start_CELL if vitalic_i > τ , end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL otherwise , end_CELL end_ROW and TopK Makhzani & Frey (2014); Gao et al. (2025) where K∈ℕK ∈ blackboard_N is a parameter and ρ⁢(v)i:=vi,if vi is among the K largest entries of v,0,otherwise.assignsubscriptcasessubscriptif vi is among the K largest entries of v,0otherwiseρ(v)_i:= casesv_i,&if $v_i$ is among the $K$ largest % entries of $v$,\\ 0,&otherwise. casesρ ( v )i := start_ROW start_CELL vitalic_i , end_CELL start_CELL if vitalic_i is among the K largest entries of v , end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL otherwise . end_CELL end_ROW The decoding is then given by x^:=Wdec⁢z+bdec,assign^subscriptdecsubscriptdec x:=W_decz+b_dec,over start_ARG x end_ARG := Wdec z + bdec , where Wdec∈ℝn×dsubscriptdecsuperscriptℝW_dec ^n× dWdec ∈ blackboard_Rn × d and bdec∈ℝnsubscriptdecsuperscriptℝb_dec ^nbdec ∈ blackboard_Rn. The columns of WdecsubscriptdecW_decWdec can be understood as dictionary atoms (see Olshausen & Field (1996)) which are sparsely recombined (with bias) to recover x^ xover start_ARG x end_ARG. The full SAE is therefore given by Sρ⁢(x)subscript S_ρ(x)Sitalic_ρ ( x ) :=Wdec⁢ρ⁢(Wenc⁢x+benc)+bdec.assignabsentsubscriptdecsubscriptencsubscriptencsubscriptdec :=W_decρ(W_encx+b_enc)+b_dec% .:= Wdec ρ ( Wenc x + benc ) + bdec . Finally, following Rajamanoharan et al. (2024), given training data xrr=1N∈ℝnsuperscriptsubscriptsuperscript1superscriptℝ\x^r\_r=1^N ^n xitalic_r r = 1N ∈ blackboard_Rn, we will consider loss functions for training an SAE of the form: ℒ =∑r=1N‖Sρ⁢(xr)−xr‖22+λ⁢ℒsparsity⁢(ρ⁢(Wenc⁢xr+benc)r=1N)+ℒaux,absentsuperscriptsubscript1superscriptsubscriptnormsubscriptsuperscriptsuperscript22subscriptℒsparsitysuperscriptsubscriptsubscriptencsuperscriptsubscriptenc1subscriptℒaux = _r=1^N\|S_ρ(x^r)-x^r\|_2^2+λ% L_sparsity (\ρ(W_encx^r+b_enc)\_r=1^% N )+L_aux,= ∑r = 1N ∥ Sitalic_ρ ( xitalic_r ) - xitalic_r ∥22 + λ Lsparsity ( ρ ( Wenc xitalic_r + benc ) r = 1N ) + Laux , where ℒauxsubscriptℒauxL_auxLaux might include regularisation, e.g. weight decay. Some activations, e.g. TopK, always produce a sparse z, so one may set ℒsparsity=0subscriptℒsparsity0L_sparsity=0Lsparsity = 0. Others, e.g. ReLU, do not inherently make z sparse, in which case common choices of ℒsparsitysubscriptℒsparsityL_sparsityLsparsity include the ℓ1subscriptℓ1 _1ℓ1 norm Bricken et al. (2023), the ℓ0subscriptℓ0 _0ℓ0 norm Rajamanoharan et al. (2024), and the Kullback–Leibler divergence to a sparse distribution Ng (2011). 2.2 SAEs are piecewise affine splines We first note a simple fact about our SAEs, also observed in Hindupur et al. (2025) (in different notation). In all three cases of ReLU, JumpReLU, and TopK, for some S⊆1,…,d1…S \1,...,d\S ⊆ 1 , … , d we have that ρ⁢(v)=PS⁢v,subscriptρ(v)=P_Sv,ρ ( v ) = Pitalic_S v , where PS∈ℝd×dsubscriptsuperscriptℝP_S ^d× dPitalic_S ∈ blackboard_Rd × d is the projection that zeroes the entries of v which are not in S. In the case of JumpReLU (of which ReLU is a special case) S is the set of indices i such that vi>τsubscriptv_i> _i > τ, and in the case of TopK S is the set of indices containing the largest K entries. Therefore, let us define ΩSJumpReLUsuperscriptsubscriptΩJumpReLU _S^JumpReLUΩitalic_Sroman_JumpReLU :=x∈ℝn:∀i∈S,(Wenc⁢x+benc)i>τ⁢ and ⁢∀j∉S,(Wenc⁢x+benc)j<τ,assignabsentconditional-setsuperscriptℝformulae-sequenceformulae-sequencefor-allsubscriptsubscriptencsubscriptenc and for-allsubscriptsubscriptencsubscriptenc :=\x ^n:∀ i∈ S,(W_encx+b_% enc)_i>τ and ∀ j∉ S,(W_encx+b_enc)_% j<τ\,:= x ∈ blackboard_Rn : ∀ i ∈ S , ( Wenc x + benc )i > τ and ∀ j ∉ S , ( Wenc x + benc )j < τ , ΩSTopKsuperscriptsubscriptΩTopK _S^TopKΩitalic_Sroman_TopK :=x∈ℝn:∀i∈S,j∉S,(Wenc⁢x+benc)i>(Wenc⁢x+benc)j,assignabsentconditional-setsuperscriptℝformulae-sequencefor-allformulae-sequencesubscriptsubscriptencsubscriptencsubscriptsubscriptencsubscriptenc :=\x ^n:∀ i∈ S,j∉ S,\>(W_enc% x+b_enc)_i>(W_encx+b_enc)_j\,:= x ∈ blackboard_Rn : ∀ i ∈ S , j ∉ S , ( Wenc x + benc )i > ( Wenc x + benc )j , where in the former S can be any subset of 1,…,d1…\1,...,d\ 1 , … , d and in the latter S must be a subset of size K. Then for ρ=JumpReLUJumpReLUρ=JumpReLUρ = JumpReLU or ρ=TopKTopKρ=TopKρ = TopK the SAE becomes: Sρ⁢(x)=Wdec⁢PS⁢(Wenc⁢x+benc)+bdec,x∈ΩSρ,subscriptcasessubscriptdecsubscriptsubscriptencsubscriptencsubscriptdecsubscriptsuperscriptΩS_ρ(x)= casesW_decP_S(W_encx+b_enc)+b% _dec,&x∈ ^ρ_S, casesSitalic_ρ ( x ) = start_ROW start_CELL Wdec Pitalic_S ( Wenc x + benc ) + bdec , end_CELL start_CELL x ∈ Ωitalic_ρitalic_S , end_CELL end_ROW which is a piecewise affine spline. Note that the ΩSJumpReLUsuperscriptsubscriptΩJumpReLU _S^JumpReLUΩitalic_Sroman_JumpReLU and ΩSTopKsuperscriptsubscriptΩTopK _S^TopKΩitalic_Sroman_TopK do not entirely partition the space, e.g. what if Wenc⁢x+bencsubscriptencsubscriptencW_encx+b_encWenc x + benc has an entry equal to τ or has ties for the top K? Such xxxs (a set of measure zero) form the boundaries of these pieces, and SρsubscriptS_ρSitalic_ρ is discontinuous at these boundaries (except in the ReLU case, i.e. τ=00τ=0τ = 0). Both ΩSρsuperscriptsubscriptΩ _S^ρΩitalic_Sitalic_ρ can be written in the form x∈ℝn:H⁢x>cconditional-setsuperscriptℝ\x ^n:Hx>c\ x ∈ blackboard_Rn : H x > c , for appropriate matrices H and vectors c, see Theorem B.1. They are thus open and convex sets, and by Theorem B.1 are the interiors of convex polyhedra except in degenerate cases of WencsubscriptencW_encWenc. Going beyond this simple characterisation, we introduce a new geometric characterisation of the TopK pieces as the cells of a power diagram. (For some visualisations of these notions, see Appendix A.) Definition 2.1 (Power and Voronoi diagrams). A power diagram (a.k.a. a Laguerre–Voronoi diagram) is a partition of ℝnsuperscriptℝR^nblackboard_Rn into k cells, defined by taking centroids μii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n μitalic_i i = 1k ∈ blackboard_Rn and weights αii=1k∈ℝsuperscriptsubscriptsubscript1ℝ\ _i\_i=1^k αitalic_i i = 1k ∈ blackboard_R and defining the ithsuperscriptthi^thith cell to be Ci:=x∈ℝn:‖x−μi‖22−αi<‖x−μj‖22−αj⁢ for all ⁢j≠i.assignsubscriptconditional-setsuperscriptℝsuperscriptsubscriptnormsubscript22subscriptsuperscriptsubscriptnormsubscript22subscript for all C_i:=\x ^n:\|x- _i\|_2^2- _i<\|x- _j\|_2% ^2- _j for all j≠ i\.Citalic_i := x ∈ blackboard_Rn : ∥ x - μitalic_i ∥22 - αitalic_i < ∥ x - μitalic_j ∥22 - αitalic_j for all j ≠ i . A Voronoi diagram is given by the special case when the αisubscript _iαitalic_i are constant in i. We further define a KthsuperscriptthK^thKth-order power diagram with centroids μii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n μitalic_i i = 1k ∈ blackboard_Rn, weights αii=1k∈ℝsuperscriptsubscriptsubscript1ℝ\ _i\_i=1^k αitalic_i i = 1k ∈ blackboard_R, and (kK)binomial kK( FRACOP start_ARG k end_ARG start_ARG K end_ARG ) cells, where for S⊆1,…,k1…S \1,...,k\S ⊆ 1 , … , k with |S|=K|S|=K| S | = K, let the SthsuperscriptthS^thSth cell be CS:=x∈ℝn:‖x−μi‖22−αi<‖x−μj‖22−αj⁢ for all i∈S and j∈Sc.assignsubscriptconditional-setsuperscriptℝsuperscriptsubscriptnormsubscript22subscriptsuperscriptsubscriptnormsubscript22subscript for all i∈S and j∈ScC_S:=\x ^n:\|x- _i\|_2^2- _i<\|x- _j\|_2% ^2- _j for all $i∈ S$ and $j∈ S^c$\.Citalic_S := x ∈ blackboard_Rn : ∥ x - μitalic_i ∥22 - αitalic_i < ∥ x - μitalic_j ∥22 - αitalic_j for all i ∈ S and j ∈ Sitalic_c . An identical power diagram (of any order) is given if a constant is added to all the weights. Note 2.2. All KthsuperscriptthK^thKth-order power diagrams are power diagrams, see Theorem B.2, but the converse does not hold, i.e. not all power diagrams with (kK)binomial kK( FRACOP start_ARG k end_ARG start_ARG K end_ARG ) centroids and weights can be written as a KthsuperscriptthK^thKth-order power diagram with k centroids and weights. As a counterexample, let k=44k=4k = 4, K=22K=2K = 2, and the power diagram centroids be the six vertices of the regular hexagon. There are no four vectors in ℝ2superscriptℝ2R^2blackboard_R2 whose pairwise means are the vertices of the regular hexagon. Theorem 2.3. The cells ΩSTopKsubscriptsuperscriptΩTopK\ ^TopK_S\ Ωroman_TopKitalic_S form a KthsuperscriptthK^thKth-order power diagram with (dK)binomial dK( FRACOP start_ARG d end_ARG start_ARG K end_ARG ) cells. Conversely, for any KthsuperscriptthK^thKth-order power diagram with (dK)binomial dK( FRACOP start_ARG d end_ARG start_ARG K end_ARG ) cells with centroids μii=1dsuperscriptsubscriptsubscript1\ _i\_i=1^d μitalic_i i = 1d and weights αii=1dsuperscriptsubscriptsubscript1\ _i\_i=1^d αitalic_i i = 1d, there exist We⁢n⁢c∈ℝd×nsubscriptsuperscriptℝW_enc ^d× nWitalic_e n c ∈ blackboard_Rd × n and be⁢n⁢c∈ℝdsubscriptsuperscriptℝb_enc ^dbitalic_e n c ∈ blackboard_Rd such that the resulting TopK SAE is affine on the cells of that KthsuperscriptthK^thKth-order power diagram. The translations between each setting are given by: eiT⁢We⁢n⁢c=μiTsubscriptsuperscriptsubscriptsuperscriptsubscript e^T_iW_enc= _i^Teitalic_Titalic_i Witalic_e n c = μitalic_iitalic_T and (be⁢n⁢c)i=12⁢αi−12⁢‖μi‖22,subscriptsubscript12subscript12superscriptsubscriptnormsubscript22 (b_enc)_i= 12 _i- 12\| _i\|_2^2,( bitalic_e n c )i = divide start_ARG 1 end_ARG start_ARG 2 end_ARG αitalic_i - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ μitalic_i ∥22 , (2.1a) μi=We⁢n⁢cT⁢eisubscriptsubscriptsuperscriptsubscript _i=W^T_ence_iμitalic_i = Witalic_Titalic_e n c eitalic_i and αi=2⁢(be⁢n⁢c)i+‖We⁢n⁢cT⁢ei‖22,subscript2subscriptsubscriptsuperscriptsubscriptnormsubscriptsuperscriptsubscript22 _i=2(b_enc)_i+\|W^T_ence_i\|_2^2,αitalic_i = 2 ( bitalic_e n c )i + ∥ Witalic_Titalic_e n c eitalic_i ∥22 , (2.1b) for all i, where eisubscripte_ieitalic_i is the elementary basis vector with 1111 in coordinate i and 00 in every other coordinate. It follows from Theorem B.2 that the cells ΩSTopKsubscriptsuperscriptΩTopK\ ^TopK_S\ Ωroman_TopKitalic_S form a power diagram with (dK)binomial dK( FRACOP start_ARG d end_ARG start_ARG K end_ARG ) cells, given by centroids νSsubscript\ _S\ νitalic_S and weights βSsubscript\ _S\ βitalic_S defined by νS:=1K⁢∑i∈SWe⁢n⁢cT⁢eiassignsubscript1subscriptsubscriptsuperscriptsubscript _S:= 1K _i∈ SW^T_ence_iνitalic_S := divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S Witalic_Titalic_e n c eitalic_i and βS:=‖1K⁢∑i∈SWe⁢n⁢cT⁢ei‖22+1K⁢∑i∈S2⁢(be⁢n⁢c)i.assignsubscriptsuperscriptsubscriptnorm1subscriptsubscriptsuperscriptsubscript221subscript2subscriptsubscript _S:= \| 1K _i∈ SW^T_ence_i % \|_2^2+ 1K _i∈ S2(b_enc)_i.βitalic_S := ∥ divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S Witalic_Titalic_e n c eitalic_i ∥22 + divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S 2 ( bitalic_e n c )i . (2.2) The converse is false: a given power diagram with (dK)binomial dK( FRACOP start_ARG d end_ARG start_ARG K end_ARG ) cells can describe the cells upon which a TopK SAE is piecewise affine if and only if it can be written as a KthsuperscriptthK^thKth-order power diagram. What Theorem 2.3 tells us is exactly the spline geometries that TopK SAEs can have, namely that these are exactly the KthsuperscriptthK^thKth-order power diagrams. Indeed, we can explicitly derive the encoding parameters that give rise to a particular spline geometry. This opens the door to engineering TopK SAEs with desirable geometric features, by translating those features into constraints on the parameters. As an example of a geometric feature one might desire to encourage, Humayun et al. (2024) related the generalisability and robustness produced by neural network grokking (see Power et al. (2022)) to the local complexity of the spline geometry. 2.3 SAEs, k-means, and principal component analysis (PCA) By the above, all of the above SAEs are piecewise affine functions on regions ΩSsubscriptΩ\ _S\ Ωitalic_S , with rank |S||S|| S | on ΩSsubscriptΩ _SΩitalic_S. We can compare this to the k-means clustering. Definition 2.4 (k-means clustering). Given data xrr=1N∈ℝnsuperscriptsubscriptsuperscript1superscriptℝ\x^r\_r=1^N ^n xitalic_r r = 1N ∈ blackboard_Rn, the k-means clustering Steinhaus (1957) seeks k regions Rii=1k⊆ℝnsuperscriptsubscriptsubscript1superscriptℝ\R_i\_i=1^k ^n Ritalic_i i = 1k ⊆ blackboard_Rn and centroids νii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n νitalic_i i = 1k ∈ blackboard_Rn minimising: ∑i=1k∑xr∈Ri‖xr−νi‖22.superscriptsubscript1subscriptsuperscriptsubscriptsubscriptsuperscriptnormsuperscriptsubscript22 _i=1^k _x^r∈ R_i\|x^r- _i\|^2_2.∑i = 1k ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ xitalic_r - νitalic_i ∥22 . This is achieved when νisubscript _iνitalic_i are the in-region means and RisubscriptR_iRitalic_i are the following Voronoi cells: νi=x¯i:=1|r:xr∈Ri|⁢∑xr∈Rixr,Ri=x∈ℝn:‖x−νi‖22≤‖x−νj‖22⁢ for all j≠i.formulae-sequencesubscriptsubscript¯assign1conditional-setsuperscriptsubscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptconditional-setsuperscriptℝsuperscriptsubscriptnormsubscript22superscriptsubscriptnormsubscript22 for all j≠i _i= x_i:= 1|\r:x^r∈ R_i\| _x^r∈ R_ix^r% , R_i=\x ^n:\|x- _i\|_2^2≤\|x- _j\|_2% ^2 for all $j≠ i$\.νitalic_i = over¯ start_ARG x end_ARGi := divide start_ARG 1 end_ARG start_ARG | r : xitalic_r ∈ Ritalic_i | end_ARG ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT xitalic_r , Ritalic_i = x ∈ blackboard_Rn : ∥ x - νitalic_i ∥22 ≤ ∥ x - νitalic_j ∥22 for all j ≠ i . Note 2.5. Suppose we have k regions ΩSii=1ksuperscriptsubscriptsubscriptΩsubscript1\ _S_i\_i=1^k Ωitalic_S start_POSTSUBSCRIPT i end_POSTSUBSCRIPT i = 1k. Consider the piecewise constant encoding fe⁢n⁢c⁢(x):=eiassignsubscriptsubscriptf_enc(x):=e_ifitalic_e n c ( x ) := eitalic_i for x∈ΩSisubscriptΩsubscriptx∈ _S_ix ∈ Ωitalic_S start_POSTSUBSCRIPT i end_POSTSUBSCRIPT and the linear decoding x^=Wdec⁢z^subscriptdec x=W_deczover start_ARG x end_ARG = Wdec z where Wdec∈ℝn×ksubscriptdecsuperscriptℝW_dec ^n× kWdec ∈ blackboard_Rn × k has ithsuperscriptthi^thith column νSisubscriptsubscript _S_iνitalic_S start_POSTSUBSCRIPT i end_POSTSUBSCRIPT. Then if we replace SρsubscriptS_ρSitalic_ρ by F⁢(x):=Wdec⁢fenc⁢(x)assignsubscriptdecsubscriptencF(x):=W_decf_enc(x)F ( x ) := Wdec fenc ( x ) in ℒLL and take ℒsparsity=ℒaux=0subscriptℒsparsitysubscriptℒaux0L_sparsity=L_aux=0Lsparsity = Laux = 0, ℒ=∑r=1N‖xr−F⁢(xr)‖22=∑i=1k∑xr∈ΩSi‖xr−νSi‖22,ℒsuperscriptsubscript1superscriptsubscriptnormsuperscriptsuperscript22superscriptsubscript1subscriptsuperscriptsubscriptΩsubscriptsuperscriptsubscriptnormsuperscriptsubscriptsubscript22L= _r=1^N\|x^r-F(x^r)\|_2^2= _i=1^k _x^r% ∈ _S_i\|x^r- _S_i\|_2^2,L = ∑r = 1N ∥ xitalic_r - F ( xitalic_r ) ∥22 = ∑i = 1k ∑xitalic_r ∈ Ω start_POSTSUBSCRIPT S start_POSTSUBSCRIPT i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ xitalic_r - νitalic_S start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥22 , which is exactly the k-means objective. An SAE is therefore a generalisation of this “k-means autoencoder”, where the encoding is allowed to be piecewise affine, the number of regions is allowed to be greater than the hidden dimension (k=2dsuperscript2k=2^dk = 2d in the (Jump)ReLU case and k=(dK)binomialk= dKk = ( FRACOP start_ARG d end_ARG start_ARG K end_ARG ) in the TopK case) and the SAE is overall piecewise affine with piecewise ranks |Si|subscript|S_i|| Sitalic_i |. But if SAEs are “k-means autoencoders” generalised to allow piecewise affine behaviour, how do they compare to the most general piecewise affine autoencoder? Theorem 2.6. On any partition Rii=1ksuperscriptsubscriptsubscript1\R_i\_i=1^k Ritalic_i i = 1k define the general piecewise affine autoencoder with piecewise ranks KisubscriptK_iKitalic_i (where Ui,Vi∈ℝn×KisubscriptsubscriptsuperscriptℝsubscriptU_i,V_i ^n× K_iUitalic_i , Vitalic_i ∈ blackboard_Rn × Kitalic_i, UiT⁢Ui=IsuperscriptsubscriptsubscriptU_i^TU_i=IUitalic_iitalic_T Uitalic_i = I, and ci∈ℝnsubscriptsuperscriptℝc_i ^ncitalic_i ∈ blackboard_Rn), G⁢(x):=Ui⁢ViT⁢x+ci,x∈Ri.assigncasessubscriptsuperscriptsubscriptsubscriptsubscriptG(x):= casesU_iV_i^Tx+c_i,&x∈ R_i. casesG ( x ) := start_ROW start_CELL Uitalic_i Vitalic_iitalic_T x + citalic_i , end_CELL start_CELL x ∈ Ritalic_i . end_CELL end_ROW Let ℒaux=0subscriptℒaux0L_aux=0Laux = 0 and ℒsparsity=∥⋅∥0L_sparsity=\|·\|_0Lsparsity = ∥ ⋅ ∥0, which counts the non-zero entries. Then we have the loss ℒ=∑i=1k∑xr∈Ri‖xr−G⁢(xr)‖22+λ⁢∑i=1kNi⁢Ki,ℒsuperscriptsubscript1subscriptsuperscriptsubscriptsuperscriptsubscriptnormsuperscriptsuperscript22superscriptsubscript1subscriptsubscriptL= _i=1^k _x^r∈ R_i\|x^r-G(x^r)\|_2^2+% λ _i=1^kN_iK_i,L = ∑i = 1k ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ xitalic_r - G ( xitalic_r ) ∥22 + λ ∑i = 1k Nitalic_i Kitalic_i , where Ni:=|r:xr∈Ri|assignsubscriptconditional-setsuperscriptsubscriptN_i:=|\r:x^r∈ R_i\|Nitalic_i := | r : xitalic_r ∈ Ritalic_i |. This has optimal parameters: Ui=VisubscriptsubscriptU_i=V_iUitalic_i = Vitalic_i where the columns of UisubscriptU_iUitalic_i are the top KisubscriptK_iKitalic_i (normalised) eigenvectors ξℓiℓ=1Kisuperscriptsubscriptsubscriptsuperscriptℓ1subscript\ξ^i_ \_ =1^K_i ξitalic_iroman_ℓ ℓ = 1Kitalic_i of the covariance matrix Xi:=1Ni⁢∑xr∈Ri(xr−x¯i)⁢(xr−x¯i)T,assignsubscript1subscriptsubscriptsuperscriptsubscriptsuperscriptsubscript¯superscriptsuperscriptsubscript¯X_i:= 1N_i _x^r∈ R_i(x^r- x_i)(x^r- x_% i)^T,Xitalic_i := divide start_ARG 1 end_ARG start_ARG Nitalic_i end_ARG ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ( xitalic_r - over¯ start_ARG x end_ARGi ) ( xitalic_r - over¯ start_ARG x end_ARGi )T , ci=(I−Ui⁢ViT)⁢x¯isubscriptsubscriptsuperscriptsubscriptsubscript¯c_i=(I-U_iV_i^T) x_icitalic_i = ( I - Uitalic_i Vitalic_iitalic_T ) over¯ start_ARG x end_ARGi, and optimal regions RisubscriptR_iRitalic_i minimising ℒ=∑i=1k(∑xr∈Ri‖(xr−x¯i)‖22)+Ni⁢Ki⁢(λ−1Ki⁢∑ℓ=1Kiλℓ⁢(Xi)),ℒsuperscriptsubscript1subscriptsuperscriptsubscriptsuperscriptsubscriptnormsuperscriptsubscript¯22subscriptsubscript1subscriptsuperscriptsubscriptℓ1subscriptsubscriptℓsubscriptL= _i=1^k ( _x^r∈ R_i\|(x^r- x_i)\|_% 2^2 )+N_iK_i (λ- 1K_i _ =1^K_i% _ (X_i) ),L = ∑i = 1k ( ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ ( xitalic_r - over¯ start_ARG x end_ARGi ) ∥22 ) + Nitalic_i Kitalic_i ( λ - divide start_ARG 1 end_ARG start_ARG Kitalic_i end_ARG ∑ℓ = 1Kitalic_i λroman_ℓ ( Xitalic_i ) ) , where λℓ⁢(Xi)subscriptℓsubscript _ (X_i)λroman_ℓ ( Xitalic_i ) are the eigenvalues of XisubscriptX_iXitalic_i in descending order. Therefore, G⁢(x)=x¯i+∑ℓ=1Ki(ξℓi)T⁢(x−x¯i)⁢ξℓi,x∈Ri,casessubscript¯superscriptsubscriptℓ1subscriptsuperscriptsubscriptsuperscriptℓsubscript¯subscriptsuperscriptℓsubscriptG(x)= cases x_i+ _ =1^K_i(ξ^i_ )^T(x- % x_i)ξ^i_ ,&x∈ R_i, casesG ( x ) = start_ROW start_CELL over¯ start_ARG x end_ARGi + ∑ℓ = 1Kitalic_i ( ξitalic_iroman_ℓ )T ( x - over¯ start_ARG x end_ARGi ) ξitalic_iroman_ℓ , end_CELL start_CELL x ∈ Ritalic_i , end_CELL end_ROW Thus, the general piecewise affine autoencoder combines a k-means-esque clustering with a variable-rank local PCA correction. The optimal KisubscriptK_iKitalic_i will occur when λKi⁢(Xi)>λ≥λKi+1⁢(Xi)subscriptsubscriptsubscriptsubscriptsubscript1subscript _K_i(X_i)>λ≥ _K_i+1(X_i)λitalic_K start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ( Xitalic_i ) > λ ≥ λitalic_K start_POSTSUBSCRIPT i + 1 end_POSTSUBSCRIPT ( Xitalic_i ), and hence G is sensitive to the local intrinsic dimension of the data. Note 2.7. The key takeaway from Theorem 2.6 is that whilst SAEs generalise k-means to allow piecewise affine behaviour, they are less accurate than the most general piecewise affine autoencoder, which is k-means-like with a local PCA that tracks the local intrinsic dimension. However, G achieves this greater accuracy by using a different encoding/decoding in each region. The code entry zj:=(ξji)T⁢(x−x¯i)assignsubscriptsuperscriptsubscriptsuperscriptsubscript¯z_j:=(ξ^i_j)^T(x- x_i)zitalic_j := ( ξitalic_iitalic_j )T ( x - over¯ start_ARG x end_ARGi ) for an x∈Risubscriptx∈ R_ix ∈ Ritalic_i is semantically unrelated to the zjsubscriptz_jzitalic_j for an x in another region. By contrast, SAEs sacrifice some accuracy to encode all inputs as monosemantic sparse codes z. This leads to SAEs having d decoding vectors (the columns of WdecsubscriptdecW_decWdec) of which a subset are deployed per region, so these vectors are shared between regions. 2.4 Visualising the SAE bridge between k-means and PCA We just saw mathematically how SAEs generalise k-means, but sacrifice accuracy for interpretability compared to the optimal piecewise affine autoencoder, which is k-means-esque extended via local PCA. Our first experiment is a quick empirical exploration of this bridge. We chose TopK SAEs, and trained them on 100 points in ℝ2superscriptℝ2R^2blackboard_R2 drawn from k=33k=3k = 3 clusters. For details on the experimental set-up, see Section D.1 and for another figure see Appendix E. We compared the SAEs to (i) a k-means “autoencoding”, which maps each data point to its centroid, and (i) a local 1-PCA extension of that k-means autoencoding. This is the optimal G from Theorem 2.6 with Rii=13superscriptsubscriptsubscript13\R_i\_i=1^3 Ritalic_i i = 13 fixed to be the k-means cells and Ki≡1subscript1K_i≡ 1Kitalic_i ≡ 1. We found (see Figures 2 and 8) that the SAE consistently had a lower mean squared error (MSE) than k-means but a higher MSE than the PCA extension, in accordance with the theory. Figure 2: Visualising the SAE bridge between k-means clustering and PCA. The cyan/brown boundary intersecting the point cloud (left and right figure) is a visualiser bug. 3 A proximal alternating method (PAM-SGD) for training SAEs 3.1 Convergence theory and the PAM-SGD algorithm It turns out that for any SAE that has the sort of loss function we have been considering, if we fix the encoding, then learning the decoding reduces to linear least squares regression: the decoding seeks an affine map that sends codes zr:=ρ⁢(Wenc⁢xr+benc)assignsuperscriptsubscriptencsuperscriptsubscriptencz^r:=ρ(W_encx^r+b_enc)zitalic_r := ρ ( Wenc xitalic_r + benc ) to xrsuperscriptx^rxitalic_r. We can therefore solve for the decoding in closed form, which suggests the idea of training an SAE by alternating between updating our encoding holding the decoding fixed, e.g. by SGD, and then updating our decoding holding the encoding fixed, via the closed form optimum. This dovetails well with the spline theory perspective, as we saw in the previous section that the spline geometry depends entirely on the encoding parameters. Therefore, this alternation can be viewed as (i) updating the SAE spline geometry to better use a given decoding, and then (i) finding the most accurate decoding given that geometry. In particular, we will consider the following proximal alternating method with quadratic costs to move, as this has solid theoretical foundations Attouch et al. (2010): (Wenct+1,benct+1)superscriptsubscriptenc1superscriptsubscriptenc1 (W_enc^t+1,b_enc^t+1)( Wencitalic_t + 1 , bencitalic_t + 1 ) =argminWenc,bencabsentsubscriptargminsubscriptencsubscriptenc = *argmin_W_enc,b_enc= argminitalic_W start_POSTSUBSCRIPT enc , benc end_POSTSUBSCRIPT ∑r=1N‖Wdect⁢ρ⁢(Wenc⁢xr+benc)+bdect−xr‖22+F⁢(Wenc,benc)superscriptsubscript1superscriptsubscriptnormsuperscriptsubscriptdecsubscriptencsuperscriptsubscriptencsuperscriptsubscriptdecsuperscript22subscriptencsubscriptenc _r=1^N\|W_dec^tρ(W_encx^r+b_% enc)+b_dec^t-x^r\|_2^2+F(W_enc,b_enc)∑r = 1N ∥ Wdecitalic_t ρ ( Wenc xitalic_r + benc ) + bdecitalic_t - xitalic_r ∥22 + F ( Wenc , benc ) (3.1a) +μenct⁢‖Wenc−Wenct‖F2+νenct⁢‖benc−benct‖22superscriptsubscriptencsuperscriptsubscriptnormsubscriptencsuperscriptsubscriptenc2superscriptsubscriptencsuperscriptsubscriptnormsubscriptencsuperscriptsubscriptenc22 + _enc^t\|W_enc-W_enc^t\|_F^2% + _enc^t\|b_enc-b_enc^t\|_2^2+ μencitalic_t ∥ Wenc - Wencitalic_t ∥F2 + νencitalic_t ∥ benc - bencitalic_t ∥22 (Wdect+1,bdect+1)superscriptsubscriptdec1superscriptsubscriptdec1 (W_dec^t+1,b_dec^t+1)( Wdecitalic_t + 1 , bdecitalic_t + 1 ) =argminWdec,bdecabsentsubscriptargminsubscriptdecsubscriptdec = *argmin_W_dec,b_dec= argminitalic_W start_POSTSUBSCRIPT dec , bdec end_POSTSUBSCRIPT ∑r=1N‖Wdec⁢ρ⁢(Wenct+1⁢xr+benct+1)+bdec−xr‖22+G⁢(Wdec,bdec)superscriptsubscript1superscriptsubscriptnormsubscriptdecsubscriptsuperscript1encsuperscriptsubscriptsuperscript1encsubscriptdecsuperscript22subscriptdecsubscriptdec _r=1^N\|W_decρ(W^t+1_encx^r+b^t% +1_enc)+b_dec-x^r\|_2^2+G(W_dec,b_dec% )∑r = 1N ∥ Wdec ρ ( Witalic_t + 1enc xitalic_r + bitalic_t + 1enc ) + bdec - xitalic_r ∥22 + G ( Wdec , bdec ) (3.1b) +μdect⁢‖Wdec−Wdect‖F2+νdect⁢‖bdec−bdect‖22superscriptsubscriptdecsuperscriptsubscriptnormsubscriptdecsuperscriptsubscriptdec2subscriptsuperscriptdecsuperscriptsubscriptnormsubscriptdecsuperscriptsubscriptdec22 + _dec^t\|W_dec-W_dec^t\|_F^2% +ν^t_dec\|b_dec-b_dec^t\|_2^2+ μdecitalic_t ∥ Wdec - Wdecitalic_t ∥F2 + νitalic_tdec ∥ bdec - bdecitalic_t ∥22 We next use Attouch et al. (2010) to analyse the convergence of Equation 3.1, and find that under some assumptions (which will require minor adjustments to our SAE settings, see C.2 and C.3) the sequence defined by (3.1) converges to a critical point of the following loss: ℒ⁢(Wenc,benc,Wdec,bdec):=∑r=1N‖Wdec⁢ρ⁢(Wenc⁢xr+benc)+bdec−xr‖22+F⁢(Wenc,benc)+G⁢(Wdec,bdec).assignℒsubscriptencsubscriptencsubscriptdecsubscriptdecsuperscriptsubscript1superscriptsubscriptnormsubscriptdecsubscriptencsuperscriptsubscriptencsubscriptdecsuperscript22subscriptencsubscriptencsubscriptdecsubscriptdecL(W_enc,b_enc,W_dec,b_dec):=Σ% _r=1^N\|W_decρ(W_encx^r+b_enc)+b_% dec-x^r\|_2^2+F(W_enc,b_enc)+G(W_dec,b_% dec).L ( Wenc , benc , Wdec , bdec ) := ∑r = 1N ∥ Wdec ρ ( Wenc xitalic_r + benc ) + bdec - xitalic_r ∥22 + F ( Wenc , benc ) + G ( Wdec , bdec ) . We summarise the convergence result as follows, for details see Section C.1. Theorem 3.1. If C.2 holds, then from any initialisation, the sequence of SAE parameters defined by Equation 3.1 monotonically decrease the loss ℒLL and every convergent subsequence converges to a critical point of ℒLL. Furthermore, if the sequence is bounded (as would be ensured by e.g. weight decay) then the sequence converges to a critical point of ℒLL, with a rate that can be characterised (if the Kurdyka–Łojasiewicz exponent of ℒLL is known). Note 3.2. Theorem 3.1 (i.e., Theorem C.6) does not prove that our PAM-SGD method (see Algorithm 3.1 below) converges, as the theorem assumes that Equation 3.1a is solved exactly, whilst in Algorithm 3.1 it will be only approximated via SGD. However, it does give some indication that the PAM-SGD method will approach an approximation to a critical point of ℒLL. The optimal decoding for Equation 3.1b with weight decay G:=α⁢‖Wdec‖F2+β⁢‖bdec‖22assignsubscriptsuperscriptnormsubscriptdec2superscriptsubscriptnormsubscriptdec22G:=α\|W_dec\|^2_F+β\|b_dec\|_2^2G := α ∥ Wdec ∥2F + β ∥ bdec ∥22 can still be found in closed form, see Theorem C.1. This gives the following novel method for training an SAE by solving Equation 3.1b exactly, which we call a proximal alternating method SGD (PAM-SGD) algorithm. Input: Initial SAE parameters Wenc0,benc0,Wdec0,bdec0superscriptsubscriptenc0superscriptsubscriptenc0superscriptsubscriptdec0superscriptsubscriptdec0W_enc^0,b_enc^0,W_dec^0,b_dec^0Wenc0 , benc0 , Wdec0 , bdec0, iterations tm⁢a⁢xsubscriptt_maxtitalic_m a x, quadratic cost parameters μenct,νenct,μdect,νdectt=0tm⁢a⁢x−1superscriptsubscriptsuperscriptsubscriptencsuperscriptsubscriptencsuperscriptsubscriptdecsuperscriptsubscriptdec0subscript1\ _enc^t, _enc^t, _dec^t, _% dec^t\_t=0^t_max-1 μencitalic_t , νencitalic_t , μdecitalic_t , νdecitalic_t t = 0titalic_m a x - 1, weight decay parameters α,βα,βα , β, activation ρ, learning rate η, batch size B, SGD steps M. Training data =xrr=1N∈ℝnsuperscriptsubscriptsuperscript1superscriptℝD=\x^r\_r=1^N ^nD = xitalic_r r = 1N ∈ blackboard_Rn. Output: Final SAE parameters Wenctm⁢a⁢x,benctm⁢a⁢x,Wdectm⁢a⁢x,bdectm⁢a⁢xsuperscriptsubscriptencsubscriptsuperscriptsubscriptencsubscriptsuperscriptsubscriptdecsubscriptsuperscriptsubscriptdecsubscriptW_enc^t_max,b_enc^t_max,W_dec^t_max,b_% dec^t_maxWencitalic_titalic_m a x , bencitalic_titalic_m a x , Wdecitalic_titalic_m a x , bdecitalic_titalic_m a x. 1 x¯←1N⁢∑r=1Nxr←¯1superscriptsubscript1superscript x← 1N _r=1^Nx^rover¯ start_ARG x end_ARG ← divide start_ARG 1 end_ARG start_ARG N end_ARG ∑r = 1N xitalic_r; 2 for t=00t=0t = 0 to tm⁢a⁢x−1subscript1t_max-1titalic_m a x - 1 do /* -------------------------------------- SGD Encoder update -------------------------------------- */ (Wenct+1,benct+1)←SGD⁢(Wenct,benct,μenct,νenct,η,B,M)←superscriptsubscriptenc1superscriptsubscriptenc1SGDsuperscriptsubscriptencsuperscriptsubscriptencsuperscriptsubscriptencsuperscriptsubscriptenc(W_enc^t+1,b_enc^t+1)← SGD(W_enc% ^t,b_enc^t, _enc^t, _enc^t,η,B,M)( Wencitalic_t + 1 , bencitalic_t + 1 ) ← SGD ( Wencitalic_t , bencitalic_t , μencitalic_t , νencitalic_t , η , B , M ) ; //Computes Equation 3.1a via M steps of SGD with learning rate η and batch size B /* ---------------------------------- Optimal Decoder update ---------------------------------- */ 3 foreach xr∈superscriptx^r _r ∈ D do 4 zt+1r←ρ⁢(Wenct+1⁢xr+benct+1)←superscriptsubscript1superscriptsubscriptenc1superscriptsuperscriptsubscriptenc1z_t+1^r←ρ(W_enc^t+1x^r+b_enc^t+1)zitalic_t + 1r ← ρ ( Wencitalic_t + 1 xitalic_r + bencitalic_t + 1 ) ; 5 6 end foreach 7 z¯t+1←1N⁢∑r=1Nzt+1r←subscript¯11superscriptsubscript1superscriptsubscript1 z_t+1← 1N _r=1^Nz_t+1^rover¯ start_ARG z end_ARGt + 1 ← divide start_ARG 1 end_ARG start_ARG N end_ARG ∑r = 1N zitalic_t + 1r; 8 for r=11r=1r = 1 to N do 9 ψtr←zt+1r−N+β+νdect⁢z¯t+1←superscriptsubscriptsubscriptsuperscript1subscriptsuperscriptdecsubscript¯1 _t^r← z^r_t+1- NN+β+ν^t_dec % z_t+1ψitalic_titalic_r ← zitalic_ritalic_t + 1 - divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG over¯ start_ARG z end_ARGt + 1; ϕtr←xr−νdectN+β+νdect⁢bdect−N+β+νdect⁢x¯←superscriptsubscriptitalic-ϕsuperscriptsubscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯ _t^r← x^r- ν^t_decN+β+ν^t_% decb_dec^t- NN+β+ν^t_dec xϕitalic_titalic_r ← xitalic_r - divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t - divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG over¯ start_ARG x end_ARG; 10 11 end for 12 (ψtN+1,ϕtN+1)←νdect⁢1N+β+νdect⁢(N⁢z¯t+1,N⁢x¯−(N+β)⁢bdect)←superscriptsubscript1superscriptsubscriptitalic-ϕ1superscriptsubscriptdec1superscriptsubscriptdecsubscript¯1¯subscriptsuperscriptdec( _t^N+1, _t^N+1)← _dec^t 1N% +β+ _dec^t (N z_t+1,N x-(N+β)b^t_% dec )( ψitalic_titalic_N + 1 , ϕitalic_titalic_N + 1 ) ← square-root start_ARG νdecitalic_t end_ARG divide start_ARG 1 end_ARG start_ARG N + β + νdecitalic_t end_ARG ( N over¯ start_ARG z end_ARGt + 1 , N over¯ start_ARG x end_ARG - ( N + β ) bitalic_tdec ) ; 13 (ψtN+2,ϕtN+2)←β⁢1N+β+νdect⁢(N⁢z¯t+1,N⁢x¯+νdect⁢bdect)←superscriptsubscript2superscriptsubscriptitalic-ϕ21superscriptsubscriptdecsubscript¯1¯subscriptsuperscriptdecsubscriptsuperscriptdec( _t^N+2, _t^N+2)← β 1N+β+ _% dec^t(N z_t+1,N x+ν^t_decb^t_dec)( ψitalic_titalic_N + 2 , ϕitalic_titalic_N + 2 ) ← square-root start_ARG β end_ARG divide start_ARG 1 end_ARG start_ARG N + β + νdecitalic_t end_ARG ( N over¯ start_ARG z end_ARGt + 1 , N over¯ start_ARG x end_ARG + νitalic_tdec bitalic_tdec ) ; Ψt←cat⁢(ψtr)←subscriptΨcatsubscriptsuperscript _t← cat(ψ^r_t)Ψitalic_t ← cat ( ψitalic_ritalic_t ) ; Φt←cat⁢(ϕtr)←subscriptΦcatsubscriptsuperscriptitalic-ϕ _t← cat(φ^r_t)Φitalic_t ← cat ( ϕitalic_ritalic_t ) ; //Concatenates the ψtrsubscriptsuperscriptψ^r_tψitalic_ritalic_t and ϕtrsubscriptsuperscriptitalic-ϕφ^r_tϕitalic_ritalic_t into a d×(N+2)2d×(N+2)d × ( N + 2 ) matrix and a n×(N+2)2n×(N+2)n × ( N + 2 ) matrix, respectively 14 Wdect+1←(Φtμdect⁢Wdect0n×d)⁢(Ψtμdect⁢Idα⁢Id)†←superscriptsubscriptdec1matrixsubscriptΦsubscriptsuperscriptdecsubscriptsuperscriptdecsubscript0superscriptmatrixsubscriptΨsubscriptsuperscriptdecsubscriptsubscript†W_dec^t+1← pmatrix _t& μ^t_dec% W^t_dec&0_n× d pmatrix pmatrix _t& % μ^t_decI_d& αI_d pmatrix Wdecitalic_t + 1 ← ( start_ARG start_ROW start_CELL Φitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Witalic_tdec end_CELL start_CELL 0n × d end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL Ψitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Iitalic_d end_CELL start_CELL square-root start_ARG α end_ARG Iitalic_d end_CELL end_ROW end_ARG )† ; bdect+1←νdectN+β+νdect⁢bdect+N+β+νdect⁢(x¯−Wdect+1⁢z¯t+1)←subscriptsuperscript1decsubscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯subscriptsuperscript1decsubscript¯1b^t+1_dec← ν^t_decN+β+ν^t_% decb_dec^t+ NN+β+ν^t_dec( x-% W^t+1_dec z_t+1)bitalic_t + 1dec ← divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t + divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG ( over¯ start_ARG x end_ARG - Witalic_t + 1dec over¯ start_ARG z end_ARGt + 1 ) ; //Optimal Wdect+1subscriptsuperscript1decW^t+1_decWitalic_t + 1dec and bdect+1subscriptsuperscript1decb^t+1_decbitalic_t + 1dec, see Theorem C.1 15 16 end for Algorithm 3.1 PAM-SGD method (with optional weight decay) for learning an SAE. 3.2 Sample-efficient sparse coding for MNIST and LLMs: PAM-SGD vs. SGD We performed two experiments comparing the benefits of our PAM-SGD method (Algorithm 3.1) vs. standard SGD (using Adam for SGD in both cases) for training SAEs: (i) on simple visual domains, using the MNIST dataset, and (i) on high-dimensional LLM activations, using Google DeepMind’s Gemma-2-2B. Figure 3: Training and test loss curves at different data sizes for MNIST, with ReLU activation. The chart highlights PAM-SGD’s superior sample efficiency and convergence speed. For MNIST we used n=282superscript282n=28^2n = 282 input dimensions and d=256256d=256d = 256 hidden dimensions. Our LLM experiments used activations from the 12th layer of Gemma-2-2B, with n=23042304n=2304n = 2304 and a highly overcomplete hidden dimension d=40964096d=4096d = 4096. Experimental settings are described in Appendix D, and additional figures and ablation studies can be found in Appendix E. PAM-SGD generalised better than SGD in low-data regimes. (Figures 3, 4, 9, 24, 25, 26, 29, 28 and 27) On MNIST, PAM-SGD consistently (using both ReLU and TopK) substantially outperformed SGD in test loss, especially when trained on just 1%–25% of the MNIST training data. In the LLM experiments, PAM-SGD again outperformed SGD when using ReLU activations, especially for low data. However, PAM-SGD became unstable when TopK was used unless K was large; even for K=320320K=320K = 320 and K=640640K=640K = 640 it underperformed SGD, slightly for low (and high) data and substantially for medium data. Figure 4: Training and test loss curves at different data sizes for Gemma-2-2B, with ReLU activation. PAM-SGD again has a huge advantage at low data, and remains superior throughout. PAM-SGD was faster, more accurate, and more interpretable on MNIST. (Figures 12, 14, 10, 13, 15 and 11) Reconstruction comparisons over training epochs show that PAM-SGD reconstructions were cleaner and converged faster than those from SGD. This was particularly evident when tracking digit reconstructions over time; PAM-SGD’s exhibited sharper edges and more localised structure by early training. By the end of training, both methods produced visually accurate reconstructions, but PAM-SGD’s showed slightly better fidelity and smoothness, and more closely resembled the originals. Finally, visualizing encoder and decoder filters in the TopK case reveals that both SGD and PAM-SGD learned edge- and stroke-like patterns, but PAM-SGD’s filters were sharper and better structured. Summary and practical implications. PAM-SGD demonstrated clear advantages over SGD on MNIST in terms of generalisation, convergence speed, reconstruction quality, and (using TopK) visual interpretability, particularly in low-data regimes. Even in a much more challenging real-world LLM setting, PAM-SGD with ReLU still substantially outperformed in low data, and improved activation sparsity by about 15%. However, issues arose for TopK: small K led to rapidly diverging test loss, and for larger K PAM-SGD still underperformed SGD (though only slightly for low data). In summary, these results suggest that PAM-SGD is a powerful tool for learning overcomplete, sparse representations from visual data and LLM activations in low-data regimes, provided that the sparsity can adapt to the data. This insight is important in downstream applications where data may be scarce. 4 Conclusions and limitations In this work, we have sought to apply a spline theoretical lens to SAEs, to gain insight into how, why, and whether SAEs work. Given the current prominence of SAEs in mechanistic interpretability, and the societal importance of interpreting AI systems, we hope that the development of SAE theory (and our small contribution to it) can help develop more efficient, fairer, and reliable AI systems. Building on the piecewise affine spline nature of SAEs, we characterised the spline geometry of TopK SAEs as exactly the KthsuperscriptthK^thKth-order power diagrams, opening the door to directly incorporating geometric constraints into SAEs. We linked SAEs with traditional ML, showing how k-means can be viewed as a special kind of SAE, and how SAEs sacrifice accuracy for interpretability vs. the optimal piecewise affine autoencoder, which we showed to be a k-means-esque clustering with a local PCA correction. Finally, we developed a new proximal alternating training method (PAM-SGD) for SAEs, with both solid theoretical foundations and promising empirical results, particularly in sample efficiency and activation sparsity for LLMs, two pain points for mechanistic interpretability. PAM-SGD’s separate updating of encoding and decoding dovetails well with the spline theory perspective of the encoding shaping the SAE’s spline geometry vs. the decoding driving the SAE’s autoencoding accuracy. This work is the beginning of a longer theoretical exploration, and is thus limited in ways we hope to address in future work. Our characterisation of the spline geometry of SAEs is currently limited to TopK SAEs; future work will seek to extend this, and explore more explicitly the incorporation of geometry into SAE training, perhaps giving insight into how to tailor an SAE architecture for a given task. Our bridge between SAEs and PCA-based autoencoders sets aside the matter (see 2.7) of shared decoding vectors. Future work will study the optimal autoencoding in that setting, and related results in the superposition hypothesis setting. Finally, PAM-SGD makes approximations which break assumptions of the theory, and had some empirical limitations. Future work will seek to understand more deeply the pros and cons of PAM-SGD, and incorporate the SGD step into the theory. Acknowledgments and Disclosure of Funding This collaboration did not form in the typical academic way, meeting at a conference or university. We instead thank the Machine Learning Street Talk (MLST) team, especially Tim Scarfe, for enabling all the authors to have met through the MLST Discord server. And we thank all the MLST Discord users involved in the discussion on “Learning in high dimension always amounts to extrapolation”, which set all this in motion. JB received financial support from start-up funds at the University of Birmingham. BMR received financial support from Taighde Éireann – Research Ireland under Grant number [12/RC/2289_P2]. We declare no conflicts of interest. References Ash & Bolker (1986) Ash, P. F. and Bolker, E. D. Generalized Dirichlet tessellations. Geometriae Dedicata, 20(2):209–243, Apr 1986. ISSN 1572-9168. doi: 10.1007/BF00164401. URL https://doi.org/10.1007/BF00164401. Attouch et al. (2010) Attouch, H., Bolte, J., Redont, P., and Soubeyran, A. Proximal alternating minimization and projection methods for nonconvex problems: An approach based on the Kurdyka-Łojasiewicz inequality. Mathematics of operations research, 35(2):438–457, 2010. Balestriero & Baraniuk (2018) Balestriero, R. and Baraniuk, R. A spline theory of deep learning. In Dy, J. and Krause, A. (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, p. 374–383. PMLR, 10–15 Jul 2018. URL https://proceedings.mlr.press/v80/balestriero18b.html. Bennett et al. (2009) Bennett, C., Miller, M., and Wolford, G. Neural correlates of interspecies perspective taking in the post-mortem atlantic salmon: an argument for multiple comparisons correction. NeuroImage, 47:S125, 2009. ISSN 1053-8119. doi: https://doi.org/10.1016/S1053-8119(09)71202-9. URL https://w.sciencedirect.com/science/article/pii/S1053811909712029. Organization for Human Brain Mapping 2009 Annual Meeting. Bolte et al. (2007) Bolte, J., Daniilidis, A., and Lewis, A. The Łojasiewicz Inequality for Nonsmooth Subanalytic Functions with Applications to Subgradient Dynamical Systems. SIAM Journal on Optimization, 17(4):1205–1223, 2007. doi: 10.1137/050644641. Bricken et al. (2023) Bricken, T., Templeton, A., Batson, J., Chen, B., Jermyn, A., Conerly, T., Turner, N., Anil, C., Denison, C., Askell, A., Lasenby, R., Wu, Y., Kravec, S., Schiefer, N., Maxwell, T., Joseph, N., Hatfield-Dodds, Z., Tamkin, A., Nguyen, K., McLean, B., Burke, J. E., Hume, T., Carter, S., Henighan, T., and Olah, C. Towards monosemanticity: Decomposing language models with dictionary learning. Transformer Circuits Thread, 2023. https://transformer-circuits.pub/2023/monosemantic-features/index.html. Elhage et al. (2022) Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., Hatfield-Dodds, Z., Lasenby, R., Drain, D., Chen, C., et al. Toy models of superposition. arXiv preprint arXiv:2209.10652, 2022. Gao et al. (2025) Gao, L., la Tour, T. D., Tillman, H., Goh, G., Troll, R., Radford, A., Sutskever, I., Leike, J., and Wu, J. Scaling and evaluating sparse autoencoders. In The Thirteenth International Conference on Learning Representations, 2025. URL https://openreview.net/forum?id=tcsZt9ZNKD. Heap et al. (2025) Heap, T., Lawson, T., Farnik, L., and Aitchison, L. Sparse autoencoders can interpret randomly initialized transformers, 2025. URL https://arxiv.org/abs/2501.17727. Hindupur et al. (2025) Hindupur, S. S. R., Lubana, E. S., Fel, T., and Ba, D. Projecting assumptions: The duality between sparse autoencoders and concept geometry, 2025. URL https://arxiv.org/abs/2503.01822. Huben et al. (2024) Huben, R., Cunningham, H., Smith, L. R., Ewart, A., and Sharkey, L. Sparse autoencoders find highly interpretable features in language models. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=F76bwRSLeK. Humayun et al. (2024) Humayun, A. I., Balestriero, R., and Baraniuk, R. Deep networks always grok and here is why. In High-dimensional Learning Dynamics 2024: The Emergence of Structure and Reasoning, 2024. URL https://openreview.net/forum?id=NpufNsg1FP. Leask et al. (2025) Leask, P., Bussmann, B., Pearce, M., Bloom, J., Tigges, C., Moubayed, N. A., Sharkey, L., and Nanda, N. Sparse autoencoders do not find canonical units of analysis, 2025. URL https://arxiv.org/abs/2502.04878. Li & Pong (2017) Li, G. and Pong, T. K. Calculus of the Exponent of Kurdyka–Łojasiewicz Inequality and Its Applications to Linear Convergence of First-Order Methods. Foundations of Computational Mathematics, 18(5):1199–1232, aug 2017. doi: 10.1007/s10208-017-9366-8. Łojasiewicz (1964) Łojasiewicz, S. Triangulation of semi-analytic sets. Annali della Scuola Normale Superiore di Pisa-Classe di Scienze, 18(4):449–474, 1964. Makhzani & Frey (2014) Makhzani, A. and Frey, B. k-sparse autoencoders, 2014. URL https://arxiv.org/abs/1312.5663. Ng (2011) Ng, A. Sparse autoencoder. CS294A Lecture notes, 72(2011):1–19, 2011. Olah et al. (2017) Olah, C., Mordvintsev, A., and Schubert, L. Feature visualization. Distill, 2(11):e7, 2017. Olshausen & Field (1996) Olshausen, B. A. and Field, D. J. Emergence of simple-cell receptive field properties by learning a sparse code for natural images. Nature, 381(6583):607–609, 1996. Power et al. (2022) Power, A., Burda, Y., Edwards, H., Babuschkin, I., and Misra, V. Grokking: Generalization beyond overfitting on small algorithmic datasets, 2022. URL https://arxiv.org/abs/2201.02177. Rajamanoharan et al. (2024) Rajamanoharan, S., Lieberum, T., Sonnerat, N., Conmy, A., Varma, V., Kramár, J., and Nanda, N. Jumping ahead: Improving reconstruction fidelity with JumpReLU sparse autoencoders, 2024. URL https://arxiv.org/abs/2407.14435. Smith et al. (2025) Smith, L., Rajamanoharan, S., Conmy, A., McDougall, C., Kramar, J., Lieberum, T., Shah, R., and Nanda, N. Negative Results for SAEs On Downstream Tasks and Deprioritising SAE Research (GDM Mech Interp Team Progress Update #2), Mar 2025. URL https://w.alignmentforum.org/posts/4uXCAJNuPKtKBsi28/sae-progress-update-2-draft. Steinhaus (1957) Steinhaus, H. Sur la division des corps matériels en parties. Bull. Acad. Pol. Sci., Cl. I, 4:801–804, 1957. ISSN 0001-4095. Templeton et al. (2024) Templeton, A., Conerly, T., Marcus, J., Lindsey, J., Bricken, T., Chen, B., Pearce, A., Citro, C., Ameisen, E., Jones, A., Cunningham, H., Turner, N. L., McDougall, C., MacDiarmid, M., Freeman, C. D., Sumers, T. R., Rees, E., Batson, J., Jermyn, A., Carter, S., Olah, C., and Henighan, T. Scaling monosemanticity: Extracting interpretable features from Claude 3 Sonnet. Transformer Circuits Thread, 2024. URL https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html. Appendix A Visualisations of Voronoi and power diagrams A.1 Spatial Partitioning Methods: From Voronoi to Power Diagrams This experiment explores the theoretical connections between different spatial partitioning methods. Starting with standard Voronoi diagrams that divide the plane based on proximity to generator points, we demonstrate how they relate to nearest-neighbor classification (k=11k=1k = 1), centroidal clustering (k-means), and finally to power diagrams which introduce weights to Voronoi cells. These relationships reveal that power diagrams emerge as a generalization of Voronoi diagrams, offering additional flexibility through weighted distance metrics and enabling richer geometric representations of data. Figure 5: Spatial Partitioning Methods (a) Standard Voronoi diagram where space is partitioned based on the nearest generator point using Euclidean distance; (b) Nearest-neighbor (k=11k=1k = 1) classification showing how Voronoi cells define decision boundaries for point classification (c) k-means clustering with k=33k=3k = 3 demonstrating how cluster centroids generate Voronoi cells that define cluster boundaries (d) Power diagram (weighted Voronoi) where each generator has an associated weight, creating curved boundaries between regions. Power diagrams generalise Voronoi diagrams and provide additional flexibility for modeling spatial relationships. (a) (b) Figure 6: Progressive Generalization of Voronoi Diagrams. Comparison showing (left) standard Voronoi diagram, (center) first-order power diagram with linear distance weighting, and (right) second-order power diagram. This progression demonstrates how higher-order power diagrams can capture more sophisticated spatial relationships. A.2 3D Voronoi diagrams to 2D power diagrams There is a neat relationship between power diagrams and projections of Voronoi diagrams, observed in Ash & Bolker (1986). Suppose that we consider a Vononoi diagram in n+11n+1n + 1 dimensions with cells defined by centroids (μi,ζi)i=1k∈ℝn+1superscriptsubscriptsubscriptsubscript1superscriptℝ1\( _i, _i)\_i=1^k ^n+1 ( μitalic_i , ζitalic_i ) i = 1k ∈ blackboard_Rn + 1, i.e. the ithsuperscriptthi^thith cell is defined to be Ci:=(x,z)∈ℝn+1:‖(x,z)−(μi,ζi)‖22<‖(x,z)−(μj,ζj)‖22⁢ for all ⁢j≠i.assignsubscriptconditional-setsuperscriptℝ1superscriptsubscriptnormsubscriptsubscript22superscriptsubscriptnormsubscriptsubscript22 for all C_i:=\(x,z) ^n+1:\|(x,z)-( _i, _i)\|_2^2<\|(x,z% )-( _j, _j)\|_2^2 for all j≠ i\.Citalic_i := ( x , z ) ∈ blackboard_Rn + 1 : ∥ ( x , z ) - ( μitalic_i , ζitalic_i ) ∥22 < ∥ ( x , z ) - ( μitalic_j , ζitalic_j ) ∥22 for all j ≠ i . Now suppose that we project this diagram into n-dimensional space, for example defining C^i:=x∈ℝn:(x,0)∈Ci=x∈ℝn:‖x−μi‖22+ζi2<‖x−μj‖22+ζj2⁢ for all ⁢j≠i.assignsubscript^conditional-setsuperscriptℝ0subscriptconditional-setsuperscriptℝsuperscriptsubscriptnormsubscript22superscriptsubscript2superscriptsubscriptnormsubscript22superscriptsubscript2 for all C_i:=\x ^n:(x,0)∈ C_i\=\x ^n:\|x-% _i\|_2^2+ _i^2<\|x- _j\|_2^2+ _j^2 for % all j≠ i\.over start_ARG C end_ARGi := x ∈ blackboard_Rn : ( x , 0 ) ∈ Citalic_i = x ∈ blackboard_Rn : ∥ x - μitalic_i ∥22 + ζitalic_i2 < ∥ x - μitalic_j ∥22 + ζitalic_j2 for all j ≠ i . This is precisely a power diagram with centroids μii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n μitalic_i i = 1k ∈ blackboard_Rn (the projections of the Voronoi centroids) and weights αi=−ζi2subscriptsuperscriptsubscript2 _i=- _i^2αitalic_i = - ζitalic_i2. Conversely, for any power diagram defined by centroids μii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n μitalic_i i = 1k ∈ blackboard_Rn and weights αii=1k∈ℝsuperscriptsubscriptsubscript1ℝ\ _i\_i=1^k αitalic_i i = 1k ∈ blackboard_R, we can subtract a constant from the αisubscript _iαitalic_i to get an equivalent power diagram with all non-positive weights, and therefore compute ζisubscript _iζitalic_i such that the corresponding Voronoi diagram with centroids (μi,ζi)i=1k∈ℝn+1superscriptsubscriptsubscriptsubscript1superscriptℝ1\( _i, _i)\_i=1^k ^n+1 ( μitalic_i , ζitalic_i ) i = 1k ∈ blackboard_Rn + 1 projects to that power diagram. We visualise this mathematical relationship between 3D Voronoi diagrams and 2D power diagrams in Figure 7. (a) (b) (c) (d) Figure 7: Projection of 3D Voronoi diagrams onto 2D power diagrams. The resulting power cells demonstrate the preservation of topological structure through projection Appendix B Proofs in Section 2 Theorem B.1. Both ΩSJumpReLUsuperscriptsubscriptΩJumpReLU _S^JumpReLUΩitalic_Sroman_JumpReLU and ΩSTopKsuperscriptsubscriptΩTopK _S^TopKΩitalic_Sroman_TopK can be written in the form x∈ℝn:H⁢x>c,conditional-setsuperscriptℝ\x ^n:Hx>c\, x ∈ blackboard_Rn : H x > c , where in the in the former H=(PS−PSc)⁢WencsubscriptsubscriptsuperscriptsubscriptencH=(P_S-P_S^c)W_encH = ( Pitalic_S - Pitalic_Sitalic_c ) Wenc and c=(PS−PSc)⁢(τ⁢−benc)subscriptsubscriptsuperscript1subscriptencc=(P_S-P_S^c)( 1-b_enc)c = ( Pitalic_S - Pitalic_Sitalic_c ) ( τ 1 - benc ), and in the latter H=MS⁢We⁢n⁢csubscriptsubscriptH=M_SW_encH = Mitalic_S Witalic_e n c where MS∈ℝK⁢(d−K)×dsubscriptsuperscriptℝM_S ^K(d-K)× dMitalic_S ∈ blackboard_RK ( d - K ) × d has (i,j)(i,j)( i , j )-th row (with i∈Si∈ Si ∈ S and j∉Sj∉ Sj ∉ S) with a 1111 in column i, a −11-1- 1 in column j, and 00 otherwise, and c=−MS⁢bencsubscriptsubscriptencc=-M_Sb_encc = - Mitalic_S benc. These regions are therefore convex, and are the interiors of the convex polyhedra x∈ℝn:H⁢x≥cconditional-setsuperscriptℝ\x ^n:Hx≥ c\ x ∈ blackboard_Rn : H x ≥ c (unless WencsubscriptencW_encWenc has a zero row, in the JumpReLU case, or has two identical rows, in the TopK case). Proof of Theorem B.1. The forms of H and c can be immediately derived by rearranging the inequalities in the definitions of ΩSJumpReLUsubscriptsuperscriptΩJumpReLU ^JumpReLU_SΩroman_JumpReLUitalic_S and ΩSTopKsuperscriptsubscriptΩTopK _S^TopKΩitalic_Sroman_TopK. Convexity immediately follows, since if H⁢x1>csubscript1Hx_1>cH x1 > c and H⁢x2>csubscript2Hx_2>cH x2 > c, then for all t∈[0,1]01t∈[0,1]t ∈ [ 0 , 1 ] H⁢(t⁢x1+(1−t)⁢x2)=t⁢H⁢x1+(1−t)⁢H⁢x2>t⁢c+(1−t)⁢c=c.subscript11subscript2subscript11subscript21H(tx_1+(1-t)x_2)=tHx_1+(1-t)Hx_2>tc+(1-t)c=c.H ( t x1 + ( 1 - t ) x2 ) = t H x1 + ( 1 - t ) H x2 > t c + ( 1 - t ) c = c . Finally, suppose that x∈ℝnsuperscriptℝx ^nx ∈ blackboard_Rn lies in the interior of x:H⁢x≥cconditional-set\x:Hx≥ c\ x : H x ≥ c . That is, there exists ε>00 >0ε > 0 such thatr for all η∈ℝnsuperscriptℝη ^nη ∈ blackboard_Rn with ‖η‖2<εsubscriptnorm2\|η\|_2< ∥ η ∥2 < ε, H⁢(x+η)≥cH(x+η)≥ cH ( x + η ) ≥ c. We wish to show that H⁢x>cHx>cH x > c. Suppose not, then for some j, (H⁢x)j=cjsubscriptsubscript(Hx)_j=c_j( H x )j = citalic_j. Therefore for all η∈ℝnsuperscriptℝη ^nη ∈ blackboard_Rn with ‖η‖2<εsubscriptnorm2\|η\|_2< ∥ η ∥2 < ε, (H⁢η)j≥0subscript0(Hη)_j≥ 0( H η )j ≥ 0, and therefore for all η∈ℝnsuperscriptℝη ^nη ∈ blackboard_Rn, (H⁢η)j=0subscript0(Hη)_j=0( H η )j = 0. This is possible if and only if the jthsuperscriptthj^thjth row of H is all zeroes. In the JumpReLU case, the jthsuperscriptthj^thjth row of H is ±plus-or-minus± the jthsuperscriptthj^thjth row of WencsubscriptencW_encWenc (depending on if j∈Sj∈ Sj ∈ S or j∉Sj∉ Sj ∉ S) and hence is zero if and only if the the jthsuperscriptthj^thjth row of WencsubscriptencW_encWenc is zero. In the TopK case, the (i,j)thsuperscriptth(i,j)^th( i , j )th row of H is the difference between the ithsuperscriptthi^thith and jthsuperscriptthj^thjth rows of WencsubscriptencW_encWenc, which is zero if and only if those rows are identical. ∎ Theorem B.2. Let μii=1k∈ℝnsuperscriptsubscriptsubscript1superscriptℝ\ _i\_i=1^k ^n μitalic_i i = 1k ∈ blackboard_Rn and αii=1k∈ℝsuperscriptsubscriptsubscript1ℝ\ _i\_i=1^k αitalic_i i = 1k ∈ blackboard_R define a KthsuperscriptthK^thKth-order power diagram CSsubscript\C_S\ Citalic_S for S the K-subsets of 1,…,k1…\1,...,k\ 1 , … , k . Then the power diagram given by νS:=1K⁢∑i∈Sμiassignsubscript1subscriptsubscript _S:= 1K _i∈ S _iνitalic_S := divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S μitalic_i and βS:=‖1K⁢∑i∈Sμi‖22−1K⁢∑i∈S‖μi‖22+1K⁢∑i∈Sαi,assignsubscriptsuperscriptsubscriptnorm1subscriptsubscript221subscriptsuperscriptsubscriptnormsubscript221subscriptsubscript _S:= \| 1K _i∈ S _i \|_2^2% - 1K _i∈ S\| _i\|_2^2+ 1K _i∈ S _i,βitalic_S := ∥ divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S μitalic_i ∥22 - divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S ∥ μitalic_i ∥22 + divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S αitalic_i , (B.1) i.e., RS:=x:‖x−νS‖22−βS<‖x−νT‖22−βT⁢ for all T≠S, T a K-subset of 1,…,k,assignsubscriptconditional-setsubscriptsuperscriptnormsubscript22subscriptsubscriptsuperscriptnormsubscript22subscript for all T≠S, T a K-subset of 1,…,kR_S:=\x:\|x- _S\|^2_2- _S<\|x- _T\|^2_2- _T% for all $T≠ S$, $T$ a $K$-subset of $\1,...,k\$\,Ritalic_S := x : ∥ x - νitalic_S ∥22 - βitalic_S < ∥ x - νitalic_T ∥22 - βitalic_T for all T ≠ S , T a K -subset of 1 , … , k , satisfies RS=CSsubscriptsubscriptR_S=C_SRitalic_S = Citalic_S for all S. Proof of Theorem B.2. Define the power functions Pi⁢(x):=−2⁢μiT⁢x+‖μi‖22−αiassignsubscript2superscriptsubscriptsubscriptsuperscriptnormsubscript22subscript P_i(x):=-2 _i^Tx+\| _i\|^2_2- _iPitalic_i ( x ) := - 2 μitalic_iitalic_T x + ∥ μitalic_i ∥22 - αitalic_i and QS⁢(x):=−2⁢νST⁢x+‖νS‖22−βS.assignsubscript2superscriptsubscriptsubscriptsuperscriptnormsubscript22subscript Q_S(x):=-2 _S^Tx+\| _S\|^2_2- _S.Qitalic_S ( x ) := - 2 νitalic_Sitalic_T x + ∥ νitalic_S ∥22 - βitalic_S . Then by subtracting ‖x‖22superscriptsubscriptnorm22\|x\|_2^2∥ x ∥22 from both sides of the defining inequalities we get: CSsubscript C_SCitalic_S =x:Pi⁢(x)<Pj⁢(x)⁢ for all i∈S, and j∈Sc,absentconditional-setsubscriptsubscript for all i∈S, and j∈Sc =\x:P_i(x)<P_j(x) for all $i∈ S$, and $j∈ S^c$\,= x : Pitalic_i ( x ) < Pitalic_j ( x ) for all i ∈ S , and j ∈ Sitalic_c , RSsubscript R_SRitalic_S =x:QS⁢(x)<QT⁢(x)⁢ for all T≠S, T a K-subset of 1,…,k.absentconditional-setsubscriptsubscript for all T≠S, T a K-subset of 1,…,k =\x:Q_S(x)<Q_T(x) for all $T≠ S$, $T$ a $K$-subset % of $\1,...,k\$\.= x : Qitalic_S ( x ) < Qitalic_T ( x ) for all T ≠ S , T a K -subset of 1 , … , k . It is straightforward to check that QS⁢(x)=1K⁢∑i∈SPi⁢(x),subscript1subscriptsubscriptQ_S(x)= 1K _i∈ SP_i(x),Qitalic_S ( x ) = divide start_ARG 1 end_ARG start_ARG K end_ARG ∑i ∈ S Pitalic_i ( x ) , and hence x∈RS⁢ if and only if ⁢∑i∈SPi⁢(x)<∑i∈TPi⁢(x)⁢ for all T≠S a K-subset.subscript if and only if subscriptsubscriptsubscriptsubscript for all T≠S a K-subsetx∈ R_S if and only if _i∈ SP_i(x)< _i∈ TP_i(x)% for all $T≠ S$ a $K$-subset.x ∈ Ritalic_S if and only if ∑i ∈ S Pitalic_i ( x ) < ∑i ∈ T Pitalic_i ( x ) for all T ≠ S a K -subset . Suppose that x∈RSsubscriptx∈ R_Sx ∈ Ritalic_S, and let i∈Si∈ Si ∈ S and j∈Scsuperscriptj∈ S^cj ∈ Sitalic_c. Let T=(S∖i)∪jT=(S \i\)∪\j\T = ( S ∖ i ) ∪ j . This is a K-subset distinct from S, and so Pi⁢(x)+∑k∈S∖iPk⁢(x)=∑i∈SPi⁢(x)<∑i∈TPi⁢(x)=Pj⁢(x)+∑k∈S∖iPk⁢(x)subscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptP_i(x)+ _k∈ S \i\P_k(x)= _i∈ SP_i(x)< _i∈ T% P_i(x)=P_j(x)+ _k∈ S \i\P_k(x)Pitalic_i ( x ) + ∑k ∈ S ∖ i Pitalic_k ( x ) = ∑i ∈ S Pitalic_i ( x ) < ∑i ∈ T Pitalic_i ( x ) = Pitalic_j ( x ) + ∑k ∈ S ∖ i Pitalic_k ( x ) and hence Pi⁢(x)<Pj⁢(x)subscriptsubscriptP_i(x)<P_j(x)Pitalic_i ( x ) < Pitalic_j ( x ). Hence x∈CSsubscriptx∈ C_Sx ∈ Citalic_S. Now suppose that x∈CSsubscriptx∈ C_Sx ∈ Citalic_S and let T≠ST≠ ST ≠ S be a K-subset. Then ∑i∈SPi⁢(x)=∑i∈S∩TPi⁢(x)+∑i∈S∖TPi⁢(x)<∑i∈S∩TPi⁢(x)+∑j∈T∖SPj⁢(x)=∑i∈TPi⁢(x)subscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscriptsubscript _i∈ SP_i(x)= _i∈ S∩ TP_i(x)+ _i∈ S TP_i% (x)< _i∈ S∩ TP_i(x)+ _j∈ T SP_j(x)= _i∈ T% P_i(x)∑i ∈ S Pitalic_i ( x ) = ∑i ∈ S ∩ T Pitalic_i ( x ) + ∑i ∈ S ∖ T Pitalic_i ( x ) < ∑i ∈ S ∩ T Pitalic_i ( x ) + ∑j ∈ T ∖ S Pitalic_j ( x ) = ∑i ∈ T Pitalic_i ( x ) where we have used that |S∖T|=|T∖S||S T|=|T S|| S ∖ T | = | T ∖ S | and for each i∈S∖Ti∈ S Ti ∈ S ∖ T and j∈T∖Sj∈ T Sj ∈ T ∖ S, Pi⁢(x)<Pj⁢(x)subscriptsubscriptP_i(x)<P_j(x)Pitalic_i ( x ) < Pitalic_j ( x ). Hence x∈RSsubscriptx∈ R_Sx ∈ Ritalic_S. ∎ Proof of 2.2. Identify ℝ2superscriptℝ2R^2blackboard_R2 with ℂCblackboard_C. Then we would need μii=14∈ℂsuperscriptsubscriptsubscript14ℂ\ _i\_i=1^4 μitalic_i i = 14 ∈ blackboard_C such that ν12=1subscript121 _12=1ν12 = 1 =12⁢μ1+12⁢μ2,absent12subscript112subscript2 = 12 _1+ 12 _2,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 , ν13=eπ⁢i/3subscript13superscript3 _13=e^π i/3ν13 = eitalic_π i / 3 =12⁢μ1+12⁢μ3,absent12subscript112subscript3 = 12 _1+ 12 _3,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ3 , ν14=e2⁢π⁢i/3subscript14superscript23 _14=e^2π i/3ν14 = e2 π i / 3 =12⁢μ1+12⁢μ4,absent12subscript112subscript4 = 12 _1+ 12 _4,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 , ν23=−1subscript231 _23=-1ν23 = - 1 =12⁢μ2+12⁢μ3,absent12subscript212subscript3 = 12 _2+ 12 _3,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ3 , ν24=−eπ⁢i/3subscript24superscript3 _24=-e^π i/3ν24 = - eitalic_π i / 3 =12⁢μ2+12⁢μ4,absent12subscript212subscript4 = 12 _2+ 12 _4,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 , ν34=−e2⁢π⁢i/3subscript34superscript23 _34=-e^2π i/3ν34 = - e2 π i / 3 =12⁢μ3+12⁢μ4.absent12subscript312subscript4 = 12 _3+ 12 _4.= divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ3 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 . Hence μ2−μ4=(12⁢μ1+12⁢μ2)+(12⁢μ2+12⁢μ3)−(12⁢μ1+12⁢μ4)−(12⁢μ3+12⁢μ4)=1+(−1)−e2⁢π⁢i/3−(−e2⁢π⁢i/3)=0subscript2subscript412subscript112subscript212subscript212subscript312subscript112subscript412subscript312subscript411superscript23superscript230 _2- _4=( 12 _1+ 12 _2)+( 12 _2+% 12 _3)-( 12 _1+ 12 _4)-( 12 _3% + 12 _4)=1+(-1)-e^2π i/3-(-e^2π i/3)=0μ2 - μ4 = ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 ) + ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ3 ) - ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 ) - ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ3 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 ) = 1 + ( - 1 ) - e2 π i / 3 - ( - e2 π i / 3 ) = 0, so μ2=μ4subscript2subscript4 _2= _4μ2 = μ4 and hence 1=12⁢μ1+12⁢μ2=12⁢μ1+12⁢μ4=e2⁢π⁢i/3112subscript112subscript212subscript112subscript4superscript231= 12 _1+ 12 _2= 12 _1+ 12 _4=% e^2π i/31 = divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ2 = divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG μ4 = e2 π i / 3, a contradiction. ∎ Proof of Theorem 2.3. For each S, ΩSTopKsubscriptsuperscriptΩTopK ^TopK_SΩroman_TopKitalic_S is given by ΩSTopK=x∈ℝn:eiT⁢We⁢n⁢c⁢x+(be⁢n⁢c)i>ejT⁢We⁢n⁢c⁢x+(be⁢n⁢c)j⁢ for all i∈S and j∈Sc.subscriptsuperscriptΩTopKconditional-setsuperscriptℝsuperscriptsubscriptsubscriptsubscriptsubscriptsuperscriptsubscriptsubscriptsubscriptsubscript for all i∈S and j∈Sc ^TopK_S=\x ^n:e_i^TW_encx+(b_% enc)_i>e_j^TW_encx+(b_enc)_j for all $i∈ S$ and $j∈ S^% c$\.Ωroman_TopKitalic_S = x ∈ blackboard_Rn : eitalic_iitalic_T Witalic_e n c x + ( bitalic_e n c )i > eitalic_jitalic_T Witalic_e n c x + ( bitalic_e n c )j for all i ∈ S and j ∈ Sitalic_c . We can rewrite the definition of CSsubscriptC_SCitalic_S in a KthsuperscriptthK^thKth-order power diagram as CS:=x∈ℝn:μiT⁢x+12⁢αi−12⁢‖μi‖22>μjT⁢x+12⁢αj−12⁢‖μj‖22⁢ for all i∈S and j∈Sc.assignsubscriptconditional-setsuperscriptℝsuperscriptsubscript12subscript12superscriptsubscriptnormsubscript22superscriptsubscript12subscript12superscriptsubscriptnormsubscript22 for all i∈S and j∈ScC_S:=\x ^n: _i^Tx+ 12 _i- 12\|% _i\|_2^2> _j^Tx+ 12 _j- 12\| _j\|_2% ^2 for all $i∈ S$ and $j∈ S^c$\.Citalic_S := x ∈ blackboard_Rn : μitalic_iitalic_T x + divide start_ARG 1 end_ARG start_ARG 2 end_ARG αitalic_i - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ μitalic_i ∥22 > μitalic_jitalic_T x + divide start_ARG 1 end_ARG start_ARG 2 end_ARG αitalic_j - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ μitalic_j ∥22 for all i ∈ S and j ∈ Sitalic_c . It follows that ΩSTopK=CSsubscriptsuperscriptΩTopKsubscript ^TopK_S=C_SΩroman_TopKitalic_S = Citalic_S if for all ℓ ℓ eℓT⁢We⁢n⁢c=μℓTsubscriptsuperscriptℓsubscriptsuperscriptsubscriptℓ e^T_ W_enc= _ ^Teitalic_Troman_ℓ Witalic_e n c = μroman_ℓitalic_T and (be⁢n⁢c)ℓ=12⁢αℓ−12⁢‖μℓ‖22,subscriptsubscriptℓ12subscriptℓ12superscriptsubscriptnormsubscriptℓ22 (b_enc)_ = 12 _ - 12\| _ % \|_2^2,( bitalic_e n c )ℓ = divide start_ARG 1 end_ARG start_ARG 2 end_ARG αroman_ℓ - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ μroman_ℓ ∥22 , i.e. μℓ=We⁢n⁢cT⁢eℓsubscriptℓsubscriptsuperscriptsubscriptℓ _ =W^T_ence_ μroman_ℓ = Witalic_Titalic_e n c eroman_ℓ and αℓ=2⁢(be⁢n⁢c)ℓ+‖We⁢n⁢cT⁢eℓ‖22.subscriptℓ2subscriptsubscriptℓsuperscriptsubscriptnormsubscriptsuperscriptsubscriptℓ22 _ =2(b_enc)_ +\|W^T_ence_ \|_2^2.αroman_ℓ = 2 ( bitalic_e n c )ℓ + ∥ Witalic_Titalic_e n c eroman_ℓ ∥22 . Hence, any (We⁢n⁢c,be⁢n⁢c)subscriptsubscript(W_enc,b_enc)( Witalic_e n c , bitalic_e n c ) gives rise to a KthsuperscriptthK^thKth-order power diagram, and any KthsuperscriptthK^thKth-order power diagram gives rise to an (We⁢n⁢c,be⁢n⁢c)subscriptsubscript(W_enc,b_enc)( Witalic_e n c , bitalic_e n c ). Finally, Equation 2.2 follows from Equation B.1 and Equation 2.1. ∎ Proof of Theorem 2.6. We seek to minimise ℒ=∑i=1kλ⁢Ni⁢Ki+∑xr∈Ri‖xr−Ui⁢ViT⁢xr−ci‖22.ℒsuperscriptsubscript1subscriptsubscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptnormsuperscriptsubscriptsuperscriptsubscriptsuperscriptsubscript22L= _i=1^kλ N_iK_i+ _x^r∈ R_i\|x^r-U_i% V_i^Tx^r-c_i\|_2^2.L = ∑i = 1k λ Nitalic_i Kitalic_i + ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ xitalic_r - Uitalic_i Vitalic_iitalic_T xitalic_r - citalic_i ∥22 . For the same reason as for the νisubscript _iνitalic_i in k-means, the cisubscriptc_icitalic_i will be minimising when ci=1Ni⁢∑xr∈Ri(I−Ui⁢ViT)⁢xr=(I−Ui⁢ViT)⁢x¯i,subscript1subscriptsubscriptsuperscriptsubscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptsubscript¯c_i= 1N_i _x^r∈ R_i(I-U_iV_i^T)x^r=(I-U_iV_i% ^T) x_i,citalic_i = divide start_ARG 1 end_ARG start_ARG Nitalic_i end_ARG ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ( I - Uitalic_i Vitalic_iitalic_T ) xitalic_r = ( I - Uitalic_i Vitalic_iitalic_T ) over¯ start_ARG x end_ARGi , and so ℒLL simplifies to ℒ=∑i=1kλ⁢Ni⁢Ki+∑xr∈Ri‖(I−Ui⁢ViT)⁢(xr−x¯i)‖22.ℒsuperscriptsubscript1subscriptsubscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptnormsubscriptsuperscriptsubscriptsuperscriptsubscript¯22L= _i=1^kλ N_iK_i+ _x^r∈ R_i\|(I-U_iV_% i^T)(x^r- x_i)\|_2^2.L = ∑i = 1k λ Nitalic_i Kitalic_i + ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ ( I - Uitalic_i Vitalic_iitalic_T ) ( xitalic_r - over¯ start_ARG x end_ARGi ) ∥22 . We claim that this is minimised when Ui=VisubscriptsubscriptU_i=V_iUitalic_i = Vitalic_i and the columns of UisubscriptU_iUitalic_i are the top KisubscriptK_iKitalic_i eigenvectors of Xi:=1Ni⁢∑xr∈Ri(xr−x¯i)⁢(xr−x¯i)T,assignsubscript1subscriptsubscriptsuperscriptsubscriptsuperscriptsubscript¯superscriptsuperscriptsubscript¯X_i:= 1N_i _x^r∈ R_i(x^r- x_i)(x^r- x_% i)^T,Xitalic_i := divide start_ARG 1 end_ARG start_ARG Nitalic_i end_ARG ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ( xitalic_r - over¯ start_ARG x end_ARGi ) ( xitalic_r - over¯ start_ARG x end_ARGi )T , and in this case ℒLL reduces to ℒ=∑i=1kλ⁢Ni⁢Ki+(∑xr∈Ri‖(xr−x¯i)‖22)−Ni⁢∑ℓ=1Kiλℓ⁢(Xi),ℒsuperscriptsubscript1subscriptsubscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptnormsuperscriptsubscript¯22subscriptsuperscriptsubscriptℓ1subscriptsubscriptℓsubscriptL= _i=1^kλ N_iK_i+ ( _x^r∈ R_i\|(x^% r- x_i)\|_2^2 )-N_i _ =1^K_i _ (X_i% ),L = ∑i = 1k λ Nitalic_i Kitalic_i + ( ∑xitalic_r ∈ R start_POSTSUBSCRIPT i end_POSTSUBSCRIPT ∥ ( xitalic_r - over¯ start_ARG x end_ARGi ) ∥22 ) - Nitalic_i ∑ℓ = 1Kitalic_i λroman_ℓ ( Xitalic_i ) , where λℓ⁢(Xi)subscriptℓsubscript _ (X_i)λroman_ℓ ( Xitalic_i ) are the eigenvalues of XisubscriptX_iXitalic_i in descending order. This follows because ∑xr‖(I−U⁢VT)⁢(xr−x¯)‖22subscriptsuperscriptsuperscriptsubscriptnormsuperscriptsuperscript¯22 _x^r\|(I-UV^T)(x^r- x)\|_2^2∑xitalic_r ∥ ( I - U Vitalic_T ) ( xitalic_r - over¯ start_ARG x end_ARG ) ∥22 =∑xr(xr−x¯)T⁢(I−V⁢UT)⁢(I−U⁢VT)⁢(xr−x¯)absentsubscriptsuperscriptsuperscriptsuperscript¯superscriptsuperscriptsuperscript¯ = _x^r(x^r- x)^T(I-VU^T)(I-UV^T)(x^r- x)= ∑xitalic_r ( xitalic_r - over¯ start_ARG x end_ARG )T ( I - V Uitalic_T ) ( I - U Vitalic_T ) ( xitalic_r - over¯ start_ARG x end_ARG ) =∑xrtr⁡((xr−x¯)T⁢(I−V⁢UT)⁢(I−U⁢VT)⁢(xr−x¯))absentsubscriptsuperscripttrsuperscriptsuperscript¯superscriptsuperscriptsuperscript¯ = _x^rtr ((x^r- x)^T(I-VU^T)% (I-UV^T)(x^r- x) )= ∑xitalic_r tr ( ( xitalic_r - over¯ start_ARG x end_ARG )T ( I - V Uitalic_T ) ( I - U Vitalic_T ) ( xitalic_r - over¯ start_ARG x end_ARG ) ) =∑xrtr⁡((xr−x¯)⁢(xr−x¯)T⁢(I−V⁢UT)⁢(I−U⁢VT))absentsubscriptsuperscripttrsuperscript¯superscriptsuperscript¯superscriptsuperscript = _x^rtr ((x^r- x)(x^r- x% )^T(I-VU^T)(I-UV^T) )= ∑xitalic_r tr ( ( xitalic_r - over¯ start_ARG x end_ARG ) ( xitalic_r - over¯ start_ARG x end_ARG )T ( I - V Uitalic_T ) ( I - U Vitalic_T ) ) =N⁢tr⁡(X⁢(I−V⁢UT)⁢(I−U⁢VT))absenttrsuperscriptsuperscript =Ntr (X(I-VU^T)(I-UV^T) )= N tr ( X ( I - V Uitalic_T ) ( I - U Vitalic_T ) ) =N⁢tr⁡((X−X⁢V⁢UT)⁢(I−U⁢VT))absenttrsuperscriptsuperscript =Ntr ((X-XVU^T)(I-UV^T) )= N tr ( ( X - X V Uitalic_T ) ( I - U Vitalic_T ) ) =N⁢tr⁡(X−X⁢V⁢UT−X⁢U⁢VT+X⁢V⁢VT)absenttrsuperscriptsuperscriptsuperscript =Ntr (X-XVU^T-XUV^T+XVV^T )= N tr ( X - X V Uitalic_T - X U Vitalic_T + X V Vitalic_T ) =N⁢tr⁡(X)+N⁢tr⁡(−UT⁢X⁢V−VT⁢X⁢U+VT⁢X⁢V)absenttrtrsuperscriptsuperscriptsuperscript =Ntr(X)+Ntr (-U^TXV-V^TXU+V% ^TXV )= N tr ( X ) + N tr ( - Uitalic_T X V - Vitalic_T X U + Vitalic_T X V ) =N⁢tr⁡(X)+N⁢tr⁡((V−U)T⁢X⁢(V−U))−N⁢tr⁡(UT⁢X⁢U)absenttrtrsuperscripttrsuperscript =Ntr(X)+Ntr ((V-U)^TX(V-U)% )-Ntr (U^TXU )= N tr ( X ) + N tr ( ( V - U )T X ( V - U ) ) - N tr ( Uitalic_T X U ) is minimised when U=VU=VU = V and tr⁡(UT⁢X⁢U)trsuperscripttr (U^TXU )tr ( Uitalic_T X U ) is maximised (with the constraint that UT⁢U=IsuperscriptU^TU=IUitalic_T U = I). This occurs when U has columns the top K leading eigenvectors of X. At this choice: ∑xr‖(I−U⁢VT)⁢(xr−x¯)‖2=N⁢tr⁡(X)−N⁢∑ℓ=1kλℓ=∑xrtr⁡((xr−x¯)T⁢(xr−x¯))−N⁢∑ℓ=1kλℓ=∑xr‖xr−x¯‖22−N⁢∑ℓ=1kλℓ.subscriptsuperscriptsuperscriptnormsuperscriptsuperscript¯2trsuperscriptsubscriptℓ1subscriptℓsubscriptsuperscripttrsuperscriptsuperscript¯superscript¯superscriptsubscriptℓ1subscriptℓsubscriptsuperscriptsuperscriptsubscriptnormsuperscript¯22superscriptsubscriptℓ1subscriptℓ _x^r\|(I-UV^T)(x^r- x)\|^2=Ntr(X)-N _% =1^k _ = _x^rtr((x^r- x)^T(x^% r- x))-N _ =1^k _ = _x^r\|x^r- x\|_% 2^2-N _ =1^k _ .∑xitalic_r ∥ ( I - U Vitalic_T ) ( xitalic_r - over¯ start_ARG x end_ARG ) ∥2 = N tr ( X ) - N ∑ℓ = 1k λroman_ℓ = ∑xitalic_r tr ( ( xitalic_r - over¯ start_ARG x end_ARG )T ( xitalic_r - over¯ start_ARG x end_ARG ) ) - N ∑ℓ = 1k λroman_ℓ = ∑xitalic_r ∥ xitalic_r - over¯ start_ARG x end_ARG ∥22 - N ∑ℓ = 1k λroman_ℓ . ∎ Appendix C Proofs in Section 3 Theorem C.1. The (Wdect+1,bdect+1)superscriptsubscriptdec1superscriptsubscriptdec1(W_dec^t+1,b_dec^t+1)( Wdecitalic_t + 1 , bdecitalic_t + 1 ) solving Equation 3.1b for G:=α⁢‖Wdec‖F2+β⁢‖bdec‖22assignsubscriptsuperscriptnormsubscriptdec2superscriptsubscriptnormsubscriptdec22G:=α\|W_dec\|^2_F+β\|b_dec\|_2^2G := α ∥ Wdec ∥2F + β ∥ bdec ∥22 are given by bdect+1subscriptsuperscript1dec b^t+1_decbitalic_t + 1dec =νdectN+β+νdect⁢bdect+N+β+νdect⁢(x¯−Wdec⁢z¯t+1),absentsubscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯subscriptdecsubscript¯1 = ν^t_decN+β+ν^t_decb_% dec^t+ NN+β+ν^t_dec( x-W_dec% z_t+1),= divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t + divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG ( over¯ start_ARG x end_ARG - Wdec over¯ start_ARG z end_ARGt + 1 ) , Wdect+1subscriptsuperscript1dec W^t+1_decWitalic_t + 1dec =(Φtμdect⁢Wdect0n×d)⁢(Ψtμdect⁢Idα⁢Id)†,absentmatrixsubscriptΦsubscriptsuperscriptdecsubscriptsuperscriptdecsubscript0superscriptmatrixsubscriptΨsubscriptsuperscriptdecsubscriptsubscript† = pmatrix _t& μ^t_decW^t_% dec&0_n× d pmatrix pmatrix _t& μ^t_% decI_d& αI_d pmatrix ,= ( start_ARG start_ROW start_CELL Φitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Witalic_tdec end_CELL start_CELL 0n × d end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL Ψitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Iitalic_d end_CELL start_CELL square-root start_ARG α end_ARG Iitalic_d end_CELL end_ROW end_ARG )† , where zt+1r:=ρ⁢(Wenct+1⁢xr+benct+1)assignsubscriptsuperscript1subscriptsuperscript1encsuperscriptsubscriptsuperscript1encz^r_t+1:=ρ(W^t+1_encx^r+b^t+1_enc)zitalic_ritalic_t + 1 := ρ ( Witalic_t + 1enc xitalic_r + bitalic_t + 1enc ), z¯t+1subscript¯1 z_t+1over¯ start_ARG z end_ARGt + 1 is the mean of the zt+1rsubscriptsuperscript1z^r_t+1zitalic_ritalic_t + 1, and Ψt∈ℝd×(N+2)subscriptΨsuperscriptℝ2 _t ^d×(N+2)Ψitalic_t ∈ blackboard_Rd × ( N + 2 ) and Φt∈ℝn×(N+2)subscriptΦsuperscriptℝ2 _t ^n×(N+2)Φitalic_t ∈ blackboard_Rn × ( N + 2 ) are matrices with columns (for r=11r=1r = 1 to N) ψtrsuperscriptsubscript _t^rψitalic_titalic_r :=zt+1r−N⁢z¯t+1N+β+νdect,assignabsentsubscriptsuperscript1subscript¯1subscriptsuperscriptdec :=z^r_t+1- N z_t+1N+β+ν^t_dec,:= zitalic_ritalic_t + 1 - divide start_ARG N over¯ start_ARG z end_ARGt + 1 end_ARG start_ARG N + β + νitalic_tdec end_ARG , ψtN+1:=N⁢νdectN+β+νdect⁢z¯t+1,assignsuperscriptsubscript1superscriptsubscriptdecsuperscriptsubscriptdecsubscript¯1 _t^N+1:= N _dec^tN+β+ _% dec^t z_t+1,ψitalic_titalic_N + 1 := divide start_ARG N square-root start_ARG νdecitalic_t end_ARG end_ARG start_ARG N + β + νdecitalic_t end_ARG over¯ start_ARG z end_ARGt + 1 , ψtN+2:=N⁢βN+β+νdect⁢z¯t+1,assignsuperscriptsubscript2superscriptsubscriptdecsubscript¯1 _t^N+2:= N βN+β+ _dec^t% z_t+1,ψitalic_titalic_N + 2 := divide start_ARG N square-root start_ARG β end_ARG end_ARG start_ARG N + β + νdecitalic_t end_ARG over¯ start_ARG z end_ARGt + 1 , ϕtrsuperscriptsubscriptitalic-ϕ _t^rϕitalic_titalic_r :=xr−N⁢x¯+νdect⁢bdectN+β+νdect,assignabsentsuperscript¯subscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec :=x^r- N x+ν^t_decb_dec^tN+% β+ν^t_dec,:= xitalic_r - divide start_ARG N over¯ start_ARG x end_ARG + νitalic_tdec bdecitalic_t end_ARG start_ARG N + β + νitalic_tdec end_ARG , ϕtN+1:=νdect⁢N⁢x¯−(N+β)⁢bdectN+β+νdect,assignsuperscriptsubscriptitalic-ϕ1superscriptsubscriptdec¯subscriptsuperscriptdecsuperscriptsubscriptdec _t^N+1:= _dec^t N x-(N+% β)b^t_decN+β+ _dec^t,ϕitalic_titalic_N + 1 := square-root start_ARG νdecitalic_t end_ARG divide start_ARG N over¯ start_ARG x end_ARG - ( N + β ) bitalic_tdec end_ARG start_ARG N + β + νdecitalic_t end_ARG , ϕtN+2:=β⁢N⁢x¯+νdect⁢bdectN+β+νdect.assignsuperscriptsubscriptitalic-ϕ2¯subscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec _t^N+2:= β N x+ν^t_decb_% dec^tN+β+ν^t_dec.ϕitalic_titalic_N + 2 := square-root start_ARG β end_ARG divide start_ARG N over¯ start_ARG x end_ARG + νitalic_tdec bdecitalic_t end_ARG start_ARG N + β + νitalic_tdec end_ARG . Proof of Theorem C.1. By completing the square, the optimal bdecsubscriptdecb_decbdec is given by bdec=1N+β+νdect⁢(νdect⁢bdect+∑r=1txr−Wdec⁢zt+1r)=νdectN+β+νdect⁢bdect+N+β+νdect⁢(x¯−Wdec⁢z¯t+1),subscriptdec1subscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsuperscriptsubscript1superscriptsubscriptdecsuperscriptsubscript1subscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯subscriptdecsubscript¯1b_dec= 1N+β+ν^t_dec (ν^t_dec% b_dec^t+ _r=1^tx^r-W_decz_t+1^r )= % ν^t_decN+β+ν^t_decb_dec^t+ N% N+β+ν^t_dec( x-W_dec z_t+1),bdec = divide start_ARG 1 end_ARG start_ARG N + β + νitalic_tdec end_ARG ( νitalic_tdec bdecitalic_t + ∑r = 1t xitalic_r - Wdec zitalic_t + 1r ) = divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t + divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG ( over¯ start_ARG x end_ARG - Wdec over¯ start_ARG z end_ARGt + 1 ) , where zt+1r:=ρ⁢(Wenct+1⁢xr+benct+1)assignsubscriptsuperscript1subscriptsuperscript1encsuperscriptsubscriptsuperscript1enc z^r_t+1:=ρ(W^t+1_encx^r+b^t+1_enc)zitalic_ritalic_t + 1 := ρ ( Witalic_t + 1enc xitalic_r + bitalic_t + 1enc ) and z¯t+1:=1N⁢∑r=1Nzt+1r.assignsubscript¯11superscriptsubscript1subscriptsuperscript1 z_t+1:= 1N _r=1^Nz^r_t+1.over¯ start_ARG z end_ARGt + 1 := divide start_ARG 1 end_ARG start_ARG N end_ARG ∑r = 1N zitalic_ritalic_t + 1 . This reduces ℒLL to: ∑r=1N‖Wdec⁢(zt+1r−N+β+νdect⁢z¯t+1)+νdectN+β+νdect⁢bdect+N+β+νdect⁢x¯−xr‖22superscriptsubscript1superscriptsubscriptnormsubscriptdecsubscriptsuperscript1subscriptsuperscriptdecsubscript¯1subscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯superscript22 _r=1^N \|W_dec (z^r_t+1- NN+% β+ν^t_dec z_t+1 )+ ν^t_decN+% β+ν^t_decb_dec^t+ NN+β+ν^t_% dec x-x^r \|_2^2∑r = 1N ∥ Wdec ( zitalic_ritalic_t + 1 - divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG over¯ start_ARG z end_ARGt + 1 ) + divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t + divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG over¯ start_ARG x end_ARG - xitalic_r ∥22 +μdect⁢‖Wdec−Wdect‖F2+α⁢‖Wdec‖F2+νdect⁢‖N+β+νdect⁢(x¯−Wdec⁢z¯t+1)−N+βN+β+νdect⁢bdect‖22superscriptsubscriptdecsuperscriptsubscriptnormsubscriptdecsuperscriptsubscriptdec2subscriptsuperscriptnormsubscriptdec2subscriptsuperscriptdecsuperscriptsubscriptnormsubscriptsuperscriptdec¯subscriptdecsubscript¯1subscriptsuperscriptdecsuperscriptsubscriptdec22 + _dec^t\|W_dec-W_dec^t\|_F^2% +α\|W_dec\|^2_F+ν^t_dec \| NN+β% +ν^t_dec( x-W_dec z_t+1)- N+βN+% β+ν^t_decb_dec^t \|_2^2+ μdecitalic_t ∥ Wdec - Wdecitalic_t ∥F2 + α ∥ Wdec ∥2F + νitalic_tdec ∥ divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG ( over¯ start_ARG x end_ARG - Wdec over¯ start_ARG z end_ARGt + 1 ) - divide start_ARG N + β end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t ∥22 +β⁢‖νdectN+β+νdect⁢bdect+N+β+νdect⁢(x¯−Wdec⁢z¯t+1)‖22.superscriptsubscriptnormsubscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptdecsubscriptsuperscriptdec¯subscriptdecsubscript¯122 +β \| ν^t_decN+β+ν^t_% decb_dec^t+ NN+β+ν^t_dec( x-W_% dec z_t+1) \|_2^2.+ β ∥ divide start_ARG νitalic_tdec end_ARG start_ARG N + β + νitalic_tdec end_ARG bdecitalic_t + divide start_ARG N end_ARG start_ARG N + β + νitalic_tdec end_ARG ( over¯ start_ARG x end_ARG - Wdec over¯ start_ARG z end_ARGt + 1 ) ∥22 . Hence ℒ =∑r=1N+2‖Wdec⁢ψtr−ϕtr‖22+μdect⁢‖Wdec−Wdect‖F2+α⁢‖Wdec‖F2absentsuperscriptsubscript12subscriptsuperscriptnormsubscriptdecsuperscriptsubscriptsuperscriptsubscriptitalic-ϕ22superscriptsubscriptdecsuperscriptsubscriptnormsubscriptdecsuperscriptsubscriptdec2subscriptsuperscriptnormsubscriptdec2 = _r=1^N+2\|W_dec _t^r- _t^r\|^2_% 2+ _dec^t\|W_dec-W_dec^t\|_F^2+α\|% W_dec\|^2_F= ∑r = 1N + 2 ∥ Wdec ψitalic_titalic_r - ϕitalic_titalic_r ∥22 + μdecitalic_t ∥ Wdec - Wdecitalic_t ∥F2 + α ∥ Wdec ∥2F =‖Wdec⁢Ψt−Φt‖F2+μdect⁢‖Wdec−Wdect‖F2+α⁢‖Wdec‖F2absentsuperscriptsubscriptnormsubscriptdecsubscriptΨsubscriptΦ2superscriptsubscriptdecsuperscriptsubscriptnormsubscriptdecsuperscriptsubscriptdec2subscriptsuperscriptnormsubscriptdec2 =\|W_dec _t- _t\|_F^2+ _dec^t% \|W_dec-W_dec^t\|_F^2+α\|W_dec\|^2_F= ∥ Wdec Ψitalic_t - Φitalic_t ∥F2 + μdecitalic_t ∥ Wdec - Wdecitalic_t ∥F2 + α ∥ Wdec ∥2F =‖Wdec⁢(Ψtμdect⁢Idα⁢Id)−(Φtμdect⁢Wdect0n×d)‖F2absentsuperscriptsubscriptnormsubscriptdecmatrixsubscriptΨsubscriptsuperscriptdecsubscriptsubscriptmatrixsubscriptΦsubscriptsuperscriptdecsubscriptsuperscriptdecsubscript02 = \|W_dec pmatrix _t& μ^t_% decI_d& αI_d pmatrix- pmatrix _t&% μ^t_decW^t_dec&0_n× d pmatrix % \|_F^2= ∥ Wdec ( start_ARG start_ROW start_CELL Ψitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Iitalic_d end_CELL start_CELL square-root start_ARG α end_ARG Iitalic_d end_CELL end_ROW end_ARG ) - ( start_ARG start_ROW start_CELL Φitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Witalic_tdec end_CELL start_CELL 0n × d end_CELL end_ROW end_ARG ) ∥F2 and so Wdecsubscriptdec W_decWdec =(Φtμdect⁢Wdect0n×d)⁢(Ψtμdect⁢Idα⁢Id)†absentmatrixsubscriptΦsubscriptsuperscriptdecsubscriptsuperscriptdecsubscript0superscriptmatrixsubscriptΨsubscriptsuperscriptdecsubscriptsubscript† = pmatrix _t& μ^t_decW^t_% dec&0_n× d pmatrix pmatrix _t& μ^t_% decI_d& αI_d pmatrix = ( start_ARG start_ROW start_CELL Φitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Witalic_tdec end_CELL start_CELL 0n × d end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL Ψitalic_t end_CELL start_CELL square-root start_ARG μitalic_tdec end_ARG Iitalic_d end_CELL start_CELL square-root start_ARG α end_ARG Iitalic_d end_CELL end_ROW end_ARG )† =(Φt⁢ΨtT+μdect⁢Wdect)⁢(Ψt⁢ΨtT+(α+μdect)⁢Id)−1.absentsubscriptΦsuperscriptsubscriptΨsubscriptsuperscriptdecsubscriptsuperscriptdecsuperscriptsubscriptΨsuperscriptsubscriptΨsubscriptsuperscriptdecsubscript1 =( _t _t^T+μ^t_decW^t_dec)% ( _t _t^T+(α+μ^t_dec)I_d )^-1.= ( Φitalic_t Ψitalic_titalic_T + μitalic_tdec Witalic_tdec ) ( Ψitalic_t Ψitalic_titalic_T + ( α + μitalic_tdec ) Iitalic_d )- 1 . ∎ C.1 Convergence proof We will need to make the following assumptions. Assumption C.2. F, G, ρ, and μenct,νenct,μdect,νdectt=0∞superscriptsubscriptsuperscriptsubscriptencsuperscriptsubscriptencsuperscriptsubscriptdecsuperscriptsubscriptdec0\ _enc^t, _enc^t, _dec^t, _% dec^t\_t=0^∞ μencitalic_t , νencitalic_t , μdecitalic_t , νdecitalic_t t = 0∞ are such that: 1. F and G are continuous and bounded below. 2. For all x∈ℝnsuperscriptℝx ^nx ∈ blackboard_Rn, ‖Wdec⁢ρ⁢(Wenc⁢x+benc)+bdec−x‖22superscriptsubscriptnormsubscriptdecsubscriptencsubscriptencsubscriptdec22\|W_decρ(W_encx+b_enc)+b_dec-x\|_2^2∥ Wdec ρ ( Wenc x + benc ) + bdec - x ∥22 is C1superscript1C^1C1 in WencsubscriptencW_encWenc, bencsubscriptencb_encbenc, WdecsubscriptdecW_decWdec, and bdecsubscriptdecb_decbdec, with a gradient that is locally Lipschitz. 3. μenct,νenct,μdect,νdectt=0∞∈(a,b)superscriptsubscriptsuperscriptsubscriptencsuperscriptsubscriptencsuperscriptsubscriptdecsuperscriptsubscriptdec0\ _enc^t, _enc^t, _dec^t, _% dec^t\_t=0^∞∈(a,b) μencitalic_t , νencitalic_t , μdecitalic_t , νdecitalic_t t = 0∞ ∈ ( a , b ) for some 0<a<b<∞00<a<b<∞0 < a < b < ∞. These together entail that (Attouch et al., 2010, Assumptions (ℋ)ℋ(H)( H ) and (ℋ1)subscriptℋ1(H_1)( H1 )) are satisfied. 4. F, G, and ρ are piecewise (real) analytic functions with finitely many pieces. This entails that ℒLL is a continuous and piecewise analytic function, with finitely many pieces. Note C.3. In the ReLU case, we must modify ρ for C.2(2) to hold. For example taking ρ to be any smooth and analytic approximation of ReLU, such as the Swish activation function, will suffice. In the TopK case, ρ=TopKTopKρ=TopKρ = TopK is not continuous. This can however be patched by using the following analytic approximation of TopK: for any T>00T>0T > 0, let ρT⁢(v):=∑S⊆1,…,d,|S|=K(∑S′exp⁡(1T⁢∑j∈S′vj))−1⁢exp⁡(1T⁢∑j∈Svj)⁢PS⁢v.assignsubscriptsubscriptformulae-sequence1…superscriptsubscriptsuperscript′1subscriptsuperscript′subscript11subscriptsubscriptsubscript _T(v):= _S \1,...,d\,|S|=K ( _S % ( 1T _j∈ S v_j ) )^-1 ( % 1T _j∈ Sv_j )P_Sv.ρitalic_T ( v ) := ∑S ⊆ 1 , … , d , | S | = K ( ∑S′ exp ( divide start_ARG 1 end_ARG start_ARG T end_ARG ∑j ∈ S′ vitalic_j ) )- 1 exp ( divide start_ARG 1 end_ARG start_ARG T end_ARG ∑j ∈ S vitalic_j ) Pitalic_S v . It will then follow that if ρ=ρTsubscriptρ= _Tρ = ρitalic_T then C.2 holds. Furthermore, as T→0→0T→ 0T → 0, ρT⁢(v)→TopK⁡(v)→subscriptTopK _T(v) (v)ρitalic_T ( v ) → TopK ( v ), so long as v has a unique set of largest K entries. Proof of C.3. It is straightforward to check that if ρ is smooth and (real) analytic, then the ρ-dependent conditions of C.2 will be satisfied. As for the convergence of ρTsubscript _Tρitalic_T to TopKTopKTopKTopK, let v have a unique set of K largest entries, and denote this set S∗superscriptS^*S∗. Then multiplying the numerator and denominator by exp⁡(−1T⁢∑j∈S∗vj)1subscriptsuperscriptsubscript (- 1T _j∈ S^*v_j )exp ( - divide start_ARG 1 end_ARG start_ARG T end_ARG ∑j ∈ S∗ vitalic_j ) we get that ρT⁢(v)=∑S⊆1,…,d,|S|=Kexp⁡(1T⁢[∑j∈Svj−∑j∈S∗vj])1+∑S′≠S∗exp⁡(1T⁢[∑j∈S′vj−∑j∈S∗vj])⁢PS⁢v.subscriptsubscriptformulae-sequence1…1delimited-[]subscriptsubscriptsubscriptsuperscriptsubscript1subscriptsuperscript′1delimited-[]subscriptsuperscript′subscriptsubscriptsuperscriptsubscriptsubscript _T(v)= _S \1,...,d\,|S|=K ( 1T % [ _j∈ Sv_j- _j∈ S^*v_j ] )1+ _S % ≠ S^* ( 1T [ _j∈ S v_j- _j∈ S% ^*v_j ] )P_Sv.ρitalic_T ( v ) = ∑S ⊆ 1 , … , d , | S | = K divide start_ARG exp ( divide start_ARG 1 end_ARG start_ARG T end_ARG [ ∑j ∈ S vitalic_j - ∑j ∈ S∗ vitalic_j ] ) end_ARG start_ARG 1 + ∑S′ ≠ S∗ exp ( divide start_ARG 1 end_ARG start_ARG T end_ARG [ ∑j ∈ S′ vitalic_j - ∑j ∈ S∗ vitalic_j ] ) end_ARG Pitalic_S v . As T→0→0T→ 0T → 0, if S≠S∗superscriptS≠ S^*S ≠ S∗ then exp⁡(1T⁢[∑j∈Svj−∑j∈S∗vj])→0→1delimited-[]subscriptsubscriptsubscriptsuperscriptsubscript0 ( 1T [ _j∈ Sv_j- _j∈ S^*v_j ]% )→ 0exp ( divide start_ARG 1 end_ARG start_ARG T end_ARG [ ∑j ∈ S vitalic_j - ∑j ∈ S∗ vitalic_j ] ) → 0, since ∑j∈Svj−∑j∈S∗vjsubscriptsubscriptsubscriptsuperscriptsubscript _j∈ Sv_j- _j∈ S^*v_j∑j ∈ S vitalic_j - ∑j ∈ S∗ vitalic_j is stricly negative. Hence, as T→0→0T→ 0T → 0 ρT⁢(v)→PS∗⁢v=TopK⁡(v).→subscriptsubscriptsuperscriptTopK _T(v)→ P_S^*v=TopK(v).ρitalic_T ( v ) → Pitalic_S∗ v = TopK ( v ) . ∎ The theory in Attouch et al. (2010) relies crucially on the Kurdyka–Łojasiewicz property, which we now define. Definition C.4 (Kurdyka–Łojasiewicz property). A proper lower semi-continuous function g:ℝn→(−∞,∞]:→superscriptℝg:R^n→(-∞,∞]g : blackboard_Rn → ( - ∞ , ∞ ] has the Kurdyka–Łojasiewicz property at x^∈dom⁢∂g^dom x ∂ gover start_ARG x end_ARG ∈ dom ∂ g111We denote by dom⁡gdomdomgdom g the set of x such that g⁢(x)<∞g(x)<∞g ( x ) < ∞, and by dom⁢∂gdomdom∂ gdom ∂ g the set of x∈dom⁡gdomx ∈ dom g such that the (limiting) subdifferential of g at x, ∂g⁢(x)∂ g(x)∂ g ( x ) (Attouch et al., 2010, Definition 2.1), is non-empty. if there exist η∈(0,∞]0η∈(0,∞]η ∈ ( 0 , ∞ ], a neighbourhood U of x^ xover start_ARG x end_ARG, and a continuous concave function φ:[0,η)→[0,∞):→00 :[0,η)→[0,∞)φ : [ 0 , η ) → [ 0 , ∞ ), such that • φ φ is C1superscript1C^1C1 with φ⁢(0)=000 (0)=0φ ( 0 ) = 0 and φ′>0superscript′0 >0φ′ > 0 on (0,η)0(0,η)( 0 , η ), and • for all x∈Ux∈ Ux ∈ U such that g⁢(x^)<g⁢(x)<g⁢(x^)+η^^g( x)<g(x)<g( x)+ ( over start_ARG x end_ARG ) < g ( x ) < g ( over start_ARG x end_ARG ) + η, the Kurdyka–Łojasiewicz inequality holds: φ′⁢(g⁢(x)−g⁢(x^))⁢dist⁡(,∂g⁢(x))≥1.superscript′^dist01 (g(x)-g( x))dist(0,∂ g(x))% ≥ 1.φ′ ( g ( x ) - g ( over start_ARG x end_ARG ) ) dist ( 0 , ∂ g ( x ) ) ≥ 1 . If φ⁢(s):=c⁢s1−θassignsuperscript1 (s):=cs^1-θφ ( s ) := c s1 - θ is a valid concave function for the above with c>00c>0c > 0 and θ∈[0,1)01θ∈[0,1)θ ∈ [ 0 , 1 ), then we will say that g has the Kurdyka–Łojasiewicz property with exponent θ at x^ xover start_ARG x end_ARG. Note that if g is differentiable on U and φ⁢(s):=c⁢s1−θassignsuperscript1 (s):=cs^1-θφ ( s ) := c s1 - θ, this inequality becomes c⁢(1−θ)⁢‖∇g⁢(x)‖2≥(g⁢(x)−g⁢(x^))θ.1subscriptnorm∇2superscript^c(1-θ)\|∇ g(x)\|_2≥(g(x)-g( x))^θ.c ( 1 - θ ) ∥ ∇ g ( x ) ∥2 ≥ ( g ( x ) - g ( over start_ARG x end_ARG ) )θ . Theorem C.5. If C.2 holds, then for all Wenc,benc,Wdec,bdecsubscriptencsubscriptencsubscriptdecsubscriptdecW_enc,b_enc,W_dec,b_decWenc , benc , Wdec , bdec, ℒLL has the Kurdyka–Łojasiewicz property, with some exponent θ∈[0,1)01θ∈[0,1)θ ∈ [ 0 , 1 ), at Wenc,benc,Wdec,bdecsubscriptencsubscriptencsubscriptdecsubscriptdecW_enc,b_enc,W_dec,b_decWenc , benc , Wdec , bdec. Proof. Since ℒLL is continuous and piecewise analytic with finitely many pieces, it follows that it is semi-analytic (see Łojasiewicz (1964)). Then if (Wenc,benc,Wdec,bdec)subscriptencsubscriptencsubscriptdecsubscriptdec(W_enc,b_enc,W_dec,b_dec)( Wenc , benc , Wdec , bdec ) is not a critical point of ℒLL, the result follows by (Li & Pong, 2017, Lemma 2.1), and if (Wenc,benc,Wdec,bdec)subscriptencsubscriptencsubscriptdecsubscriptdec(W_enc,b_enc,W_dec,b_dec)( Wenc , benc , Wdec , bdec ) is a critical point, the result follows by (Bolte et al., 2007, Theorem 3.1). ∎ We therefore prove a more detailed version of Theorem 3.1. Theorem C.6. If C.2 holds, then for all Wenc0,benc0,Wdec0,bdec0subscriptsuperscript0encsubscriptsuperscript0encsubscriptsuperscript0decsubscriptsuperscript0decW^0_enc,b^0_enc,W^0_dec,b^0_decW0enc , b0enc , W0dec , b0dec the sequence of SAE parameters Θt:=(Wenct,benct,Wdect,bdect)t=0∞superscriptsubscriptassignsubscriptΘsubscriptsuperscriptencsubscriptsuperscriptencsubscriptsuperscriptdecsubscriptsuperscriptdec0\ _t:=(W^t_enc,b^t_enc,W^t_dec,b^t_% dec)\_t=0^∞ Θitalic_t := ( Witalic_tenc , bitalic_tenc , Witalic_tdec , bitalic_tdec ) t = 0∞ defined by Equation 3.1 obeys: i. ℒ⁢(Θt+1)≤ℒ⁢(Θt)ℒsubscriptΘ1ℒsubscriptΘL( _t+1) ( _t)L ( Θitalic_t + 1 ) ≤ L ( Θitalic_t ) with equality if and only if Θt+1=ΘtsubscriptΘ1subscriptΘ _t+1= _tΘitalic_t + 1 = Θitalic_t. i. Θt+1−Θt→0→subscriptΘ1subscriptΘ0 _t+1- _t→ 0Θitalic_t + 1 - Θitalic_t → 0 as t→∞→t→∞t → ∞, and the number of t such that Θt+1−ΘtsubscriptΘ1subscriptΘ _t+1- _tΘitalic_t + 1 - Θitalic_t has norm greater than some threshold ε>00 >0ε > 0 is proportional to at most ε−2superscript2 ^-2ε- 2. i. Every limit point of Θtt=0∞superscriptsubscriptsubscriptΘ0\ _t\_t=0^∞ Θitalic_t t = 0∞ is a critical point of ℒLL. And furthermore if the sequence Θtt=0∞superscriptsubscriptsubscriptΘ0\ _t\_t=0^∞ Θitalic_t t = 0∞ is bounded: iv. The number of Θt+1−ΘtsubscriptΘ1subscriptΘ _t+1- _tΘitalic_t + 1 - Θitalic_t with norm greater than ε ε is proportional to at most ε−1superscript1 ^-1ε- 1. v. ΘtsubscriptΘ _tΘitalic_t converges to a critical point of ℒLL as t→∞→t→∞t → ∞. Finally, if Θt→Θ∞→subscriptΘsubscriptΘ _t→ _∞Θitalic_t → Θ∞ and θ∈[0,1)01θ∈[0,1)θ ∈ [ 0 , 1 ) is the Kurdyka–Łojasiewicz exponent of ℒLL at Θ∞subscriptΘ _∞Θ∞, then: vi. If θ=00θ=0θ = 0, Θtt=0∞superscriptsubscriptsubscriptΘ0\ _t\_t=0^∞ Θitalic_t t = 0∞ converges after finitely many steps. vii. If θ∈(0,1/2]012θ∈(0,1/2]θ ∈ ( 0 , 1 / 2 ], there exist c>00c>0c > 0 and ζ∈[0,1)01ζ∈[0,1)ζ ∈ [ 0 , 1 ) such that ‖Θt−Θ∞‖≤c⁢ζtnormsubscriptΘsubscriptΘsuperscript\| _t- _∞\|≤ cζ^t∥ Θitalic_t - Θ∞ ∥ ≤ c ζitalic_t. viii. If θ∈(1/2,1)121θ∈(1/2,1)θ ∈ ( 1 / 2 , 1 ), there exists c>00c>0c > 0 such that ‖Θt−Θ∞‖≤c⁢t−1−θ2⁢θ−1normsubscriptΘsubscriptΘsuperscript121\| _t- _∞\|≤ ct^- 1-θ2θ-1∥ Θitalic_t - Θ∞ ∥ ≤ c t- divide start_ARG 1 - θ end_ARG start_ARG 2 θ - 1 end_ARG. Proof. (i) and (i) follow from (Attouch et al., 2010, Lemma 3.1). (i) follows from (Attouch et al., 2010, Proposition 3.1). (iv) and (v) follow from Theorem C.5 and (Attouch et al., 2010, Theorem 3.2). (vi) to (viii) follow from Theorem C.5 and (Attouch et al., 2010, Theorem 3.4). ∎ Note C.7. If F and G include weight decay terms this will not impede C.2, and furthermore will by (i) ensure that Wenct,benct,Wdect,bdectt=0∞superscriptsubscriptsubscriptsuperscriptencsubscriptsuperscriptencsubscriptsuperscriptdecsubscriptsuperscriptdec0\W^t_enc,b^t_enc,W^t_dec,b^t_dec% \_t=0^∞ Witalic_tenc , bitalic_tenc , Witalic_tdec , bitalic_tdec t = 0∞ is always bounded, and hence weight decay ensures convergence of trajectories of Equation 3.1 to critical points of ℒLL. Appendix D Experimental settings All computations were performed on: WS Obsidian 750D AirFlow / AMD Ryzen 9 3900X 12x3.8 Ghz / 2x32GB DDR4 3600 / X570 WS / DIS. NOCTUA / 1000W Platinum / 2TB NVME Ent. / RTX 3090 24GB. All code is available at: https://github.com/splInterp2025/splInterp. D.1 SAEs as a bridge between k-means and PCA • Data sampling: 100 points in 2D, sampled as three clusters: – 40% on a noisy horizontal line (x from −1.51.5-1.5- 1.5 to 00, y≈−0.80.8y≈-0.8y ≈ - 0.8) – 30% in a dense square (x∈[−0.4,0.4]0.40.4x∈[-0.4,0.4]x ∈ [ - 0.4 , 0.4 ], y∈[0.4,1.2]0.41.2y∈[0.4,1.2]y ∈ [ 0.4 , 1.2 ]) – 30% along a noisy diagonal (x≈0.8+t0.8x≈ 0.8+tx ≈ 0.8 + t, y≈t−0.50.5y≈ t-0.5y ≈ t - 0.5, t∈[0,1]01t∈[0,1]t ∈ [ 0 , 1 ]) • Number of data points: 100 • SAE architecture: Linear sparse autoencoder with 80 dictionary elements (d=8080d=80d = 80), 2D input/output, using a Top1 or Top3 sparse coding. • Training method and hyperparameters: – Adam optimiser (learning rate 8×10−38superscript1038× 10^-38 × 10- 3) – 5000 steps, full-batch (all data at once) – Dictionary initialised near data points • Runtime: 6 minutes D.2 PAM-SGD vs. SGD on MNIST • Number of data points and training/test split: Uses the standard MNIST dataset: – 60,000 training images – 10,000 test images – Images are 28×28282828× 2828 × 28 grayscale digits • SAE architecture(s): Two models: – SGD Autoencoder: Linear encoder/decoder, tied weights, sparsity via TopK (K=1515K=15K = 15) or ReLU (without L1 regularisation), 256 latent dimensions – PAM-SGD Autoencoder: Linear encoder, decoder weights solved analytically (not tied), same latent size and sparsity. • Training (hyper)parameters: – Optimiser: Adam (learning rate 0.0030.0030.0030.003) – Batch size: 128 (SGD), 1024 (PAM-SGD encoder update) – Number of epochs: 50 (default, or fewer for small ablation subsets) – K-sparsity: K=1515K=15K = 15 (for TopK) – L1 regularisation: not used – Input dimension: 784 (28×28282828× 2828 × 28) – Ablation studies vary training set size, K-sparsity, number of SGD steps per batch, activation type, weight decay parameters, and cost-to-move parameters. • Running time: 1.5 to 2 hours D.3 PAM-SGD vs. SGD on Gemma • Details on Gemma version and license: Uses activations from Gemma-2-2B (https://huggingface.co/google/gemma-2-2b) (Google, 2.2B parameters). License: see Gemma Terms of Use (https://ai.google.dev/gemma/terms) (accessed 8th May 2025). • Number of data points and training/test split: Up to 10,000 LLM activation vectors extracted from real text (default: 90% train, 10% test split). • SAE architecture(s): Two models: – SGD Autoencoder: Linear encoder/decoder (tied weights), 4096 latent dimensions, sparsity via TopK (K=320320K=320K = 320) or ReLU, with L1 regularisation and “cost-to-move” penalties. – PAM-SGD Autoencoder: Linear encoder, decoder weights solved analytically (not tied), same latent size and sparsity, with additional regularisation. • Training (hyper)parameters: – Optimiser: Adam (learning rate 0.0010.0010.0010.001) – Batch size: 256 (SGD), 2048 (PAM-SGD Encoder update) – Epochs: 100 – K-sparsity: K=320320K=320K = 320 (for TopK) – L1 regularisation: 0.010.010.010.01 (TopK), 0.000010.000010.000010.00001 (ReLU) – “Cost-to-move” and weight decay regularisation for encoder/decoder – Ablation studies vary training set size, K-sparsity, number of SGD steps per batch, activation type, weight decay parameters, and cost-to-move parameters. • Runtimes: 3 to 7 minutes Appendix E Additional figures and ablation studies E.1 SAEs as a bridge between k-means and PCA Figure 8: Visualising the SAE bridge between k-means clustering and PCA. Top-1 SAE. E.2 MNIST experiments E.2.1 TopK experiments PAM-SGD similarly outperforms SGD with TopK. (Figure 9) We tested PAM-SGD using TopK activation for K=1515K=15K = 15. We again saw PAM-SGD outperform SGD, especially at low training data levels. Figure 9: Training and test loss curves at different data sizes for MNIST, with TopK (K=1515K=15K = 15) activation. The chart highlights PAM-SGD’s superior sample efficiency. E.2.2 Reconstruction accuracy and interpretability Figure 10: Learned Dictionary Elements using ReLU. Visualization of encoder and decoder weights as filters: SGD encoder (top row), PAM-SGD encoder (second row), SGD decoder (third row), and PAM-SGD decoder (bottom row). Figure 11: Learned Dictionary Elements using TopK K=1515K=15K = 15. Visualization of encoder and decoder weights as filters: SGD encoder (top row), PAM-SGD encoder (second row), SGD decoder (third row), and PAM-SGD decoder (bottom row). With TopK, PAM-SGD produces more interpretable features representative of digit components Figure 12: Reconstruction Quality Comparison using ReLU. Original MNIST digits (top row) with their reconstructions using SGD optimization (middle row) and PAM-SGD optimization (bottom row). PAM-SGD produces cleaner, more accurate reconstructions. Figure 13: Reconstruction Quality Comparison using TopK K=1515K=15K = 15. Original MNIST digits (top row) with their reconstructions using SGD optimization (middle row) and PAM-SGD optimization (bottom row). PAM-SGD produces cleaner, more accurate reconstructions. Figure 14: Evolution of Reconstruction Quality Over Training using ReLU. Progression of a single digit’s reconstruction across epochs, comparing SGD (top row) and PAM-SGD (bottom row) approaches, showing how representation quality improves with training. Here PAM-SGD converges almost immediately. Figure 15: Evolution of Reconstruction Quality Over Training using TopK K=1515K=15K = 15. Progression of a single digit’s reconstruction across epochs, comparing SGD (top row) and PAM-SGD (bottom row) approaches, showing how representation quality improves with training E.2.3 Ablation study varying SGD updates per batch in PAM-SGD Stability Across SGD Updates. (Figures 16 and 17) Unlike in LLM experiments, PAM-SGD on MNIST is robust to the number of SGD updates per batch. Varying this hyperparameter from 1 to 10 has slightly improves final performance in the ReLU case and has little impact in the TopK case, suggesting that the optimization landscape is smoother and less sensitive in this setting. Figure 16: Effect of Multiple SGD Updates Per Batch on PAM-SGD Performance with ReLU. Test loss decreases with more updates per batch. Figure 17: Effect of Multiple SGD Updates Per Batch on PAM-SGD Performance with TopK K=1515K=15K = 15. Test loss increases with more updates per batch, suggesting simpler optimization (single updates) maintains better balance. E.2.4 TopK and ReLU activation patterns Figure 18: Sparse activation patterns for ReLU activations. Plots showing which latent neurons activate for 5 different input digits, comparing SGD (left) and PAM-SGD (right) models. PAM-SGD activations are roughly five times denser. Figure 19: Sparse activation patterns for TopK K=1515K=15K = 15 activations. Plots showing which latent neurons activate for 5 different input digits, comparing SGD (left) and PAM-SGD (right) models. Each sample activates exactly K=1515K=15K = 15 neurons from the 256-dimensional latent space. E.2.5 Ablation study adding weight decay Small amounts of weight decay make SGD compete with PAM-SGD in the ReLU setting. (Figures 20 and 21) We experimented with adding weight decay in the 100% data setting. In the ReLU setting, we found that small amounts aided SGD performance to be competitive with PAM-SGD and had little effect on PAM-SGD. Increasing weight decay further however degraded both performances, especially PAM-SGD’s. In the TopK setting, weight decay just steadily degraded both performances. Figure 20: Impact of weight decay on final test loss in the ReLU case. Figure 21: Impact of weight decay on final test loss in the TopK K=1515K=15K = 15 case. E.2.6 Ablation study varying μ and ν Sensitivity to quadratic costs to move μ and ν. (Figures 22 and 23) We studied the effect of varying the values of the parameters μenc,μdec,νenc,subscriptencsubscriptdecsubscriptenc _enc, _dec, _enc,μenc , μdec , νenc , and νdecsubscriptdec _decνdec from Equation 3.1, in the 100% data setting. For ReLU activation, very small values for these parameters improve the test loss almost to zero for both SGD (where similar parameters can easily be introduced) and PAM-SGD. Further increases however degrade performance for both, rapidly in the case of PAM-SGD. For TopK increasing these parameters steadily degrades performance in both cases, though this may simply be due to these parameters slowing convergence and therefore worsening performance at the 50 epoch cut-off. Figure 22: Effect of modifying the cost-to-move parameters on final test loss in the ReLU case. Figure 23: Effect of modifying the cost-to-move parameters on final test loss in the TopK K=1515K=15K = 15 case. E.3 LLM Activation Experiments (Gemma–2-2B) E.3.1 Additional ReLU test runs PAM-SGD consistently outperforms SGD. (Figure 24) Owing to the stochasticity of the training algorithms, different results are obtained in every re-run of the training. However, the pattern of PAM-SGD outperforming SGD remained consistent. Figure 24: Final test loss across training set sizes for Gemma-2-2B, with ReLU activation, for two additional training runs. PAM-SGD retains an advantage at low data, and remains superior throughout, in both runs. E.3.2 TopK activation experiments Stability only at high sparsity and underperforms SGD. (Figures 25, 26, 29, 28 and 27) PAM-SGD was highly unstable for low values of K, with the test loss diverging rapidly. Only for larger values was the test loss stable, but fairly stagnant agross epochs even for very large K (over 30% of the hidden dimension) leading us to choose K=320320K=320K = 320 as our default TopK sparsity. We speculate that this is because the LLM reconstruction is sufficiently complicated as to make being able to capture it with small K unrealistic. We furthermore compared SGD and PAM-SGD at various training data sizes for K=320320K=320K = 320 and K=640640K=640K = 640. We found that PAM-SGD consistently underperformed SGD in both cases, with the difference the smallest at low data sizes and again at high data sizes, with a surprising big rise in test loss for medium data sizes (with a maximum around 45%). This peak was consistent across multiple runs, so we suspect it is some fundamental issue perhaps caused by numerical instability. Inspecting the loss curves in the two cases shows that PAM-SGD only has well-behaved training and test loss in the low data regime, or at 100% data in the K=640640K=640K = 640 case. Figure 25: PAM-SGD test loss stable only for high values of K. Figure 26: Final test loss across training set sizes for Gemma-2-2b with TopK activation, for K=320320K=320K = 320. PAM-SGD here consistently underperforms SGD, with a major peak at 45% data. Figure 27: Final test loss across training set sizes for Gemma-2-2b with TopK activation, for K=640640K=640K = 640. PAM-SGD here consistently underperforms SGD, with a major peak at 45% data. Figure 28: Training and test loss curves across training set sizes for Gemma-2-2b with TopK activation, for K=320320K=320K = 320. PAM-SGD has well-behaved training and test loss only for low data. Figure 29: Training and test loss curves across training set sizes for Gemma-2-2b with TopK activation, for K=640640K=640K = 640. PAM-SGD has well-behaved training and test loss only for low data and 100% data. E.3.3 TopK and ReLU activation patterns Sparsity Comparison. (Figures 30, 31 and 32) TopK by design produces a constant sparsity K=320320K=320K = 320 for both SGD and PAM-SGD. ReLU produces much denser activations, around 58.5% (approx. 2400) for SGD and 49.6% (approx. 2000) for PAM-SGD. This increased sparsity from PAM-SGD is an important advantage of the method. Figure 30: Activation sparsity comparison. ReLU yields much denser activations than TopK, and PAM-SGD activations about 15% sparser than SGD activations. Figure 31: Sparse activation patterns. Plots showing which latent neurons were active in the ReLU case, comparing SGD (left) and PAM-SGD (right). Figure 32: Sparse activation patterns. Plots showing which latent neurons were active in the TopK (K=320320K=320K = 320) case, comparing SGD (left) and PAM-SGD (right). E.3.4 Ablation study varying SGD updates per batch in PAM-SGD Number of SGD steps per batch matters for PAM-SGD. (Figures 33 and 34) For both the ReLU and TopK activations (with K=320320K=320K = 320), performance improves slightly when increasing SGD updates per batch from 1 to 3, but degrades beyond that. Too few updates prevent convergence of the inner optimization loop. Too many updates may lead to overfitting within the inner loop or instability due to misaligned gradients. PAM-SGD benefits from a moderate number of decoder updates per batch. An optimal value provides enough adaptation without overfitting, highlighting the importance of tuning this hyperparameter for practical deployments. Figure 33: Effect of SGD Updates per Batch on PAM-SGD Test Loss with ReLU activation. Performance improves up to 3 updates but degrades beyond that, suggesting an optimal trade-off. Figure 34: Effect of SGD Updates per Batch on PAM-SGD Test Loss with TopK activation. Performance again improves up to 3 updates but degrades beyond that. E.3.5 Ablation study adding weight decay Weight decay had a minor effect on performance. (Figures 35 and 36) Weight decay had a very minor effect on the SAE performance using either ReLU or TopK activations, with final test loss relatively constant, and PAM-SGD slightly outperforming SGD in the ReLU case and underperforming SGD in the TopK case. However, in the TopK case large values of weight decay are intially divergent before converging, whilst in the ReLU case this behaviour is less pronounced. Figure 35: Training and test loss curves with various weight decay parameters, with ReLU activation. Figure 36: Training and test loss curves with various weight decay parameters, with TopK K=320320K=320K = 320 activation. E.3.6 Ablation study varying the quadratic costs to move μ and ν Sensitivity to quadratic costs to move μ and ν. (Figure 37) We studied the effect of varying the values of the parameters μenc,μdec,νenc,subscriptencsubscriptdecsubscriptenc _enc, _dec, _enc,μenc , μdec , νenc , and νdecsubscriptdec _decνdec from Equation 3.1. For ReLU activation, we found that very small values of these parameters caused the test loss to begin diverging, perhaps due to numerical instability. Slightly larger values improved performance, but increases beyond that slowed the learning process to no clear gain. In the TopK case, this divergence occurred at larger values than for ReLU, but went away once the parameters were sufficiently large. Figure 37: Loss curves for ReLU activation at different “cost to move” parameters. Figure 38: Loss curves for TopK activation (K=320320K=320K = 320) at different “cost to move” parameters.