← Back to papers

Paper deep dive

Explaining Grokking Through Circuit Efficiency

Vikrant Varma, Rohin Shah, Zachary Kenton, János Kramár, Ramana Kumar

Year: 2023Venue: arXiv preprintArea: Training DynamicsType: EmpiricalEmbeddings: 81

Models: 1-layer transformer (d_model=128, 4 heads, d_mlp=512)

Abstract

Abstract:One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.

Tags

ai-safety (imported, 100%)empirical (suggested, 88%)training-dynamics (suggested, 92%)

Links

Your browser cannot display the PDF inline. Open PDF directly →

Intelligence

Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 93%

Last extracted: 3/12/2026, 7:42:38 PM

Summary

The paper proposes a theory of 'grokking' in neural networks based on circuit efficiency. It posits that grokking occurs when a task allows for both a memorizing circuit (C_mem) and a generalizing circuit (C_gen), where C_gen is more efficient but slower to learn. The authors introduce the concept of a 'critical dataset size' (D_crit) and predict two novel phenomena: 'ungrokking' (regression from generalization to memorization when the dataset shrinks) and 'semi-grokking' (partial generalization when dataset size is near D_crit).

Entities (6)

Grokking · phenomenon · 100%C_gen · circuit · 95%C_mem · circuit · 95%Semi-grokking · phenomenon · 95%Ungrokking · phenomenon · 95%D_crit · metric · 90%

Relation Signals (3)

C_gen ismoreefficientthan C_mem

confidence 90% · C_gen is more efficient than C_mem, that is, it can produce equivalent cross-entropy loss on the training set with a lower parameter norm.

C_mem learnedfasterthan C_gen

confidence 90% · C_gen is learned more slowly than C_mem, such that during early phases of training C_mem is stronger than C_gen.

Ungrokking occurswhen D_crit

confidence 85% · Ungrokking should only be expected once D' < D_crit.

Cypher Suggestions (2)

Map the efficiency relationship between circuits · confidence 95% · unvalidated

MATCH (c1:Circuit)-[r:IS_MORE_EFFICIENT_THAN]->(c2:Circuit) RETURN c1.name, c2.name, r.confidence

Find all phenomena related to grokking · confidence 90% · unvalidated

MATCH (p:Phenomenon)-[:RELATED_TO]->(g:Phenomenon {name: 'Grokking'}) RETURN p

Full Text

81,194 characters extracted from source content.

Expand or collapse full text

5 September 2023 Explaining grokking through circuit efficiency Vikrant Varma *, 1 , Rohin Shah *, 1 , Zachary Kenton 1 , János Kramár 1 and Ramana Kumar 1 * Equal contributions, 1 Google DeepMind One of the most surprising puzzles in neural network generalisation isgrokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but moreefficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours:ungrokking, in which a network regresses from perfect to low test accuracy, andsemi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy. 1. Introduction When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbedgrokking that drastically violates this expectation. The network first “memorises” the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question:why does the network’s test performance improve dramatically upon continued training, having already achieved nearly perfect training performance? Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss ("slingshots") (Thi- lak et al., 2022), random walks among optimal solutions (Millidge, 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically. We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call “circuits” (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well (퐶 gen ) and one which memorises the training dataset (퐶 mem ). The key insight is thatwhen there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high “efficiency”, that is, circuits that require less parameter norm to produce a given logit value. Efficiency answers our question above: if퐶 gen is more efficient than퐶 mem , gradient descent can reduce nearly perfect training loss even further by strengthening퐶 gen while weakening퐶 mem , which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1)퐶 gen generalises well while퐶 mem does not, (2)퐶 gen is more efficient than퐶 mem , and (3)퐶 gen is learned more slowly than퐶 mem . Since퐶 gen generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast,퐶 mem must memorise any additional data points added to the training dataset, and so Corresponding author(s): vikrantvarma@deepmind.com, rohinmshah@deepmind.com arXiv:2309.02390v1 [cs.LG] 5 Sep 2023 Explaining grokking through circuit efficiency Train lossTest lossTest accuracy 0510 Epochs 1e3 10 1 10 2 10 5 Loss 0 1 (a)Grokking.The original be- haviour from Power et al. (2021). We train on a dataset with퐷≫ 퐷 crit . As a result,퐶 gen is strongly preferred to퐶 mem at convergence, and we observe a transition to very low test loss (100% test accuracy). 10 1 10 3 10 5 Epochs 10 1 10 2 10 5 0 1 (b)Ungrokking.If we take a net- work that has already grokked and train it on a new dataset with퐷 < 퐷 crit , the network reverts to signif- icant memorisation, leading to a transition back to poor test loss. (Note the log scale for the x-axis.) 02 Epochs 1e7 10 1 10 2 10 5 0 1 Accuracy (c)Semi-grokking.When퐷∼ 퐷 crit , the memorising algorithm and generalising algorithm com- pete with each other at conver- gence, so we observe a transition to improved but not perfect test loss (test accuracy of 83%). Figure 1|Novel grokking phenomena.When grokking occurs, we expect there are two algorithms that perform well at training: “memorisation” (퐶 mem , with poor test performance) and “generalisation” (퐶 gen , with strong test performance). Weight decay strengthens퐶 gen over퐶 mem as the training dataset size increases. By analysing the point at which퐶 gen and퐶 mem are equally strong (the critical dataset size퐷 crit ), we predict and confirm two novel behaviours:ungrokkingandsemi-grokking. its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both퐶 mem and퐶 gen . This suggests that there exists a crossover point at which퐶 gen becomes more efficient than퐶 mem , which we call the critical dataset size퐷 crit . By analysing dynamics at퐷 crit , we predict and demonstrate two new behaviours (Figure 1). Inungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than퐷 crit . Insemi-grokking, we choose a dataset size where퐶 gen and퐶 mem are similarly efficient, leading to a phase transition but only to middling test accuracy. We make the following contributions: 1.We demonstrate the sufficiency of three ingredients for grokking through a constructed simula- tion (Section 3). 2. By analysing dynamics at the “critical dataset size” implied by our theory, wepredicttwo novel behaviours:semi-grokkingandungrokking(Section 4). 3. We confirm our predictions through careful experiments, including demonstrating semi-grokking and ungrokking in practice (Section 5). 2. Notation We consider classification using deep neural networks under the cross-entropy loss. In particular, we are given a set of inputs푋, a set of labels푌, and a training dataset,D=(푥 1 , 푦 ∗ 1 ), . . .(푥 퐷 , 푦 ∗ 퐷 ). For an arbitrary classifierℎ:푋×푌→ℝ, the softmax cross entropy loss is given by: L x-ent (ℎ)=− 1 퐷 ∑︁ (푥,푦 ∗ )∈D log exp(ℎ(푥, 푦 ∗ )) Í 푦 ′ ∈푌 exp(ℎ(푥, 푦 ′ )) .(1) The output of a classifier for a specific class is the classlogit, denoted by표 푦 ℎ (푥)Bℎ(푥, 푦) . When the 2 Explaining grokking through circuit efficiency input푥is clear from context, we will denote the logit as표 푦 ℎ . We denote the vector of the logits for all classes for a given input as ® 표 ℎ (푥)or ® 표 ℎ when푥is clear from context. Parametric classifiers (such as neural networks) are parameterised with a vector휃that induces a classifierℎ 휃 . The parameter norm of the classifier is푃 ℎ 휃 B∥휃∥. It is common to addweight decay regularisation, which is an additional loss termL wd (ℎ 휃 )= 1 2 (푃 ℎ 휃 ) 2 . The overall loss is given by L(ℎ 휃 )=L x-ent (ℎ)+훼L wd (ℎ 휃 ),(2) where훼is a constant that trades off between softmax cross entropy and weight decay. Circuits.Inspired by Olah et al. (2020), we use the termcircuitto refer to an internal mechanism by which a neural network works. We only consider circuits that map inputs to logits, so that a circuit퐶 induces a classifierℎ 퐶 for the overall task. We elide this distinction and simply write퐶to refer toℎ 퐶 , so that the logits are표 푦 퐶 , the loss isL(퐶), and the parameter norm is푃 퐶 . For any given algorithm, there exist multiple circuits that implement that algorithm. Abusing notation, we use퐶 gen (퐶 mem ) to refer either to thefamilyof circuits that implements the generalising (memorising) algorithm, or a single circuit from the appropriate family. 3. Three ingredients for grokking Given a circuit with perfect training accuracy (as with a pure memorisation approach like퐶 mem or a perfectly generalising solution like퐶 gen ), the cross entropy lossL x-ent incentivises gradient descent to scale up the classifier’s logits, as that makes its answers more confident, leading to lower loss (see Theorem D.1). For typical neural networks, this would be achieved by making the parameters larger. Meanwhile, weight decayL wd pushes in the opposite direction, directly decreasing the parameters. These two forces must be balanced at any local minimum of the overall loss. When we havemultiplecircuits that achieve strong training accuracy, this constraint applies to each individually. But how will they relate to each other? Intuitively, the answer depends on the efficiencyof each circuit, that is, the extent to which the circuit can convert relatively small parameters into relatively large logits. For more efficient circuits, theL x-ent force towards larger parameters is stronger, and theL wd force towards smaller parameters is weaker. So, we expect that more efficient circuits will be stronger at any local minimum. Given this notion of efficiency, we can explain grokking as follows. In the first phase,퐶 mem is learned quickly, leading to strong train performance and poor test performance. In the second phase, 퐶 gen is now learned, and parameter norm is “reallocated” from퐶 mem to퐶 gen , eventually leading to a mixture of strong퐶 gen and weak퐶 mem , causing an increase in test performance. This overall explanation relies on the presence of three ingredients: 1.Generalising circuit:There are two families of circuits that achieve good training performance: a memorising family퐶 mem with poor test performance, and a generalising family퐶 gen with good test performance. 2.Efficiency:퐶 gen is more “efficient” than퐶 mem , that is, it can produce equivalent cross-entropy loss on the training set with a lower parameter norm. 3. Slow vs fast learning:퐶 gen is learned more slowly than퐶 mem , such that during early phases of training퐶 mem is stronger than퐶 gen . To illustrate the sufficiency of these ingredients, we construct a minimal example containing all three ingredients, and demonstrate that it leads to grokking. We emphasise that this example is to be 3 Explaining grokking through circuit efficiency treated as a validation of the three ingredients, rather than as a quantitative prediction of the dynamics of existing examples of grokking. Many of the assumptions and design choices were made on the basis of simplicity and analytical tractability, rather than a desire to reflect examples of grokking in practice. The clearest difference is that퐶 gen and퐶 mem are modelled as hardcoded input-output lookup tables whose outputs can be strengthened through learned scalar weights, whereas in existing examples of grokking퐶 gen and퐶 mem are learned internal mechanisms in a neural network that can be strengthened by scaling up the parameters implementing those mechanisms. Generalisation.To model generalisation, we introduce a training datasetDand a test datasetD test . 퐶 gen is a lookup table that produces logits that achieve perfect train and test accuracy.퐶 mem is a lookup table that achieves perfect train accuracy, but makes confident incorrect predictions on the test dataset. We denote byD mem the predictions made by퐶 mem on the test inputs, with the property that there is no overlap betweenD test andD mem . Then we have: 표 푦 퐺 (푥)=ퟙ [ (푥, 푦) ∈ Dor(푥, 푦) ∈ D test ] 표 푦 푀 (푥)=ퟙ [ (푥, 푦) ∈ Dor(푥, 푦) ∈ D mem ] Slow vs fast learning.To model learning, we introduceweightsfor each of the circuits, and use gradient descent to update the weights. Thus, the overall logits are given by: 표 푦 (푥)=푤 퐺 표 푦 퐺 (푥)+푤 푀 표 푦 푀 (푥) Unfortunately, if we learn푤 퐺 and푤 푀 directly with gradient descent, we have no control over the speedat which the weights are learned. Inspired by Jermyn and Shlegeris (2022), we instead compute weights as multiples of two “subweights”, and then learn the subweights with gradient descent. More precisely, we let푤 퐺 =푤 퐺 1 푤 퐺 2 and푤 푀 =푤 푀 1 푤 푀 2 , and update each subweight according to 푤 푖 ←푤 푖 −휆·휕L/휕푤 푖 . The speed at which the weights are strengthened by gradient descent can then be controlled by the initial values of the weights. Intuitively, the gradient towards the first subweight 휕L/휕푤 1 depends on the strength of the second subweight푤 2 and vice-versa, and so low initial values lead to slow learning. At initialisation, we set푤 퐺 1 =푤 푀 1 =0to ensure the logits are initially zero, and then set푤 퐺 2 << 푤 푀 2 to ensure퐶 gen is learned more slowly than퐶 mem . Efficiency.Above, we operationalised circuit efficiency as the extent to which the circuit can convert relatively small parameters into relatively large logits. When the weights are all one, each circuit produces a one-hot vector as its logits, so their logit scales are the same, and efficiency is determined solely by parameter norm. We define푃 퐺 and푃 푀 to be the parameter norms when weights are all one. Since we want퐶 gen to be more efficient than퐶 mem , we set푃 퐺 < 푃 푀 . This still leaves the question of how to model parameter norm when the weights are not all one. Intuitively, increasing the weights corresponds to increasing the parameters in a neural network to scale up the resulting outputs. In a휅-layer MLP with Relu activations and without biases, scaling all parameters by a constant푐scales the outputs by푐 휅 . Inspired by this observation, we model the parameter norm of푤 퐺 퐶 퐺 as푤 1/휅 퐺 푃 푖 for some휅 >0, and similarly for푤 푀 퐶 푀 . Theoretical analysis.We first analyse the optimal solutions to the setup above. We can ignore the subweights, as they only affect the speed of learning:L x-ent andL wd depend only on the weights, not subweights. Intuitively, to get minimal loss, we must assign higher weights to more efficient circuits – but it is unclear whether we should assignnoweight to less efficient circuits, or merely smaller but still non-zero weights. Theorem D.4 shows that in our example, both of these cases can arise: which one we get depends on the value of휅. 4 Explaining grokking through circuit efficiency Train lossTest lossGen logitMem logitParameter norm 0510 Steps 1e3 10 1 10 0 10 1 10 2 Loss 0 5 10 (a)All three ingredients. When퐶 gen is more efficient than 퐶 mem but learned slower, we observe grokking.퐶 gen only starts to grow significantly by step 2500, and then substitutes for퐶 mem . Total parameter norm falls due to 퐶 gen ’s higher efficiency. 0510 Steps 1e3 10 1 10 0 10 1 10 2 Loss 0 5 10 (b)퐶 gen less efficient than퐶 mem . We set푃 퐺 > 푃 푀 . Since퐶 gen is now less efficient and learned slower, it never grows, and test loss stays high due to퐶 mem throughout train- ing. 0510 Steps 1e3 10 1 10 0 10 1 10 2 Loss 0 5 10 (c)퐶 gen and퐶 mem learned at equal speeds.We set푤 퐺 1 =푤 푀 1 so they are learned equally quickly. 퐶 gen is prioritised at least as much as퐶 mem throughout training, due to its higher efficiency. Thus, test loss is very similar to train loss throughout training, and no grokking is observed. Figure 2|퐶 gen must be learned slowly for grokking to arise.We learn weights푤 푀 and푤 퐺 through gradient descent on the loss in Equation D.1. To model the fact that퐶 gen is more efficient than퐶 mem , we set푃 푀 > 푃 퐺 . We see that we only get grokking when퐶 gen is learned more slowly than퐶 mem . Experimental analysis.We run our example for various hyperparameters, and plot training and test loss in Figure 2. We see that when all three ingredients are present (Figure 2a), we observe the standard grokking curves, with a delayed decrease in test loss. By contrast, when we make the generalising circuit less efficient (Figure 2b), the test loss never falls, and when we remove the slow vs fast learning ingredient (Figure 2c), we see that test loss decreases immediately. See Appendix C for details. 4. Why generalising circuits are more efficient Section 3 demonstrated that grokking can arise when퐶 gen is more efficient than퐶 mem , but left open the question ofwhy퐶 gen is more efficient. In this section, we develop a theory based on training dataset size퐷, and use it to predict two new behaviours:ungrokkingandsemi-grokking. 4.1. Relationship of efficiency with dataset size Consider a classifierℎ D obtained by training on a datasetDof size퐷with weight decay, and a classifierℎ D ′ obtained by training on the same dataset with one additional point:D ′ =D∪(푥, 푦 ∗ ). Intuitively,ℎ D ′ cannot bemoreefficient thanℎ D : if it was, thenℎ D ′ would outperformℎ D even on justD, since it would get similarL x-ent while doing better by weight decay. So we should expect that, on average, classifier efficiency is monotonically non-increasing in dataset size. How does generalisation affect this picture? Let us suppose thatℎ D successfully generalises to predict푦 ∗ for the new input푥. Then, as we move fromDtoD ′ ,L x-ent (ℎ D )likely does not worsen with this new data point. Thus, we could expect to see the same classifier arise, with the same average logit value, parameter norm, and efficiency. Now supposeℎ D insteadfailsto predict the new data point(푥, 푦 ∗ ). Then the classifier learned for 5 Explaining grokking through circuit efficiency D ′ will likely belessefficient:L x-ent (ℎ D )would be much higher due to this new data point, and so the new classifier must incur some additional regularisation loss to reduceL x-ent on the new point. Applying this analysis to our circuits, we should expect퐶 gen ’s efficiency to remain unchanged as퐷increases arbitrarily high, since퐶 gen does need not to change to accommodate new training examples. In contrast,퐶 mem must change with nearly every new data point, and so we should expect its efficiency to decrease as퐷increases. Thus, when퐷is sufficiently large, we expect퐶 gen to be more efficient than퐶 mem . (Note however that when the set of possible inputs is small, even the maximal퐷 may not be “sufficiently large”.) Critical threshold for dataset size.Intuitively, we expect that for extremely small datasets (say, 퐷 <5), it is extremely easy to memorise the training dataset. So, we hypothesise that for these very small datasets,퐶 mem is more efficient than퐶 gen . However, as argued above,퐶 mem will get less efficient as퐷increases, and so there will be a critical dataset size퐷 crit at which퐶 mem and퐶 gen are approximately equally efficient. When퐷≫퐷 crit ,퐶 gen is more efficient and we expect grokking, and when퐷≪퐷 crit ,퐶 mem is more efficient and so grokking should not happen. Effect of weight decay on퐷 crit .Since퐷 crit is determined only by the relative efficiencies of퐶 gen and 퐶 mem , and none of these depends on the exact value of weight decay (just on weight decay being present at all), our theory predicts that퐷 crit shouldnotchange as a function of weight decay. Of course, the strength of weight decay may still affect other properties such as the number of epochs till grokking. 4.2. Implications of crossover: ungrokking and semi-grokking. By thinking through the behaviour around the critical threshold for dataset size, we predict the existence of two phenomena that, to the best of our knowledge, have not previously been reported. Ungrokking.Suppose we take a network that has been trained on a dataset with퐷 > 퐷 crit and has already exhibited grokking, and continue to train it on a smaller dataset with size퐷 ′ < 퐷 crit . In this new training setting,퐶 mem is now more efficient than퐶 gen , and so we predict that with enough further training gradient descent will reallocate weight from퐶 gen to퐶 mem , leading to a transition from high test performance to low test performance. Since this is exactly the opposite observation as in regular grokking, we term this behaviour “ungrokking”. Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once퐷 ′ < 퐷 crit , if we vary퐷 ′ we predict that there will be a sharp transition from very strong to near-random test accuracy (around퐷 crit ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since퐷 crit does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay. Semi-grokking.Suppose we train a network on a dataset with퐷≈퐷 crit .퐶 gen and퐶 mem would be similarly efficient, and there are two possible cases for what we expect to observe (illustrated in Theorem D.4). In the first case, gradient descent would select either퐶 mem or퐶 gen , and then make it the maximal circuit. This could happen in a consistent manner (for example, perhaps since퐶 mem is learned faster it always becomes the maximal circuit), or in a manner dependent on the random initialisation. In either case we would simply observe the presence or absence of grokking. 6 Explaining grokking through circuit efficiency In the second case, gradient descent would produce a mixture of both퐶 mem and퐶 gen . Since neither퐶 mem nor퐶 gen would dominate the prediction on the test set, we would expect middling test performance. 퐶 mem would still be learned faster, and so this would look similar to grokking: an initial phase with good train performance and bad test performance, followed by a transition to significantly improved test performance. Since we only get to middling generalisation unlike in typical grokking, we call this behavioursemi-grokking. Our theory does not say which of the two cases will arise in practice, but in Section 5.3 we find that semi-grokking does happen in our setting. 5. Experimental evidence Our explanation of grokking has some support from from prior work: 1.Generalising circuit:Nanda et al. (2023, Figure 1) identify and characterise the generalising circuit learned at the end of grokking in the case of modular addition. 2. Slow vs fast learning:Nanda et al. (2023, Figure 7) demonstrate “progress measures” showing that the generalising circuit develops and strengthens long after the network achieves perfect training accuracy in modular addition. To further validate our explanation, we empirically test our predictions from Section 4: (P1)Efficiency:We confirm our prediction that퐶 gen efficiency is independent of dataset size, while 퐶 mem efficiency decreases as training dataset size increases. (P2) Ungrokking (phase transition):We confirm our prediction that ungrokking shows a phase transition around퐷 crit . (P3) Ungrokking (weight decay):We confirm our prediction that the final test accuracy after ungrokking is independent of the strength of weight decay. (P4)Semi-grokking:We demonstrate that semi-grokking occurs in practice. Training details.We train 1-layer Transformer models with the AdamW optimiser (Loshchilov and Hutter, 2019) on cross-entropy loss (see Appendix A for more details). All results in this section are on the modular addition task (푎+푏mod푃for푎, 푏∈ (0, . . . , 푃−1)and푃=113) unless otherwise stated; results on 9 additional tasks can be found in Appendix A. 5.1. Relationship of efficiency with dataset size We first test our prediction about memorisation and generalisation efficiency: (P1) Efficiency.We predict (Section 4.1) that memorisation efficiency decreases with increasing train dataset size, while generalisation efficiency stays constant. To test (P1), we look at training runs where only one circuit is present, and see how the logits표 푦 푖 vary with the parameter norm푃 푖 (by varying the weight decay) and the dataset size퐷. Experiment setup.We produce퐶 mem -only networks by using completely random labels for the training data (Zhang et al., 2021), and assume that the entire parameter norm at convergence is allocated to memorisation. We produce퐶 gen -only networks by training on large dataset sizes and checking that>95%of the logit norm comes from just the trigonometric subspace (see Appendix B for details). 7 Explaining grokking through circuit efficiency 2530406090 Parameter norm 30 60 90 Correct logit fixed logit values 0.51.02.04.08.0 Dataset size 1e3 25 30 40 60 90 Parameter norm 0.5 1.0 2.0 4.0 8.0 Dataset size 1e3 50 60 70 Correct logit (a)퐶 mem scatter plot.At a fixed logit value (dot- ted horizontal lines), parameter norm increases with dataset size. 2530406090 Parameter norm 30 60 90 Correct logit fixed logit values 0.51.02.04.08.0 Dataset size 1e3 25 30 40 60 90 Parameter norm 0.5 1.0 2.0 4.0 8.0 Dataset size 1e3 50 60 70 Correct logit (b)퐶 mem isologit curves.Curves go up and right, showing that parameter norm increases with dataset size when holding logits fixed. 273033 Parameter norm 30 60 90 120 Correct logit fixed logit values 468 Dataset size 1e3 27 30 33 Parameter norm 4 6 8 Dataset size 1e3 50 70 90 Correct logit (c)퐶 gen scatter plot.There is no obvious structure to the colours, suggesting that the logit to parameter norm relationship is independent of dataset size. 273033 Parameter norm 30 60 90 120 Correct logit fixed logit values 468 Dataset size 1e3 27 30 33 Parameter norm 4 6 8 Dataset size 1e3 50 70 90 Correct logit (d)퐶 gen isologit curves.The curves are flat, showing that for fixed logit values the parameter norm does not depend on dataset size. Figure 3|Efficiency of the퐶 mem and퐶 gen algorithms.We collect and visualise a dataset of triples 표 푦 , 푃 푚 , 퐷 (correct logit, parameter norm, and dataset size), each corresponding to a training run with varying random seed, weight decay, and dataset size, for both퐶 mem and퐶 gen . Besides a standard scatter plot, we geometrically bucket logit values into six buckets, and plot “isologit curves” showing the dependence of parameter norm on dataset size for each bucket. The results validate our theory that (1)퐶 mem requires larger parameter norm to produce the same logits as dataset size increases, and (2)퐶 gen uses the same parameter norm to produce fixed logits, irrespective of dataset size. In addition,퐶 mem has a much wider range of parameter norms than퐶 gen , and at the extremes can be more efficient than퐶 gen . Results.Figures 3a and 3b confirm our theoretical prediction for memorisation efficiency. Specifically, to produce a fixed logit value, a higher parameter norm is required when dataset size is increased, implying decreased efficiency. In addition, for a fixed dataset size, scaling up logits requires scaling up parameter norm, as expected. Figures 3c and 3d confirm our theoretical prediction for generalisation efficiency. To produce a fixed logit value, the same parameter norm is required irrespective of the dataset size. Note that the figures show significant variance across random seeds. We speculate that there are many different circuits implementing the same overall algorithm, but they have different efficiencies, and the random initialisation determines which one gradient descent finds. For example, in the case of modular addition, the generalising algorithm depends on a set of “key frequencies” (Nanda et al., 2023); different choices of key frequencies could lead to different efficiencies. It may appear from Figure 3c that increasing parameter norm does not increase logit value, contradicting our theory. However, this is a statistical artefact caused by the variance from the random seed. Wedosee “stripes” of particular colours going up and right: these correspond to runs with the same seed and dataset size, but different weight decay, and they show that when the noise from the random seed is removed, increased parameter norm clearly leads to increased logits. 8 Explaining grokking through circuit efficiency 10 3 10 4 Reduced dataset size 0.0 0.5 1.0 Test accuracy x + y mod P 0.1 0.5 1.0 1.5 2.0 Weight decay Average accuracyPer seed accuracy Figure 4|Ungrokking.We train on the full dataset (achieving 100% test accuracy), and then continue training on a smaller subset of the full dataset. We plot test accuracy against reduced dataset size for a range of weight de- cays. We see a sharp transition from strong test accuracy to near-zero test accuracy, that is in- dependent of weight decay (different coloured lines almost perfectly overlap). See Figure 8 for more tasks. 0123 Epochs 1e7 0.0 0.5 1.0 Test accuracy 0.25 0.50 0.75 1.00 Final accuracy Figure 5|Semi-grokking.We plot test accu- racy against training epochs for a large sweep of training runs with varying dataset sizes. Lines are coloured by the final test accuracy at the end of training. Out of 200 runs, at least 6 show clear semi-grokking at the end of training. Many other runs show transient semi-grokking, hover- ing around middling test accuracy for millions of epochs, or having multiple plateaus, before fully generalising. 5.2. Ungrokking: overfitting after generalisation We now turn to testing our predictions about ungrokking. Figure 1b demonstrates that ungrokking happens in practice. In this section we focus on testing that it has the properties we expect. (P2) Phase transition.We predict (Section 4.2) that if we plot test accuracy at convergence against the size of the reduced training dataset퐷 ′ , there will be a phase transition around퐷 crit . (P3) Weight decay.We predict (Section 4.2) that test accuracy at convergence is independent of the strength of weight decay. Experiment setup.We train a network to convergence on the full dataset to enable perfect gener- alisation, then continue training the model on a small subset of the full dataset, and measure the test accuracy at convergence. We vary both the size of the small subset, as well as the strength of the weight decay. Results.Figure 4 shows the results, and clearly confirms both (P2) and (P3). Appendix A has additional results, and in particular Figure 8 replicates the results for many additional tasks. 5.3. Semi-grokking: evenly matched circuits Unlike the previous predictions, semi-grokking is not strictly implied by our theory. However, as we will see, it turns out that it does occur in practice. (P4) Semi-grokking.When training at around퐷≈퐷 crit , where퐶 mem and퐶 gen have roughly equal efficiencies, the final network at convergence should either be entirely composed of the most efficient circuit, or of roughly equal proportions of퐶 mem and퐶 gen . If the latter, we should observe a transition to middling test accuracy well after near-perfect train accuracy. There are a number of difficulties in demonstrating an example of semi-grokking in practice. First, the time to grok increases super-exponentially as the dataset size퐷decreases (Power et al., 2021, Figure 1), and퐷 crit is significantly smaller than the smallest dataset size at which grokking 9 Explaining grokking through circuit efficiency has been demonstrated. Second, the random seed causes significant variance in the efficiency of 퐶 gen and퐶 mem , which in turn affects퐷 crit for that run. Third, accuracy changes sharply with the퐶 gen to 퐶 mem ratio (Appendix A). To observe a transition to middling accuracy, we need to have balanced 퐶 gen and퐶 mem outputs, but this is difficult to arrange due to the variance with random seed. To address these challenges, we run many different training runs, on dataset sizes slightlyaboveour best estimate of the typical퐷 crit , such that some of the runs will (through random noise) have an unusually inefficient퐶 gen or an unusually efficient퐶 mem , such that the efficiencies match and there is a chance to semi-grok. Experiment setup.We train 10 seeds for each of 20 dataset sizes evenly spaced in the range [1500,2050](somewhat above our estimate of퐷 crit ). Results.Figure 1c shows an example of a single run that demonstrates semi-grokking, and Figure 5 shows test accuracies over time for every run. These validate our initial hypothesis that semi-grokking may be possible, but also raise new questions. In Figure 1c, we see two phenomena peculiar to semi-grokking: (1) test accuracy “spikes” several times throughout training before finally converging, and (2) training loss fluctuates in a set range. We leave investigation of these phenomena to future work. In Figure 5, we observe that there is oftentransientsemi-grokking, where a run hovers around middling test accuracy for millions of epochs, or has multiple plateaus, before generalising perfectly. We speculate that each transition corresponds to gradient descent strengthening a new generalising circuit that is more efficient than any previously strengthened circuit, but took longer to learn. We would guess that if we had trained for longer, many of the semi-grokking runs would exhibit full grokking, and many of the runs that didn’t generalise at all would generalise at least partially to show semi-grokking. Given the difficulty of demonstrating semi-grokking, we only run this experiment on modular addition. However, our experience with modular addition shows that if we only care about values at convergence, we can find them much faster by ungrokking from a grokked network (instead of semi-grokking from a randomly initialised network). Thus the ungrokking results on other tasks (Figure 8) provide some support that we would see semi-grokking on those tasks as well. 6. Related work Grokking.Since Power et al. (2021) discovered grokking, many works have attempted to understand why it happens. Thilak et al. (2022) suggest that “slingshots” could be responsible for grokking, particularly in the absence of weight decay, and Notsawo Jr et al. (2023) discuss a similar phenomenon of “oscillations”. In contrast, our explanation applies even where there are no slingshots or oscillations (as in most of our experiments). Liu et al. (2022) show that for a specific (non-modular) addition task with an inexpressive architecture, perfect generalisation occurs when there is enough data to determine the appropriate structured representation. However, such a theory does not explain semi-grokking. We argue that for more typical tasks on which grokking occurs, the critical factor is instead the relative efficiencies of퐶 mem and퐶 gen . Liu et al. (2023) show that grokking occurs even in non-algorithmic tasks when parameters are initialised to be very large, because memorisation happens quickly but it takes longer for regularisation to reduce parameter norm to the “Goldilocks zone” where generalisation occurs. This observation is consistent with our theory: in non-algorithmic tasks, we expect there exist efficient, generalising circuits, and the increased parameter norm creates the final ingredient (slow learning), leading to grokking. However, we expect that in algorithmic tasks such as the ones we study, the slow learning is caused by some factor other than large parameter 10 Explaining grokking through circuit efficiency norm at initialisation. Davies et al. (2023) is the most related work. The authors identify three ingredients for grokking that mirror ours: ageneralisingcircuit that isslowly learnedbut isfavoured by inductive biases, a perspective also mirrored in Nanda et al. (2023, Appendix E). We operationalise the “inductive bias” ingredient asefficiencyat producing large logits with small parameter norm, and to provide significant empirical support by predicting and verifying the existence of the critical threshold퐷 crit and the novel behaviours of semi-grokking and ungrokking. Nanda et al. (2023) identify the trigonometric algorithm by which the networks solve modular addition after grokking, and show that it grows smoothly over training, and Chughtai et al. (2023) extend this result to arbitrary group compositions. We use these results to define metrics for the strength of the different circuits (Appendix B) which we used in preliminary investigations and for some results in Appendix A (Figures 9 and 10). Merrill et al. (2023) show similar results on sparse parity: in particular, they show that a sparse subnetwork is responsible for the well-generalising logits, and that it grows as grokking happens. Weight decay.While it is widely known that weight decay can improve generalisation (Krogh and Hertz, 1991), the mechanisms for this effect are multifaceted and poorly understood (Zhang et al., 2018). We propose a mechanism that is, to the best of our knowledge, novel: generalising circuits tend to be more efficient than memorising circuits at large dataset sizes, and so weight decay preferentially strengthens the generalising circuits. Understanding deep learning through circuit-based analysis.One goal ofinterpretabilityis to understand the internal mechanisms by which neural networks exhibit specific behaviours, often through the identification and characterisation of specific parts of the network, especially compu- tational subgraphs (“circuits”) that implement human-interpretable algorithms (Cammarata et al., 2021; Elhage et al., 2021; Erhan et al., 2009; Geva et al., 2020; Li et al., 2022; Meng et al., 2022; Olah et al., 2020; Wang et al., 2022). Such work can also be used to understand deep learning. Olsson et al. (2022) explain a phase change in the training of language models by reference to induction heads, a family of circuits that produce in-context learning. In concurrent work, Singh et al. show that the in-context learning from induction heads is later replaced by in-weights learning in the absence of weight decay, but remains strong when weight decay is present. We hypothesise that this effect is also explained through circuit efficiency: the in-context learning from induction heads is a generalising algorithm and so is favoured by weight decay given a large enough dataset size. Michaud et al. (2023) propose an explanation for power-law scaling (Barak et al., 2022; Hestness et al., 2017; Kaplan et al., 2020) based on a model in which there are many discrete quanta (algorithms) and larger models learn more of them. Our explanation involves a similar structure: we posit the existence of two algorithms (quanta), and analyse the resulting training dynamics. 7. Discussion Rethinking generalisation.Zhang et al. (2021) pose the question of why deep neural networks achieve good generalisation even when they are easily capable of memorising a random labelling of the training data. Our results gesture at a resolution: in the presence of weight decay, circuits that generalise well are likely to be more efficient given a large enough dataset, and thus are preferred over memorisation circuits, even when both achieve perfect training loss (Section 4.1). Similar arguments may hold for other types of regularisation as well. 11 Explaining grokking through circuit efficiency Necessity of weight decay.The biggest limitation of our explanation is that it relies crucially on weight decay, but grokking has been observed even when weight decay is not present (Power et al., 2021; Thilak et al., 2022) (though it is slower and often much harder to elicit (Nanda et al., 2023, Appendix D.1)). This demonstrates that our explanation is incomplete. Does it also imply that our explanation isincorrect? We do not think so. We speculate there is at least one other effect that has a similar regularising effect favouring퐶 gen over퐶 mem , such as the implicit regularisation of gradient descent (Lyu and Li, 2019; Smith and Le, 2017; Soudry et al., 2018; Wang et al., 2021), and that the speed of the transition from퐶 mem to퐶 gen is based on thesumof these effects and the effect from weight decay. This would neatly explain why grokking takes longer as weight decay decreases (Power et al., 2021), and does not completely vanish in the absence of weight decay. Given that there is a potential extension of our theory that explains grokking without weight decay, and the significant confirming evidence that we have found for our theory in settings with weight decay, we are overall confident that our explanation is at least one part of the true explanation when weight decay is present. Broader applicability: beyond parameter norm.Another limitation is that we only consider one kind of constraint that gradient descent must navigate: parameter norm. Typically, there are many other constraints: fitting the training data, capacity in “bottleneck activations” (Elhage et al., 2021), interference between circuits (Elhage et al., 2022), and more. This may limit the broader applicability of our theory, despite its success in explaining grokking. Broader applicability: realistic settings.A natural question is what implications our theory has for more realistic settings. We expect that the general concepts of circuits, efficiency, and speed of learning continue to apply. However, in realistic settings, good training performance is typically achieved when the model has many different circuit families that contribute different aspects (e.g. language modelling requires spelling, grammar, arithmetic, etc). We expect that these will have a wide variety of learning speeds and efficiencies (although note that “efficiency” is not as well defined in this setting, because the circuits don’t get perfect training accuracy). In contrast, the key property for grokking in “algorithmic” tasks like modular arithmetic is that there are two clusters of circuit families – one slowly learned, efficient, generalising cluster, and one quickly learned, inefficient, memorising cluster. In particular, our explanation relies on there being no circuits in between the two clusters. Therefore we observe a sharp transition in test performance when shifting from the memorising to the generalising cluster. Future work.Within grokking, several interesting puzzles are still left unexplained. Why does the time taken to grok rise super-exponentially as dataset size decreases? How does the random initialisation interact with efficiency to determine which circuits are found by gradient descent? What causes generalising circuits to develop slower? Investigating these puzzles is a promising avenue for further work. While the direct application of our work is to understand the puzzle of grokking, we are excited about the potential for understanding deep learning more broadly through the lens of circuit efficiency. We would be excited to see work looking at the role of circuit efficiency in more realistic settings, and work that extends circuit efficiency to consider other constraints that gradient descent must navigate. 8. Conclusion The central question of our paper is: in grokking, why does the network’s test performance improve dramatically upon continued training, having already achieved nearly perfect training performance? Our explanation is: the generalising solution is more “efficient” but slower to learn than the memorising 12 Explaining grokking through circuit efficiency solution. After quickly learning the memorising circuit, gradient descent can still decrease loss even further by simultaneously strengthening the efficient, generalising circuit and weakening the inefficient, memorising circuit. Based on our theory we predict and demonstrate two novel behaviours:ungrokking, in which a model that has perfect generalisation returns to memorisation when it is further trained on a dataset with size smaller than the critical threshold, andsemi-grokking, where we train a randomly initialised network on the critical dataset size which results in a grokking-like transition to middling test accuracy. Our explanation is the only one we are aware of that has made (and confirmed) such surprising advance predictions, and we have significant confidence in the explanation as a result. Acknowledgements Thanks to Paul Christiano, Xander Davies, Seb Farquhar, Geoffrey Irving, Tom Lieberum, Eric Michaud, Vlad Mikulik, Neel Nanda, Jonathan Uesato, and anonymous reviewers for valuable discussions and feedback. References B. Barak, B. Edelman, S. Goel, S. Kakade, E. Malach, and C. Zhang. Hidden progress in deep learning: SGD learns parities near the computational limit.Advances in Neural Information Processing Systems, 35:21750–21764, 2022. N. Cammarata, G. Goh, S. Carter, C. Voss, L. Schubert, and C. Olah. Curve circuits.Distill, 2021. doi: 10.23915/distill.00024.006. https://distill.pub/2020/circuits/curve-circuits. B. Chughtai, L. Chan, and N. Nanda. A toy model of universality: Reverse engineering how networks learn group operations.arXiv preprint arXiv:2302.03025, 2023. X. Davies, L. Langosco, and D. Krueger. Unifying grokking and double descent.arXiv preprint arXiv:2303.06173, 2023. N. Elhage, N. Nanda, C. Olsson, T. Henighan, N. Joseph, B. Mann, A. Askell, Y. Bai, A. Chen, T. Conerly, N. DasSarma, D. Drain, D. Ganguli, Z. Hatfield-Dodds, D. Hernandez, A. Jones, J. Kernion, L. Lovitt, K. Ndousse, D. Amodei, T. Brown, J. Clark, J. Kaplan, S. McCandlish, and C. Olah. A mathematical framework for Transformer circuits.Transformer Circuits Thread, 2021. https://transformer- circuits.pub/2021/framework/index.html. N. Elhage, T. Hume, C. Olsson, N. Schiefer, T. Henighan, S. Kravec, Z. Hatfield-Dodds, R. Lasenby, D. Drain, C. Chen, R. Grosse, S. McCandlish, J. Kaplan, D. Amodei, M. Wattenberg, and C. Olah. Toy models of superposition.Transformer Circuits Thread, 2022. https://transformer- circuits.pub/2022/toy_model/index.html. D. Erhan, Y. Bengio, A. Courville, and P. Vincent. Visualizing higher-layer features of a deep network. University of Montreal, 1341(3):1, 2009. M. Geva, R. Schuster, J. Berant, and O. Levy. Transformer feed-forward layers are key-value memories. arXiv preprint arXiv:2012.14913, 2020. J. Hestness, S. Narang, N. Ardalani, G. Diamos, H. Jun, H. Kianinejad, M. Patwary, M. Ali, Y. Yang, and Y. Zhou. Deep learning scaling is predictable, empirically.arXiv preprint arXiv:1712.00409, 2017. 13 Explaining grokking through circuit efficiency A. Jermyn and B. Shlegeris.Multi-component learning and s-curves, 2022. URLhttps://w.alignmentforum.org/posts/RKDQCB6smLWgs2Mhr/ multi-component-learning-and-s-curves. J. Kaplan, S. McCandlish, T. Henighan, T. B. Brown, B. Chess, R. Child, S. Gray, A. Radford, J. Wu, and D. Amodei. Scaling laws for neural language models.arXiv preprint arXiv:2001.08361, 2020. A. Krogh and J. Hertz. A simple weight decay can improve generalization.Advances in neural information processing systems, 4, 1991. K. Li, A. K. Hopkins, D. Bau, F. Viégas, H. Pfister, and M. Wattenberg. Emergent world representations: Exploring a sequence model trained on a synthetic task.arXiv preprint arXiv:2210.13382, 2022. Z. Liu, O. Kitouni, N. Nolte, E. J. Michaud, M. Tegmark, and M. Williams. Towards understanding grokking: An effective theory of representation learning.arXiv preprint arXiv:2205.10343, 2022. Z. Liu, E. J. Michaud, and M. Tegmark. Omnigrok: Grokking beyond algorithmic data, 2023. I. Loshchilov and F. Hutter. Decoupled weight decay regularization.arXiv preprint arXiv:1711.05101, 2019. K. Lyu and J. Li. Gradient descent maximizes the margin of homogeneous neural networks.arXiv preprint arXiv:1906.05890, 2019. M. McCloskey and N. Cohen. Catastrophic interference in connectionist networks: The sequential learning problem.Psychology of Learning and Motivation, 24:109–165, 1989. K. Meng, D. Bau, A. Andonian, and Y. Belinkov. Locating and editing factual knowledge in GPT.arXiv preprint arXiv:2202.05262, 2022. W. Merrill, N. Tsilivis, and A. Shukla. A tale of two circuits: Grokking as competition of sparse and dense subnetworks.arXiv preprint arXiv:2303.11873, 2023. E. J. Michaud, Z. Liu, U. Girit, and M. Tegmark. The quantization model of neural scaling.arXiv preprint arXiv:2303.13506, 2023. B. Millidge.Grokking ’grokking’, 2022.URLhttps://w.beren.io/ 2022-01-11-Grokking-Grokking/. N. Nanda, L. Chan, T. Liberum, J. Smith, and J. Steinhardt. Progress measures for grokking via mechanistic interpretability.arXiv preprint arXiv:2301.05217, 2023. P. Notsawo Jr, H. Zhou, M. Pezeshki, I. Rish, G. Dumas, et al. Predicting grokking long before it happens: A look into the loss landscape of models which grok.arXiv preprint arXiv:2306.13253, 2023. C. Olah, N. Cammarata, L. Schubert, G. Goh, M. Petrov, and S. Carter. Zoom in: An introduction to circuits.Distill, 2020. doi: 10.23915/distill.00024.001. https://distill.pub/2020/circuits/zoom-in. C. Olsson, N. Elhage, N. Nanda, N. Joseph, N. DasSarma, T. Henighan, B. Mann, A. Askell, Y. Bai, A. Chen, T. Conerly, D. Drain, D. Ganguli, Z. Hatfield-Dodds, D. Hernandez, S. Johnston, A. Jones, J. Kernion, L. Lovitt, K. Ndousse, D. Amodei, T. Brown, J. Clark, J. Kaplan, S. McCandlish, and C. Olah. In-context learning and induction heads.Transformer Circuits Thread, 2022. https://transformer- circuits.pub/2022/in-context-learning-and-induction-heads/index.html. 14 Explaining grokking through circuit efficiency A. Power, Y. Burda, H. Edwards, I. Babuschkin, and V. Misra. Grokking: Generalization beyond overfitting on small algorithmic datasets. InMathematical Reasoning in General Artificial Intelligence Workshop, ICLR, 2021. URLhttps://mathai-iclr.github.io/papers/papers/MATHAI_ 29_paper.pdf. R. Ratcliff. Connectionist models of recognition memory: constraints imposed by learning and forgetting functions.Psychology Review, 97(2):285–308, 1990. A. Singh, S. Chan, T. Moskovitz, E. Grant, A. Saxe, and F. Hill. The transient nature of emergent in-context learning in transformers.Forthcoming. S. L. Smith and Q. V. Le. A bayesian perspective on generalization and stochastic gradient descent. arXiv preprint arXiv:1710.06451, 2017. D. Soudry, E. Hoffer, M. S. Nacson, S. Gunasekar, and N. Srebro. The implicit bias of gradient descent on separable data.The Journal of Machine Learning Research, 19(1):2822–2878, 2018. V. Thilak, E. Littwin, S. Zhai, O. Saremi, R. Paiss, and J. Susskind. The slingshot mechanism: An empir- ical study of adaptive optimizers and the grokking phenomenon.arXiv preprint arXiv:2206.04817, 2022. A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. u. Kaiser, and I. Polo- sukhin. Attention is all you need. InAdvances in Neural Information Processing Systems, volume 30, 2017. URLhttps://proceedings.neurips.c/paper_files/paper/2017/ file/3f5e243547dee91fbd053c1c4a845a-Paper.pdf. B. Wang, Q. Meng, W. Chen, and T.-Y. Liu. The implicit bias for adaptive optimization algorithms on homogeneous neural networks. InInternational Conference on Machine Learning, pages 10849– 10858. PMLR, 2021. K. Wang, A. Variengien, A. Conmy, B. Shlegeris, and J. Steinhardt. Interpretability in the wild: a circuit for indirect object identification in gpt-2 small.arXiv preprint arXiv:2211.00593, 2022. C. Zhang, S. Bengio, M. Hardt, B. Recht, and O. Vinyals. Understanding deep learning (still) requires rethinking generalization.Communications of the ACM, 64(3):107–115, 2021. G. Zhang, C. Wang, B. Xu, and R. Grosse. Three mechanisms of weight decay regularization.arXiv preprint arXiv:1810.12281, 2018. 15 Explaining grokking through circuit efficiency 0.0 0.5 1.0 Accuracy traintest 10 1 10 2 10 5 Loss 0.00.51.01.52.02.53.03.54.0 Epochs 1e7 30 40 50 Parameter norm Figure 6|Examining a single semi-grokking run in detail.We plot accuracy, loss, and parameter norm over training for a single cherry-picked modular addition run at a dataset size of 1532 (12% of the full dataset). This run shows transient semi-grokking. At epoch0.8×10 7 , test accuracy rises to around0.55, and then stays there for10 7 epochs, because퐶 gen and퐶 mem efficiencies are balanced. At epoch1.8×10 7 , we speculate that gradient descent finds an even more efficient퐶 gen circuit, as parameter norm drops suddenly and test accuracy rises to 1. At epoch3.2×10 7 we see test lossrise, we do not know why. There seem to be multiple phases, perhaps corresponding to the network transitioning between mixtures of multiple circuits with increasing efficiencies, but further investigation is needed. A. Experimental details and more evidence For all our experiments, we use 1-layer decoder-only transformer networks (Vaswani et al., 2017) with learned positional embeddings, untied embeddings/unembeddings, The hyperparameters are as follows:푑 model =128is the residual stream width,푑 head =32is the size of the query, key, and value vectors for each attention head,푑 mlp =512is the number of neurons in the hidden layer of the MLP, and we have푑 model /푑 head =4heads per self-attention layer. We optimise the network with full batch training (that is, using the entire training dataset for each update) using the AdamW optimiser (Loshchilov and Hutter, 2019) with훽 1 =0.9,훽 2 =0.98, learning rate of10 −3 , and weight decay of1.0. In some of our experiments we vary the weight decay in order to produce networks with varying parameter norm. Following Power et al. (2021), for a binary operation푥◦푦, we construct a dataset of the form ⟨푥⟩⟨◦⟩⟨푦⟩⟨=⟩⟨푥◦푦⟩, where⟨푎⟩stands for the token corresponding to the element푎. We choose a fraction of this dataset at random as the train dataset, and the remainder as the test dataset. The first 4 tokens⟨푥⟩⟨◦⟩⟨푦⟩⟨=⟩are the input to the network, and we train with cross-entropy loss over the final token⟨푥◦푦⟩. For all modular arithmetic tasks we use the modulus푝=113, so for example the size of the full dataset for modular addition is푝 2 =12769, and푑 vocab =115, including the⟨+⟩and ⟨=⟩tokens. 16 Explaining grokking through circuit efficiency 0.00.20.40.60.81.0 Epochs 1e6 0.0 0.2 0.4 0.6 0.8 1.0 Test accuracy 400 600 800 1000 1200 1400 Reduced dataset size Figure 7|Many ungrokking runs.We show test accuracy over epochs for a range of ungrokking runs for modular addition. Each line represents a single run, and we sweep over 7 geometrically spaced dataset sizes in[390,1494]with 10 seeds each. Each run is initialised with parameters from a network trained on the full dataset (the initialisation runs are not shown), so test accuracy starts at 1 for all runs. When the dataset size is small enough, the network ungroks to poor test accuracy, while train accuracy remains at 1 (not shown). For an intermediate dataset size, we see ungrokking to middling test accuracy as퐶 gen and퐶 mem efficiencies are balanced. A.1. Semi-grokking In Section 5.3 we looked at all semi-grokking training runs in Figure 5. Here, we investigate a single example of transient semi-grokking in more detail (see Figure 6). We speculate that there are multiple circuits with increasing efficiencies for퐶 gen , and in these cases the more efficient circuits are slower to learn. This would explain transient semi-grokking: gradient descent first finds a less efficient퐶 gen and we see middling generalisation, but since we are using the upper range of퐷 crit , eventually gradient descent finds a more efficient퐶 gen leading to full generalisation. A.2. Ungrokking In Figure 7, we show many ungrokking runs for modular addition, and in Figure 8 we show ungrokking across many other tasks. We have already seen that퐷 crit is affected by the random initialisation. It is interesting to compare 퐷 crit when starting with a given random initialisation, and when ungrokking from a network that was trained to full generalisation with the same random initialisation. Figure 5 shows a semi-grokking run that achieves a test accuracy of∼0.7 with a dataset size of∼2000, while Figure 7 shows ungrokking runs that achieve a test accuracy of∼0.7 with a dataset size of around 800–1000, less than half of what the semi-grokking run required. In Figure 10b, the final test accuracy afterungrokkingshows a smooth relationship with dataset size, which we might expect if퐶 gen is getting stronger on a smoothly increasing number of inputs compared to퐶 mem . However due to the difficulties discussed previously, we don’t see a smooth relationship between test accuracy and dataset size insemi-grokking. These results suggest that퐷 crit is an oversimplified concept, because in reality the initialisation and training dynamics affect which circuits are found, and therefore the dataset size at which we see middling generalisation. 17 Explaining grokking through circuit efficiency 10 3 10 4 0.0 0.5 1.0 Test accuracy x/y mod P 10 3 10 4 xy mod P 10 3 10 4 xy mod P 10 3 10 4 0.0 0.5 1.0 Test accuracy x/y mod P or xy mod P 10 3 10 4 x 2 + y 2 mod P 10 3 10 4 x 2 + xy + y 2 mod P 10 3 10 4 Reduced dataset size 0.0 0.5 1.0 Test accuracy xyx 1 for x, yS 5 10 3 10 4 Reduced dataset size xy for x, yS 5 10 3 10 4 Reduced dataset size xyx for x, yS 5 0.1 0.5 1.0 1.5 2.0 Weight decay Average accuracyPer seed accuracy Figure 8|Ungrokking on many other tasks.We plot test accuracy against reduced dataset size for many other modular arithmetic and symmetric group tasks (Power et al., 2021). For each run, we train on the full dataset (achieving 100% accuracy), and then further train on a reduced subset of the dataset for 100k steps. The results show clear ungrokking, since in many cases test accuracy falls below 100%, often to nearly 0%. For most datasets the transition point is independent of weight decay (different coloured lines almost perfectly overlap). 18 Explaining grokking through circuit efficiency 0200040006000800010000 Epochs 10 9 10 7 10 5 10 3 10 1 10 1 Loss Train lossTest loss 0 15 30 45 60 C trig logitC mem logit Parameter norm Figure 9|Grokking occurs because퐶 gen is more efficient than퐶 mem .We show loss, parameter norm, and the value of the correct logit for퐶 gen and퐶 mem for a randomly-picked training run. By step 200, the train accuracy is already perfect (not shown), train loss is low while test loss has risen, and parameter norm is at its maximum value, indicating strong퐶 mem . Train loss continues to fall rapidly until step 1500, as parameter norm falls and the퐶 gen logit becomes higher than the퐶 mem logit. At step 3500, test loss starts to fall as the high퐶 gen logit starts to dominate, and by step 6000 we get good generalisation. A.3.퐶 gen and퐶 mem development during grokking In Figure 9 we show퐶 gen and퐶 mem development via the proxy measures defined in Appendix B for a randomly-picked grokking run. Looking at these measures was very useful to form a working theory for why grokking happens. However as we note in Appendix B, these proxy measures tend to overestimate퐶 gen and underestimate퐶 mem . We note some interesting phenomena in Figure 9: 1. Between epochs 200 to 1500,boththe퐶 gen and퐶 mem logits are rising while parameter norm is falling, indicating that gradient descent is improving efficiency (possibly by removing irrelevant parameters). 2. After epoch 4000, the퐶 gen logitfallswhile the퐶 mem logit is already∼0. Since test loss continues to fall, we expect that incorrect logits from퐶 mem on the test dataset are getting cleaned up, as described in Nanda et al. (2023). A.4. Tradeoffs between퐶 gen and퐶 mem In Section 4.1 we looked at the efficiency of퐶 gen -only and퐶 mem -only circuits. In this section we train on varying dataset sizes so that the network develops a mixture of퐶 gen and퐶 mem circuits, and study their relative strength using the correct logit as a proxy measure (described in Appendix B). As we demonstrated previously,퐶 mem ’s efficiency drops with increasing dataset size, while퐶 gen ’s stays constant. Theorem D.4 (case 2) suggests that parameter norm allocated to a circuit is proportional to efficiency, and since logit values also increase with parameter norm, this implies that the ratio of the퐶 gen to퐶 mem logit표 푦 푡 /표 푦 푚 should increase monotonically with dataset size. In Figure 10a we see exactly this: the logit ratio changes monotonically (over 6 orders of magnitude) with increasing dataset size. Due to the difficulties in training to convergence at small dataset sizes, we initialised all parameters 19 Explaining grokking through circuit efficiency 10 3 10 4 Dataset size 10 2 10 0 10 2 10 4 o y t / o y m (a)Logit ratio (표 푦 푡 /표 푦 푚 )vs dataset size (퐷). Colours correspond to different bucketed val- ues of parameter norm (푃). Each line shows that as dataset size increases, a fixed param- eter norm (fixed colour) is being reallocated smoothly towards increasing the trigonometric logit compared to the memorisation logit. 10 3 10 4 Dataset size 0.0 0.2 0.4 0.6 0.8 1.0 Test accuracy (b)Test accuracy vs dataset size (퐷).We see a smooth dependence on dataset size. Each line shows that as dataset size in- creases, the reallocation of a fixed parame- ter norm (fixed colour) towards퐶 gen from 퐶 mem results in increasing accuracy. 2 × 10 1 3 × 10 1 4 × 10 1 Parameter norm Figure 10|Relative strength at convergence.We report logit ratios and test accuracy at convergence across a range of training runs, generated by sweeping over the weight decay and random seed to obtain different parameter norms at the same dataset size. We use the ungrokking runs from Figure 7, so every run is initialised with parameters obtained by training on the full dataset. from a퐶 gen -only network trained on the full dataset. We confirmed that in all the runs, at convergence, the training loss from the퐶 gen -initialised network was lower than the training loss from a randomly initialised network, indicating that this initialisation allows our optimiser to find a better minimum than from random initialisation. B.퐶 gen and퐶 mem in modular addition In modular addition, given two integers푎, 푏and a modulus푝as input, where0≤푎, 푏 < 푝, the task is to predict푎+푏mod푝. Nanda et al. (2023) identified the generalising algorithm implemented by a 1-layer transformer after grokking (visualised in Figure 11), which we call the “trigonometric” algorithm. In this section we summarise the algorithm, and explain how we produce our proxy metrics for the strength of퐶 gen and퐶 mem . Trigonometric logits.We explain the structure of the logits produced by the trigonometric algorithm. For each possible label푐∈ 0,1, . . . 푝−1, the trigonometric logit표 푐 will be given by Í 휔 푘 cos(휔 푘 (푎+ 푏−푐)), for a few key frequencies휔 푘 =2휋 푘 푝 with integer푘. For the true label푐 ∗ =푎+푏mod푝, the term휔 푘 (푎+푏−푐 ∗ )is an integer multiple of2휋, and socos(휔 푘 (푎+푏−푐 ∗ ))=1. For any incorrect label 푐≠푎+푏mod푝, it is very likely that at leastsomeof the key frequencies satisfycos(휔 푘 (푎+푏−푐)) ≪1, creating a large difference between표 푐 and표 푐 ∗ . Trigonometric algorithm.There is a set of key frequencies휔 푘 . (These frequencies are typically whichever frequencies were highest at the time of random initialisation.) For an arbitrary label푐, the logit표 푐 is computed as follows: 1.Embed the one-hot encoded number푎tosin(휔 푘 푎)andcos(휔 푘 푎)for the various frequencies휔 푘 . 20 Explaining grokking through circuit efficiency Figure 11|The trigonometric algorithm for modular arithmetic(reproduced from Nanda et al. (2023)). Given two numbers푎and푏, the model projects each point onto a corresponding rotation using its embedding matrix. Using its attention and MLP layers, it then composes the rotations to get a representation of푎+푏mod푝. Finally, it “reads off” the logits for each푐∈ 0,1, . . . , 푝−1, by rotating by−푐to get푐표푠(휔(푎+푏−푐)), which is maximised when푎+푏≡푐mod푃(since휔is a multiple of2휋). Do the same for푏. 2. Computecos(휔 푘 (푎+푏))andsin(휔 푘 (푎+푏))using the intermediate attention and MLP layers via the trigonometric identities: cos(휔 푘 (푎+푏))=cos(휔 푘 푎)cos(휔 푘 푎)−sin(휔 푘 푎)sin(휔 푘 푏) sin(휔 푘 (푎+푏))=sin(휔 푘 푎)cos(휔 푘 푏)+cos(휔 푘 푎)sin(휔 푘 푏) 3. Use the output and unembedding matrices to implement the trigonometric identity: 표 푐 = ∑︁ 휔 푘 cos(휔 푘 (푎+푏−푐))= ∑︁ 휔 푘 cos(휔 푘 (푎+푏))cos(휔 푘 푐)+sin(휔 푘 (푎+푏))sin(휔 푘 푐). Isolating trigonometric logits.Given a classifierℎ, we can aggregate its logits on every possible input, resulting in a vector ® 푍 ℎ of length푝 3 where ® 푍 푎,푏,푐 ℎ =표 푐 ℎ (“푎+푏=”)is the logit for label푐on the input(푎, 푏). We are interested in identifying the contribution of the trigonometric algorithm to ® 푍 ℎ . We use the same method as Chughtai et al. (2023) and restrict ® 푍 ℎ to a much smaller trigonometric subspace. For a frequency휔 푘 , let us define the푝 3 -dimensional vector ® 푍 휔 푘 as ® 푍 푎,푏,푐 휔 푘 =cos(휔 푘 (푎+푏−푐)). Since ® 푍 휔 푘 = ® 푍 휔 푝−푘 , we set1≤푘≤퐾, where퐾=⌈(푝−1)/2⌉, to obtain퐾distinct vectors, ignoring the constant bias vector. These vectors are orthogonal, as they are part of a Fourier basis. Notice that any circuit that was exactly following the learned algorithm described above would only produce logits in the directions ® 푍 휔 푘 for the key frequencies휔 푘 . So, we can define the trigonometric contribution to ® 푍 ℎ as the projection of ® 푍 ℎ onto the directions ® 푍 휔 푘 . We may not know the key frequencies in advance, but we can sum over all퐾of them, giving the following definition for trigonometric logits: ® 푍 ℎ,푇 = 퐾 ∑︁ 푘=1 ( ® 푍 ℎ · ˆ 푍 휔 푘 ) ˆ 푍 휔 푘 21 Explaining grokking through circuit efficiency where ˆ 푍 휔 푘 is the normalised version of ® 푍 휔 푘 . This corresponds to projecting onto a퐾-dimensional subspace of the푝 3 -dimensional space in which ® 푍 ℎ lives. Memorisation logits.Early in training, neural networks memorise the training dataset without generalising, suggesting that there exists a memorisation algorithm, implemented by the circuit 퐶 mem 1 . Unfortunately, we do not understand the algorithm underlying memorisation, and so cannot design a similar procedure to isolate퐶 mem ’s contribution to the logits. However, we hypothesise that for modular addition,퐶 gen and퐶 mem are the only two circuit families of importance for the loss. This allows us to define the퐶 mem contribution to the logits as the residual: ® 푍 ℎ,푀 = ® 푍 ℎ − ® 푍 ℎ,푇 퐶 trig and퐶 mem circuits.We say that a circuit is a퐶 trig circuit if it implements the퐶 trig algorithm, and similarly for퐶 mem circuits. Importantly, this is a many-to-one mapping: there are many possible circuits that implement a given algorithm. We isolate퐶 trig ( ® 표 푡 ) and퐶 mem ( ® 표 푚 ) logits by projecting the output logits ( ® 표 ) as described in Appendix B. We cannot directly measure the circuit weights푤 푡 and푤 푚 , but instead use an indirect measure: the value of the logit for the correct class given by each circuit, i.e.표 푦 푡 and표 푦 푚 . FlawsThese metrics should be viewed as an imperfect proxy measure for the true strength of the trigonometric and memorisation circuits, as they have a number of flaws: 1.When both퐶 trig and퐶 mem are present in the network, they are both expected to produce high values for the correct logits, and low values for incorrect logits, on the train dataset. Since the 퐶 trig and퐶 mem logits are correlated, it becomes more likely that ® 푍 ℎ,푇 captures퐶 mem logits too. 2.In this case we would expect our proxy measure to overestimate the strength of퐶 trig and underestimate the strength of퐶 mem . In fact, in our experiments we do see large negative correct logit values for퐶 mem on training for semi-grokking, which probably arises because of this effect. 3.Logits are not inherently meaningful; what matters for loss is the extent to which the correct logit is larger than the incorrect logits. This is not captured by our proxy metric, which only looks at the size of the correct logit. In a binary classification setting, we could instead use the difference between the correct and incorrect logit, but it is not clear what a better metric would be in the multiclass setting. C. Details for the minimal example In Figure 2 we show that two ingredients: multiple circuits with different efficiencies, and slow and fast circuit development, are sufficient to reproduce learning curves that qualitatively demonstrate grokking. In Table 1 we provide details about the simulation used to produce this figure. As explained in Section 3, the logits produced by퐶 gen and퐶 mem are given by: 표 푦 퐺 (푥)=ퟙ [ (푥, 푦) ∈ Dor(푥, 푦) ∈ D test ] (3) 표 푦 푀 (푥)=ퟙ [ (푥, 푦) ∈ Dor(푥, 푦) ∈ D mem ] (4) 1 In reality, there are at least two different memorisation algorithms: commutative memorisation (which predicts the same answer for(푎, 푏)and(푏, 푎)) and non-commutative memorisation (which does not). However, this difference does not matter for our analyses, and we will call both of these “memorisation” in this paper. 22 Explaining grokking through circuit efficiency Table 1|Hyperparameters used for our simulations. (a)퐶 gen learned slower but more efficient than퐶 mem . Parameter Value 푃 푔 1 푃 푚 2 휅1.2 훼0.005 푤 푔 1 (0)0 푤 푔 2 (0)0.005 푤 푚 1 (0)0 푤 푚 2 (0)1 푞113 휆0.01 (b)퐶 gen less efficient than퐶 mem . Parameter Value 푃 푔 4 푃 푚 2 휅1.2 훼0.005 푤 푔 1 (0)0 푤 푔 2 (0)0.005 푤 푚 1 (0)0 푤 푚 2 (0)1 푞113 휆0.01 (c)퐶 gen and퐶 mem learned at equal speeds. Parameter Value 푃 푔 1 푃 푚 2 휅1.2 훼0.005 푤 푔 1 (0)0 푤 푔 2 (0)1 푤 푚 1 (0)0 푤 푚 2 (0)1 푞113 휆0.01 These are scaled by two independent weights for each circuit, giving the overall logits as: 표 푦 (푥)=푤 퐺 1 푤 퐺 2 표 푦 퐺 (푥)+푤 푀 1 푤 푀 2 표 푦 푀 (푥)(5) We model the parameter norms according to the scaling efficiency in Section D.1, inspired by a 휅-layer MLP with Relu activations and without biases: 푃 ′ 푐 =(푤 푐 1 푤 푐 2 ) 1/휅 푃 푐 for푐∈ (푔, 푚). From Equations (3) to (5) we get the following equations for train and test loss respectively: L train =−log exp(푤 푔 1 푤 푔 2 +푤 푚 1 푤 푚 2 ) (푞−1)+exp(푤 푔 1 푤 푔 2 +푤 푚 1 푤 푚 2 ) +L wd , L test =−log exp(푤 푔 1 푤 푔 2 ) (푞−2)+exp(푤 푔 1 푤 푔 2 )+exp(푤 푚 1 푤 푚 2 ) +L wd , where푞is the number of labels, and the weight decay loss is: L wd =푃 ′2 푔 +푃 ′2 푚 . The weights푤 푐 푖 are updated based on gradient descent: 푤 푐 푖 (휏) ←푤 푐 푖 (휏−1)−휆 휕L train 휕푤 푐 푖 where휆is a learning rate. The initial values of the parameters are푤 푐 푖 (0). In Table 1 we list the values of the simulation hyperparameters. D. Proofs of theorems We assume we have a set of inputs푋, a set of labels푌, and a training dataset,D=(푥 1 , 푦 1 ), . . .(푥 퐷 , 푦 퐷 ). Letℎbe a classifier that assigns a real-valued logit for each possible label given an input. We denote 23 Explaining grokking through circuit efficiency an individual logit as표 푦 ℎ (푥)Bℎ(푥, 푦). When the input푥is clear from context, we will denote the logit as표 푦 ℎ . Excluding weight decay, thelossfor the classifier is given by the softmax cross-entropy loss: L x-ent (ℎ)=− 1 퐷 ∑︁ (푥,푦)∈D log exp(표 푦 ℎ ) Í 푦 ′ ∈푌 exp(표 푦 ′ ℎ ) . For any푐∈ℝ, let푐·ℎbe the classifier whose logits are multipled by푐, that is,(푐·ℎ)(푥, 푦)=푐×ℎ(푥, 푦). Intuitively, once a classifier achieves perfect accuracy, then the true class logit표 푦 ∗ will be larger than any incorrect class logit표 푦 ′ , and so loss can be further reduced by scaling upallof the logits further (increasing the gap between표 푦 ∗ and표 푦 ′ ). Theorem D.1.Suppose that the classifierℎhas perfect accuracy, that is, for any(푥, 푦 ∗ ) ∈ Dand any 푦 ′ ≠푦 ∗ we have표 푦 ∗ ℎ > 표 푦 ′ ℎ . Then, for any푐 >1, we haveL x-ent (푐·ℎ)<L x-ent (ℎ). Proof.First, note that we can rewrite the loss function as: L x-ent (ℎ)=− 1 퐷 ∑︁ (푥,푦 ∗ ) log exp(표 푦 ∗ ℎ ) Í 푦 ′ exp(표 푦 ′ ℎ ) = 1 퐷 ∑︁ (푥,푦 ∗ ) log © ­ ­ ­ « Í 푦 ′ exp(표 푦 ′ ℎ ) exp(표 푦 ∗ ℎ ) ª ® ® ® ¬ = 1 퐷 ∑︁ (푥,푦 ∗ ) log © ­ « 1+ ∑︁ 푦 ′ ≠푦 ∗ exp(표 푦 ′ ℎ −표 푦 ∗ ℎ ) ª ® ¬ Since we are given that표 푦 ∗ ℎ > 표 푦 ′ ℎ , for any푐 >1we have푐(표 푦 ′ ℎ −표 푦 ∗ ℎ ))< 표 푦 ′ ℎ −표 푦 ∗ ℎ . Sinceexp,log, and sums are all monotonic, this gives us our desired result: L x-ent (푐·ℎ)= 1 퐷 ∑︁ (푥,푦 ∗ ) log © ­ « 1+ ∑︁ 푦 ′ ≠푦 ∗ exp(푐(표 푦 ′ ℎ −표 푦 ∗ ℎ )) ª ® ¬ < 1 퐷 ∑︁ (푥,푦 ∗ ) log © ­ « 1+ ∑︁ 푦 ′ ≠푦 ∗ exp(표 푦 ′ ℎ −표 푦 ∗ ℎ ) ª ® ¬ =L x-ent (ℎ). □ We now move on to Theorem D.4. First we establish some basic lemmas that will be used in the proof: Lemma D.2.Let푎, 푏, 푟∈ℝwith푎, 푏≥0and0< 푟≤1. Then(푎+푏) 푟 ≤푎 푟 +푏 푟 . Proof.The case with푎=0or푏=0is clear, so let us consider푎, 푏 >0. Let푥= 푎 푎+푏 and푦= 푏 푎+푏 . Since 0≤푥≤1, we have푥 (1−푟) ≤1, which implies푥≤푥 푟 . Similarly푦≤푦 푟 . Thus푥 푟 +푦 푟 ≥푥+푦=1. Substituting in the values of푥and푦we get 푎 푟 +푏 푟 (푎+푏) 푟 ≥1, which when rearranged gives us the desired result.□ Lemma D.3.For any푥, 푐, 푟∈ℝwith푟≥1, there exists some훿 >0such that for any휖 < 훿we have 푥 푟 −(푥−휖) 푟 > 훿(푟푥 푟−1 −푐). Proof.The function푓(푥)=푥 푟 is everywhere-differentiable and has derivative푟푥 푟−1 . Thus we can choose훿such that for any휖 < 훿we have−푐 < 푥 푟 −(푥−휖) 푟 훿 −푟푥 푟−1 < 푐. Rearranging, we get푥 푟 −(푥−휖) 푟 > 훿(푟푥 푟−1 −푐)as desired.□ 24 Explaining grokking through circuit efficiency D.1. Weight decay favours efficient circuits To flesh out the argument in Section 3, we construct a minimal example of multiple circuits퐶 1 , . . . 퐶 퐼 of varying efficiencies that can be scaled up or down through a set of non-negativeweights푤 푖 . Our classifier is given byℎ= Í 퐼 푖=1 푤 푖 퐶 푖 , that is, the outputℎ(푥, 푦)is given by Í 퐼 푖=1 푤 푖 퐶 푖 (푥, 푦). We take circuits퐶 푖 that arenormalised, that is, they produce the same average logit value.푃 푖 denotes the parameter norm of the normalised circuit퐶 푖 . We decide to call a circuit with lower푃 푖 more efficient. However, it is hard to define efficiency precisely. Consider instead the parameter norm푃 ′ 푖 of the scaled circuit푤 푖 퐶 푖 . If we define efficiency as either the ratio∥ ® 표 퐶 푖 ∥/푃 ′ 푖 or the derivative 푑∥ ® 표 퐶 푖 ∥/푑푃 ′ 푖 , then it would vary with푤 푖 since ® 표 퐶 푖 and푃 ′ 푖 can in general have different relationships with푤 푖 . We prefer푃 푖 as a measure of relative efficiency as it is intrinsic to퐶 푖 rather than depending on its scaling푤 푖 . Gradient descent operates over the weights푤 푖 (but not퐶 푖 or푃 푖 ) to minimiseL=L x-ent +훼L wd . L x-ent can easily be rewritten in terms of푤 푖 , but forL wd we need to model the parameter norm of the scaled circuits푤 푖 퐶 푖 . Notice that, in a휅-layer MLP with Relu activations and without biases, scaling all parameters by a constant푐scales the outputs by푐 휅 . Inspired by this observation, we model the parameter norm of푤 푖 퐶 푖 as푤 1/휅 푖 푃 푖 for some휅 >0. This gives the following effective loss: L( ® 푤)=L x-ent 퐼 ∑︁ 푖=1 푤 푖 퐶 푖 ! + 훼 2 퐼 ∑︁ 푖=1 (푤 1 휅 푖 푃 푖 ) 2 We will generalise this to any퐿 푞 -norm (where푞 >0). Standard weight decay corresponds to 푞=2. We will also generalise to arbitrary differentiable, bounded training loss functions, instead of cross-entropy loss specifically. In particular, we assume that there is some differentiableL train (ℎ)such that there exists a finite bound퐵∈ℝsuch that∀ℎ:L train (ℎ) ≥퐵. (In the case of cross-entropy loss, 퐵=0.) With these generalisations, the overall loss is now given by: L( ® 푤)=L train 퐼 ∑︁ 푖=1 푤 푖 퐶 푖 ! + 훼 푞 퐼 ∑︁ 푖=1 (푤 1 휅 푖 푃 푖 ) 푞 (6) The following theorem establishes that the optimal weight vector allocates more weight to more efficient circuits, under the assumption that the circuits produce identical logits on the training dataset. Theorem D.4.Given퐼circuits퐶 푖 and associated퐿 푞 parameter norms푃 푖 , assume that every circuit produces the same logits on the training dataset, i.e.∀푖, 푗,∀(푥,_) ∈ D,∀푦 ′ ∈푌we have표 푦 ′ 퐶 푖 (푥)=표 푦 ′ 퐶 푗 (푥). Then, any weight vector ® 푤 ∗ ∈ℝ 퐼 that minimizes the loss in Equation 6 subject to푤 푖 ≥0satisfies: 1. If휅≥푞, then푤 ∗ 푖 =0for all푖such that푃 푖 >min 푗 푃 푗 . 2. If0< 휅 < 푞, then푤 ∗ 푖 ∝푃 − 푞휅 푞−휅 푖 . Intuition .Since every circuit produces identical logits, their weights are interchangeable with each other from the perspective ofL x-ent , and so we must analyse how interchanging weights affectsL wd . L wd grows as푂(푤 2/휅 푖 ) . When휅 >2,L wd grows sublinearly, and so it is cheaper to add additional weight to thelargestweight, creating a “rich get richer” effect that results in a single maximally efficient circuit getting all of the weight. When휅 <2,L wd grows superlinearly, and so it is cheaper to add additional weight to thesmallestweight. As a result, every circuit is allocated at least some weight, though more efficient circuits are still allocated higher weight than less efficient circuits. 25 Explaining grokking through circuit efficiency Sketch .The assumption that every circuit produces the same logits on the training dataset implies thatL train is purely a function of Í 퐼 푖=1 푤 푖 . So, forL train , a small increase훿푤to푤 푖 can be balanced by a corresponding decrease훿푤to some other weight푤 푗 . ForL wd , an increase훿푤to푤 푖 produces a change of approximately 훿L wd 훿푤 푖 ·훿푤= 훼 휅 ( 푃 푖 (푤 푖 ) 푟 ) 푞 ·훿푤 , where푟= 1 휅 − 1 푞 = 푞−휅 푞휅 . So, an increase of훿푤to푤 푖 can be balanced by a decrease of 푃 푖 (푤 푖 ) 푟 푃 푗 (푤 푗 ) 푟 푞 훿푤to some other weight푤 푗 . The two cases correspond to푟≤0and푟 >0respectively. Case 1:푟≤0.Consider푖, 푗with푃 푗 > 푃 푖 . The optimal weights must satisfy푤 ∗ 푖 ≥푤 ∗ 푗 (else you could swap푤 ∗ 푖 and푤 ∗ 푗 to decrease loss). But then푤 ∗ 푗 must be zero: if not, we could increase푤 ∗ 푖 by훿푤and decrease푤 ∗ 푗 by훿푤, which keepsL x-ent constant and decreasesL wd (since푃 푖 (푤 ∗ 푖 ) 푟 < 푃 푗 (푤 ∗ 푗 ) 푟 ). Case 2:푟 >0.Consider푖, 푗with푃 푗 > 푃 푖 . As before we must have푤 ∗ 푖 ≥푤 ∗ 푗 . But now푤 ∗ 푗 mustnot be zero: otherwise we could increase푤 ∗ 푗 by훿푤and decrease푤 ∗ 푖 by훿푤to keepL x-ent constant and decreaseL wd , since푃 푗 (푤 ∗ 푗 ) 푟 = 0< 푃 푖 (푤 ∗ 푖 ) 푟 . The balance occurs when푃 푗 (푤 ∗ 푗 ) 푟 =푃 푖 (푤 ∗ 푖 ) 푟 , which means 푤 ∗ 푖 ∝푃 −1/푟 푖 . Proof. First, notice that our conclusions trivially hold for ® 푤 ∗ = ® 0(which can be a minimum if e.g. the circuits are worse than random). Thus for the rest of the proof we will assume that at least one weight is non-zero. In addition,L →∞whenever any푤 푖 →∞(becauseL train ≥퐵andL wd →∞as any one푤 푖 →∞). Thus, any global minimum must have finite ® 푤. Notice that, since the circuit logits are independent of푖, we haveℎ= ( Í 푖 푤 푖 ) 푓, and soL train ( ® 푤)is purely a function of the sum of weights Í 퐼 푖=1 푤 푖 , and the overall loss can be written as: L( ® 푤)=L train 퐼 ∑︁ 푖=1 푤 푖 ! + 훼 푞 퐼 ∑︁ 푖=1 ((푤 푖 ) 1 휅 푃 푖 ) 푞 We will now consider each case in order. Case 1:휅≥푞.Assume towards contradiction that there is a global minimum ® 푤 ∗ where푤 ∗ 푗 >0for some circuit퐶 푗 with non-minimal푃 푗 . Let퐶 푖 be a circuit with minimal푃 푖 (so that푃 푖 < 푃 푗 ), and let its weight be푤 ∗ 푖 . Consider an alternate weight assignment ® 푤 ′ that is identical to ® 푤 ∗ except that푤 ′ 푗 =0and 푤 ′ 푖 =푤 ∗ 푖 +푤 ∗ 푗 . Clearly Í 푖 푤 ∗ 푖 = Í 푖 푤 ′ 푖 , and soL train ( ® 푤 ∗ )=L train ( ® 푤 ′ ). Thus, we have: L( ® 푤 ∗ )−L( ® 푤 ′ ) = 훼 푞 퐼 ∑︁ 푚=1 ((푤 ∗ 푚 ) 1 휅 푃 푚 ) 푞 ! − 훼 푞 퐼 ∑︁ 푚=1 ((푤 ′ 푚 ) 1 휅 푃 푚 ) 푞 ! = 훼 푞 (푤 ∗ 푖 ) 푞 휅 푃 푞 푖 +(푤 ∗ 푗 ) 푞 휅 푃 푞 푗 −(푤 ′ 푖 ) 푞 휅 푃 푞 푖 > 훼 푞 푃 푞 푖 (푤 ∗ 푖 ) 푞 휅 +(푤 ∗ 푗 ) 푞 휅 −(푤 ′ 푖 ) 푞 휅 since푃 푗 > 푃 푖 = 훼 푞 푃 푞 푖 (푤 ∗ 푖 ) 푞 휅 +(푤 ∗ 푗 ) 푞 휅 −(푤 ∗ 푖 +푤 ∗ 푗 ) 푞 휅 definition of푤 ′ 푖 ≥ 훼 푞 푃 푞 푖 (푤 ∗ 푖 ) 푞 휅 +(푤 ∗ 푗 ) 푞 휅 − (푤 ∗ 푖 ) 푞 휅 +(푤 ∗ 푗 ) 푞 휅 using Lemma D.2 since0< 푞 휅 ≤1 =0 26 Explaining grokking through circuit efficiency Thus we haveL( ® 푤 ∗ )>L( ® 푤 ′ ), contradicting our assumption that ® 푤 ∗ is a global minimum ofL. This completes the proof for the case that휅≥푞. Case 2:휅 < 푞.First, we will show that all weights are non-zero at a global minimum (excluding the case where ® 푤 ∗ = ® 0 , discussed at the beginning of the proof). Assume towards contradiction that there is a global minimum ® 푤 ∗ with푤 ∗ 푗 = 0for some푗. Choose some arbitrary circuit퐶 푖 with nonzero weight 푤 ∗ 푖 . Choose some휖 1 >0satisfying휖 1 < 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 . By applying Lemma D.3 with푥=푤 ∗ 푖 , 푐=휖 1 , 푟= 푞 휅 , we can get some훿 >0such that for any휖 < 훿we have(푤 ∗ 푖 ) 푞 휅 −(푤 ∗ 푖 −휖) 푞 휅 > 훿( 푞 휅 (푤 ∗ 푖 ) 푞 휅 −1 −휖 1 ). Choose some휖 2 >0satisfying휖 2 <min(푤 ∗ 푖 , 훿, 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 푃 푞 푖 푃 푞 푗 1 푞 휅 −1 ). Consider an alternate weight assignment defined ® 푤 ′ that is identical to ® 푤 ∗ except that푤 ′ 푗 =휖 2 and푤 ′ 푖 =푤 ∗ 푖 −휖 2 . As in the previous case,L train ( ® 푤 ∗ )=L train ( ® 푤 ′ ). Thus, we have: L( ® 푤 ∗ )−L( ® 푤 ′ ) = 훼 푞 (푤 ∗ 푖 ) 푞 휅 푃 푞 푖 −(푤 ∗ 푖 −휖 2 ) 푞 휅 푃 푞 푖 −휖 푞 휅 2 푃 푞 푗 = 훼 푞 푃 푞 푖 ((푤 ∗ 푖 ) 푞 휅 −(푤 ∗ 푖 −휖 2 ) 푞 휅 )−휖 푞 휅 2 푃 푞 푗 > 훼 푞 푃 푞 푖 훿( 푞 휅 (푤 ∗ 푖 ) 푞 휅 −1 −휖 1 )−휖 푞 휅 2 푃 푞 푗 application of Lemma D.3 discussed above > 훼 푞 푃 푞 푖 훿( 푞 휅 (푤 ∗ 푖 ) 푞 휅 −1 − 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 )−휖 푞 휅 2 푃 푞 푗 we chose휖 1 < 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 > 훼 푞 푃 푞 푖 휖 2 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 −휖 푞 휅 2 푃 푞 푗 we chose휖 2 < 훿 = 훼휖 2 푞 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 푃 푞 푖 −휖 푞 휅 −1 2 푃 푞 푗 > 훼휖 2 푞 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 푃 푞 푖 − 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 푃 푞 푖 푃 푞 푗 푃 푞 푗 ! we chose휖 2 < " 푞 2휅 (푤 ∗ 푖 ) 푞 휅 −1 푃 푞 푖 푃 푞 푗 # 1 푞 휅 −1 ) =0 Note that in the last step, we rely on the fact that휅 < 푞: this lets us use an upper bound on휖 2 to get an upper bound on휖 푞 휅 −1 2 , and so a lower bound on the overall expression. Thus we haveL( ® 푤 ∗ )>L( ® 푤 ′ ), contradicting our assumption that ® 푤 ∗ is a global minimum ofL. So, for all푖we have푤 푖 >0. In addition, as푤 푖 →∞we haveL( ® 푤) →∞, so ® 푤 ∗ cannot be at the boundaries, and instead lies in the interior. Since푞 > 휅,L( ® 푤)is differentiable everywhere. Thus, we can conclude that its gradient 27 Explaining grokking through circuit efficiency at ® 푤 ∗ is zero: 훿L 훿푤 푖 =0 훿L train 훿푤 푖 + 훼푃 푞 푖 휅 (푤 ∗ 푖 ) 푞 휅 −1 =0 푃 푞 푖 (푤 ∗ 푖 ) 푞−휅 휅 =− 휅 훼 훿L train 훿푤 푖 푤 ∗ 푖 푃 푞휅 푞−휅 푖 = − 휅 훼 훿L train 훿푤 푖 휅 푞−휅 SinceL train ( ® 푤)is a function of 퐼 Í 푗=1 푤 푗 , we can conclude that 훿L train 훿푤 푖 = 훿L train 훿 Í 푗 푤 푗 · 훿 Í 푗 푤 푗 훿푤 푖 = 훿L train 훿 Í 푗 푤 푗 , which is independent of푖. So the right hand side of the equation is independent of푖, allowing us to conclude that푤 ∗ 푖 ∝푃 − 푞휅 푞−휅 푖 . □ 28