Paper deep dive
Let Me Grok for You: Accelerating Grokking via Embedding Transfer from a Weaker Model
Zhiwei Xu, Zhiyu Ni, Yixin Wang, Wei Hu
Models: Fully-connected neural networks, Transformers (small)
Abstract
Abstract:''Grokking'' is a phenomenon where a neural network first memorizes training data and generalizes poorly, but then suddenly transitions to near-perfect generalization after prolonged training. While intriguing, this delayed generalization phenomenon compromises predictability and efficiency. Ideally, models should generalize directly without delay. To this end, this paper proposes GrokTransfer, a simple and principled method for accelerating grokking in training neural networks, based on the key observation that data embedding plays a crucial role in determining whether generalization is delayed. GrokTransfer first trains a smaller, weaker model to reach a nontrivial (but far from optimal) test performance. Then, the learned input embedding from this weaker model is extracted and used to initialize the embedding in the target, stronger model. We rigorously prove that, on a synthetic XOR task where delayed generalization always occurs in normal training, GrokTransfer enables the target model to generalize directly without delay. Moreover, we demonstrate that, across empirical studies of different tasks, GrokTransfer effectively reshapes the training dynamics and eliminates delayed generalization, for both fully-connected neural networks and Transformers.
Tags
Links
- Source: https://arxiv.org/abs/2504.13292
- Canonical: https://arxiv.org/abs/2504.13292
PDF not stored locally. Use the link above to view on the source site.
Intelligence
Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 95%
Last extracted: 3/12/2026, 6:14:53 PM
Summary
The paper introduces 'GrokTransfer', a method to accelerate the 'grokking' phenomenon in neural networks by transferring learned input embeddings from a smaller, weaker model to a larger target model. This approach reshapes training dynamics, allowing the target model to generalize without the typical delayed phase transition observed in standard training, as validated on synthetic XOR tasks and various algorithmic benchmarks.
Entities (5)
Relation Signals (3)
GrokTransfer → accelerates → Grokking
confidence 98% · proposes GrokTransfer, a simple and principled method for accelerating grokking in training neural networks
GrokTransfer → utilizes → Embedding
confidence 95% · based on the key observation that data embedding plays a crucial role in determining whether generalization is delayed
GrokTransfer → appliedto → Transformer
confidence 90% · eliminates delayed generalization, for both fully-connected neural networks and Transformers
Cypher Suggestions (2)
Find all methods that address the grokking phenomenon · confidence 90% · unvalidated
MATCH (m:Method)-[:ACCELERATES]->(p:Phenomenon {name: 'Grokking'}) RETURN m.nameList tasks where GrokTransfer has been evaluated · confidence 85% · unvalidated
MATCH (m:Method {name: 'GrokTransfer'})-[:APPLIED_TO]->(t:Task) RETURN t.nameFull Text
269,981 characters extracted from source content.
Expand or collapse full text
Let Me Grok for You: Accelerating Grokking via Embedding Transfer from a Weaker Model Zhiwei Xu† † , Zhiyu Ni‡ ‡∗, Yixin Wang† † ♢ ♢, Wei Hu† † ♢ ♢ † †University of Michigan, ‡ ‡University of California, Berkeley zhiweixu,yixinw,vvh@umich.edu,zhiyuni@berkeley.edu Equal contribution; ♢ ♢ Equal advising. Abstract “Grokking” (Power et al., 2022) is a phenomenon where a neural network first memorizes training data and generalizes poorly, but then suddenly transitions to near-perfect generalization after prolonged training. While intriguing, this delayed generalization phenomenon compromises predictability and efficiency. Ideally, models should generalize directly without delay. To this end, this paper proposes GrokTransfer, a simple and principled method for accelerating grokking in training neural networks, based on the key observation that data embedding plays a crucial role in determining whether generalization is delayed. GrokTransfer first trains a smaller, weaker model to reach a nontrivial (but far from optimal) test performance. Then, the learned input embedding from this weaker model is extracted and used to initialize the embedding in the target, stronger model. We rigorously prove that, on a synthetic XOR task where delayed generalization always occurs in normal training, GrokTransfer enables the target model to generalize directly without delay. Moreover, we demonstrate that, across empirical studies of different tasks, GrokTransfer effectively reshapes the training dynamics and eliminates delayed generalization, for both fully-connected neural networks and Transformers. 1 Introduction “Grokking” is an intriguing phenomenon recently discovered by Power et al. (2022), where a neural network first memorizes the training dataset but has poor test performance, and after much longer training, it suddenly transitions to near-perfect generalization. Initially reported for Transformer models trained on modular arithmetic tasks, the grokking phenomenon has since been observed in other settings such as learning group operations (Chughtai et al., 2023), sparse parity (Barak et al., 2022), and image classification (Liu et al., 2023). While grokking is an interesting phenomenon, it introduces unpredictability into the training process and compromises its practical efficiency. When the model has interpolated the training data with small training loss but still performed poorly on the validation set, it becomes difficult to predict whether or when the model will eventually generalize. Ideally, we would like the model to make continuous progress during training, keeping the gap between training and validation errors minimal. This raises the question: How can we effectively modify the training dynamics so that the model generalizes without delay? In this work, we show that data embedding plays a crucial role in determining the training dynamics; an informative embedding enables continuous progress during training. To obtain such an informative embedding without excessive computational cost, we propose a novel method called GrokTransfer, which leverages the embedding learned by a weaker, smaller model to accelerate the generalization of a larger target model. See Figure 1(a) for an overview of GrokTransfer. Specifically, GrokTransfer involves two main steps: (1) Train a weaker model until it groks to non-trivial test performance; (2) Extract the weak model’s learned embedding and use a linear mapping of this embedding to initialize the embedding of the target model. Then, proceed to train the target model. We theoretically study GrokTransfer in the setting of a two-layer neural network trained on a high-dimensional XOR classification task, where normal training exhibits grokking. We prove that GrokTransfer enables the target model to directly generalize without delay. We further empirically verify the effectiveness of GrokTransfer on typical algorithmic tasks that show grokking. This is done for both fully-connected neural networks with trainable embeddings and Transformers. Figure 1(b) shows typical training curves of GrokTransfer vs. training a target model from scratch, on a modular addition task. It shows that GrokTransfer effectively eliminates grokking and significantly improves efficiency. In summary, our contributions are as follows: • We propose a novel method, GrokTransfer, which leverages the embedding learned from a smaller, weaker model to accelerate grokking in the target model. • We theoretically justify GrokTransfer in an XOR classification task. We further empirically validate our method on several algorithmic tasks that exhibit grokking in normal training, demonstrating that GrokTransfer can effectively eliminate delayed generalization. (a) (b) Figure 1: (a) Overview of the GrokTransfer framework. (b) Comparison of the training dynamics of a model trained using GrokTransfer versus one trained from scratch. There is a clear phase transition between memorization and generalization if we train the model from scratch (blue lines). GrokTransfer (red lines) enables the model to make continuous progress, significantly reducing the gap between memorization and generalization. See Appendix A.3 for the detailed experimental setup. 1.1 Related Work Our work draws on two themes around grokking and weak-to-strong knowledge transfer. Grokking. Liu et al. (2022) reported that the model starts grokking when it learns the hidden structure of the data. Gromov (2023) showed that grokking is robust to different optimizers such as vanilla gradient descent and Adam; and regularization methods including no regularization, weight decay, and dropout. Davies et al. (2023) hypothesized that grokking and double descent, another surprising phenomenon, are caused by the same hidden mechanism. Nanda et al. (2023) reverse-engineered a grokked transformer model for modular addition and reported that the learned algorithm is a composition of trigonometric and inverse trigonometric functions. Merrill et al. (2023) and Varma et al. (2023) contributed to the occurrence of grokking to the competition of sparse (generalizing) and dense (complementary) subnetworks during training. Zhu et al. (2024) showed that models only grok when the training data exceeds some critical size. Liu et al. (2023) attributed grokking to large initialization scale and induced grokking on real-world datasets such as MNIST and IMDb by initializing models with large weight norm. Further work (Miller et al., 2023; Humayun et al., 2024) showed that grokking can also be observed in other scenarios such as Gaussian Process regression and multi-class classification with adversarial samples. A series of theoretical papers have established rigorous results for grokking/delayed generalization in several settings outside of algorithmic tasks: linear regression with linear models (Žunkovič & Ilievski, 2022), and binary classification with neural networks (Lyu et al., 2024; Xu et al., 2024). Lyu et al. (2024) proved that grokking can be induced by a sharp phase transition from kernel regime to rich regime. Mallinar et al. (2024) trained Recursive Feature Machines on algorithmic tasks and found its training dynamics similar to neural networks, showing that grokking is not restricted to neural networks. He et al. (2024); Wang et al. (2024) found transformers achieve out-of-distribution generalization on some tasks through grokking. Doshi et al. (2024) provided analytical solutions for complex modular arithmetic tasks and hypothesized that some complex modular polynomial tasks cannot be learned by shallow neural networks. Mohamadi et al. (2024) showed that learning modular addition is fundamentally hard for neural networks in the kernel regime. A related phenomenon, termed “sudden drop in the loss” (Chen et al., 2024; Gopalani et al., 2025; Yang et al., 2025), describes an abrupt drop in loss after an extended plateau during online training. Recent work has proposed several methods to accelerate grokking. Liu et al. (2023) explained grokking through the concept of a “Goldilocks zone”, a spherical shell of weights, and found that restricting the weight norm to a sphere of the appropriate radius during training can accelerate generalization. However, this method introduces instability in the training process and still involves a phase transition. Furuta et al. (2024) suggested initializing the model with weights or embeddings from another model that has already generalized on a different task may accelerate grokking, which needs to train the same model on additional data, while our method do not need additional data. Lee et al. (2024) decomposed the gradient at each step and accelerated grokking by amplifying part of the gradient. Interestingly, Minegishi et al. (2024) demonstrated that the gap between memorization and generalization can be nearly eliminated if a lottery ticket, a set of sparse mask matrices, is applied to the model during training. However, this lottery ticket can only be obtained by first training the same model under the same initialization till generalization. In contrast, our approach can nearly eliminate the phase transition without requiring additional data or pretraining on the same model. Weak to strong knowledge transfer. Burns et al. (2023) proposed a method where a small model is first fine-tuned as a teacher model. This teacher model is then used to generate pseudo-labels to fine-tune a larger student model. Surprisingly, the student model can outperform the teacher. Wang et al. (2023) designed a learned linear growth operator, which uses a learnable linear map of a pretrained small model’s weights as the initialization for the large model’s weights, to accelerate the training of large models. In contrast to these works, our method focuses on transferring the embedding layer from a weaker model to the target model and reshaping the training dynamics to accelerate grokking. 1.2 Notation For a set S with finite elements, we denote its cardinality by |S||S|| S | and use Uniform(S)UniformUniform(S)Uniform ( S ) to represent the uniform distribution over S. We denote the set 1,2,⋯,n12⋯\1,2,·s,n\ 1 , 2 , ⋯ , n by [n]delimited-[][n][ n ]. We use sgn(x)sgn sgn(x)sgn ( x ) to represent the sign of a scalar x. For a matrix A∈ℝn×msuperscriptℝA ^n× mA ∈ blackboard_Rn × m, we denote by Ai,⋅=[Ai,1,⋯,Ai,m]subscript⋅subscript1⋯subscriptA_i,·=[A_i,1,·s,A_i,m]Aitalic_i , ⋅ = [ Aitalic_i , 1 , ⋯ , Aitalic_i , m ] the i-th row, Ai:j,⋅=[Ai,⋅⊤,⋯,Aj,⋅⊤]⊤∈ℝ(j−i+1)×msubscript:⋅superscriptsuperscriptsubscript⋅top⋯superscriptsubscript⋅toptopsuperscriptℝ1A_i:j,·=[A_i,· ,·s,A_j,· ] ∈% R^(j-i+1)× mAitalic_i : j , ⋅ = [ Aitalic_i , ⋅⊤ , ⋯ , Aitalic_j , ⋅⊤ ]⊤ ∈ blackboard_R( j - i + 1 ) × m the i-th to j-th rows, and ‖A‖FsubscriptnormF\|A\|_F∥ A ∥F the Frobenius norm. We use ϕ(x)=max0,xitalic-ϕ0φ(x)= \0,x\ϕ ( x ) = max 0 , x to represent the ReLU activation function. We denote the inner product between two vectors a,ba,ba , b by ⟨a,b⟩ a,b ⟨ a , b ⟩. For two sequences xnsubscript\x_n\ xitalic_n and ynsubscript\y_n\ yitalic_n , we say xn=O(yn)subscriptsubscriptx_n=O(y_n)xitalic_n = O ( yitalic_n ) if there exists some constant C>00C>0C > 0 such that xn≤Cynsubscriptsubscriptx_n≤ Cy_nxitalic_n ≤ C yitalic_n for all n and xn=Ω(yn)subscriptΩsubscriptx_n= (y_n)xitalic_n = Ω ( yitalic_n ) if yn=O(xn)subscriptsubscripty_n=O(x_n)yitalic_n = O ( xitalic_n ). 2 Accelerating Grokking via Embedding Transfer from a Weaker Model 2.1 Motivation: The Role of Data Embedding To demonstrate the pivotal role of data embedding in shaping training dynamics, we examine the modular addition task a+b mod p mod a+b mod pa + b mod p. Following settings in Nanda et al. (2023) and Liu et al. (2023), we take p=113113p=113p = 113. The dataset consists of ((a,b),y)0≤a,b≤p−1subscriptformulae-sequence01\((a,b),y)\_0≤ a,b≤ p-1 ( ( a , b ) , y ) 0 ≤ a , b ≤ p - 1 with label y=(a+b) mod p mod y=(a+b) mod py = ( a + b ) mod p. 25%percent2525\%25 % of the dataset is randomly sampled as the training set. We evaluate four types of embeddings: • One-hot embedding: Each integer a∈[0,p−1]01a∈[0,p-1]a ∈ [ 0 , p - 1 ] is represented by its one-hot encoding. • Binary embedding: Each a is encoded in binary, padded with zeros to the maximum length ⌊log2(p−1)⌋+1subscript211 _2(p-1) +1⌊ log2 ( p - 1 ) ⌋ + 1. • Fourier embedding: Each a is encoded as a vector of trigonometric functions: [cos(2πi1ap),sin(2πi1ap),⋯,cos(2πikap),sin(2πikap)]2subscript12subscript1⋯2subscript2subscript[ ( 2π i_1ap), ( 2π i_1ap),·s, ( 2% π i_kap), ( 2π i_kap)][ cos ( divide start_ARG 2 π i1 a end_ARG start_ARG p end_ARG ) , sin ( divide start_ARG 2 π i1 a end_ARG start_ARG p end_ARG ) , ⋯ , cos ( divide start_ARG 2 π iitalic_k a end_ARG start_ARG p end_ARG ) , sin ( divide start_ARG 2 π iitalic_k a end_ARG start_ARG p end_ARG ) ], where i1,⋯,ik∈ℕsubscript1⋯subscriptℕi_1,·s,i_k 1 , ⋯ , iitalic_k ∈ blackboard_N are predetermined frequencies. • GPT embedding: Each a is embedded using OpenAI’s text-embedding-3-small model (OpenAI, 2024) Figure 2: FNN training dynamics using different embeddings for the modular addition task (p=113113p=113p = 113). The training dynamics vary significantly across different embeddings. The one-hot embedding and GPT embedding exhibit sharp phase transition. See Appendix A.3 for details of the experimental setup. One-hot embeddings contain no prior information about the data, while binary embeddings capture the ordinal information of integers. Fourier embeddings, inspired by the analytical solutions learned by neural networks (Nanda et al., 2023; Morwani et al., 2024), encode task-specific information. GPT embeddings encode general information about integers. Figure 2 shows the training dynamics of a feed-forward neural network using these embeddings. The training dynamics with one-hot and GPT embeddings exhibit clear grokking behavior, whereas those with binary and Fourier embeddings show continuous generalization progress. Notably, Fourier embeddings enable the model to simultaneously achieve memorization and perfect generalization. We observe that general embeddings like one-hot and GPT embeddings suffer from generalization delay, while embeddings encoded with task-related information allow the model to generalize continuously. Figure 3: Change of empirical NTK. A series of works (Liu et al., 2023; Kumar et al., 2024; Lyu et al., 2024; Mohamadi et al., 2024) found that the default initialization scale is relatively large and causes generalization delay. They observed that reducing the initialization scale can accelerate grokking and hypothesized that grokking arises from a time gap between the Neural Tangent Kernel (NTK) regime and the feature-learning regime. However, our empirical findings indicate that grokking persists even after carefully tuning the initialization scale (see Appendix A.3.1). This suggests that grokking occurs even when the model is not initialized in the kernel regime, implying that the kernel regime may not be the sole cause of grokking. In Figure 3, we compare the changes in the empirical NTK (Mohamadi et al., 2023) corresponding to the dynamics in Figure 2. The change of empirical NTK evolves similarly across all four types of embeddings (see Appendix A.3 for details). In conclusion, the choice of embedding significantly impacts training dynamics, and an informative embedding can close the gap between memorization and generalization. However, finding such an informative embedding for specific tasks is not always straightforward. Binary embedding, for example, reduces the sharp phase transition for modular addition but fails to do so for modular multiplication. In the next section, we will show that constructing a task-specific embedding from training data can be a promising approach to obtaining an informative embedding that can accelerate grokking. The embedding construction can be achieved by training a much smaller, weaker model. Here “small” refers to smaller model expressivity. This weaker model can learn an informative embedding without achieving optimal generalization. This embedding can then be used to positively influence the training dynamics of the larger target model. 2.2 Our Method: GrokTransfer We propose GrokTransfer, a simple and principled method for accelerating grokking in training neural networks. In more detail, given a specific task and a training set GG, we consider a target model fTsubscriptf_Tfitalic_T that has a trainable embedding layer ETsubscriptEE_TET with vocabulary size dvsubscriptd_vditalic_v and embedding dimension dTsubscriptd_Tditalic_T. Our proposed method GrokTransfer works as follows: 1. Train a Weaker Model: Train a weaker model fWsubscriptf_Wfitalic_W with a trainable embedding table EW∈ℝdv×dWsubscriptEsuperscriptℝsubscriptsubscriptE_W ^d_v× d_WEW ∈ blackboard_Rditalic_v × ditalic_W on GG, where dWsubscriptd_Wditalic_W is the embedding dimension in the weak model. Train fWsubscriptf_Wfitalic_W until it groks to a non-trivial performance on the validation set. 2. Train the Target Model: Initialize A=EWsubscriptEA=E_WA = EW and randomly initialize a matrix B∈ℝdW×dTsuperscriptℝsubscriptsubscriptB ^d_W× d_TB ∈ blackboard_Rditalic_W × ditalic_T. Train the target model with an embedding layer set to ET=A⋅BsubscriptE⋅E_T=A· BET = A ⋅ B, where both A and B are trainable. By training a weaker model, the first step aims to obtain an informative embedding that aids the training of the target model. In practice, the weak model can be much smaller than the target model or can even have a different architecture (e.g., the weak model can be a fully-connected network when the target model is a Transformer; see Section 4). As a result, training a weak model greatly reduces the computational cost of acquiring an informative embedding. This contrasts with the method proposed in Minegishi et al. (2024), which requires the target model to be trained till perfect generalization first. In the next sections, we will demonstrate, both theoretically and empirically, that even if the weak model only partially generalizes (i.e., has a non-trivial but non-optimal test error), its embedding still allows the large model to generalize optimally without delay. In the second step, we impose a low-rank structure A⋅B⋅A· BA ⋅ B on the embedding ETsubscriptEE_TET while training the target model. This constraint alters the empirical risk landscape and provides a favorable initialization for the embedding table. The intuition behind our method is as follows: by initializing with an informative embedding from the weak model, the target model can bypass the initial phase of pure memorization. Instead, it can start generalizing almost immediately as it begins to optimize the training loss. 3 Case Study: GrokTransfer on XOR Cluster Data In this section, we theoretically study an XOR classification task and prove that GrokTransfer can eliminate grokking for this task. 3.1 The Setup of XOR Cluster Data We study the setting where the data x=[x1,x2,⋯,xp]⊤=[xsignal⊤,xnoise⊤]⊤∈ℝp,xsignal∼Uniform(±12),xnoise∼Uniform(±εp−2)formulae-sequencesuperscriptsubscript1subscript2⋯subscripttopsuperscriptsuperscriptsubscriptsignaltopsuperscriptsubscriptnoisetoptopsuperscriptℝformulae-sequencesimilar-tosubscriptsignalUniformsuperscriptplus-or-minus12similar-tosubscriptnoiseUniformsuperscriptplus-or-minus2x=[x_1,x_2,·s,x_p] =[x_signal ,x_noise% ] ^p,x_signal (\± 1\% ^2),x_noise (\± \^p-2)x = [ x1 , x2 , ⋯ , xitalic_p ]⊤ = [ xsignal⊤ , xnoise⊤ ]⊤ ∈ blackboard_Rp , xsignal ∼ Uniform ( ± 1 2 ) , xnoise ∼ Uniform ( ± ε p - 2 ), and the label y=x1x2subscript1subscript2y=x_1x_2y = x1 x2. Here ε ε is the parameter that controls the scale of the noise. We denote this data distribution by P and consider n training datapoints (xi,yi)i=1nsuperscriptsubscriptsubscriptsubscript1\(x_i,y_i)\_i=1^n ( xitalic_i , yitalic_i ) i = 1n drawn i.i.d. from the distribution P. We assume the sample size n to be sufficiently large, specifically larger than any universal constant mentioned in this paper. The data distribution comprises four feature vectors (see Figure 5(a) for a projected visualization), and the model need learn all four features to achieve perfect generalization. We denote a width-m two-layer neural network by f(x)=∑j=1majϕ(⟨wj,x⟩),superscriptsubscript1subscriptitalic-ϕsubscriptf(x)= _j=1^ma_jφ( w_j,x ),f ( x ) = ∑j = 1m aitalic_j ϕ ( ⟨ witalic_j , x ⟩ ) , where wj∈ℝp,j∈[m]formulae-sequencesubscriptsuperscriptℝdelimited-[]w_j ^p,j∈[m]witalic_j ∈ blackboard_Rp , j ∈ [ m ] are neurons in the hidden layer and aj∈ℝ,j∈[m]formulae-sequencesubscriptℝdelimited-[]a_j ,j∈[m]aitalic_j ∈ blackboard_R , j ∈ [ m ] are second-layer weights. The model is randomly initialized by wj∼i.i.dN(0,winit2Ip),aj∼i.i.dN(0,ainit2),j∈[m].formulae-sequencesuperscriptsimilar-toformulae-sequencesubscript0superscriptsubscriptinit2subscriptformulae-sequencesuperscriptsimilar-toformulae-sequencesubscript0superscriptsubscriptinit2delimited-[] -1mw_j i.i.d N(0,w_init^% 2I_p), a_j i.i.d N(0,a_init% ^2), j∈[m]. -1mwitalic_j start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG i . i . d end_ARG end_RELOP N ( 0 , winit2 Iitalic_p ) , aitalic_j start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG i . i . d end_ARG end_RELOP N ( 0 , ainit2 ) , j ∈ [ m ] . Define the empirical risk with the exponential loss as: L^(f)=∑i=1nl(yi,f(xi))/n,^superscriptsubscript1subscriptsubscript L(f)= _i=1^nl(y_i,f(x_i))/n,over start_ARG L end_ARG ( f ) = ∑i = 1n l ( yitalic_i , f ( xitalic_i ) ) / n , where l(y,y^)=exp(−yy^)^^l(y, y)= (-y y)l ( y , over start_ARG y end_ARG ) = exp ( - y over start_ARG y end_ARG ). We use gradient descent (GD) with weight decay θj(t+1)=(1−λ)θj(t)−α∇θjL^(f(t))superscriptsubscript11superscriptsubscriptsubscript∇subscript^superscript _j^(t+1)=(1-λ) _j^(t)-α _ _j% L(f^(t))θitalic_j( t + 1 ) = ( 1 - λ ) θitalic_j( t ) - α ∇θ start_POSTSUBSCRIPT j end_POSTSUBSCRIPT over start_ARG L end_ARG ( f( t ) ) to update both layers wj,ajj=1msuperscriptsubscriptsubscriptsubscript1\w_j,a_j\_j=1^m witalic_j , aitalic_j j = 1m, where λ is the coefficient of L2subscript2L_2L2 regularization. Setting p=80000,n=400,ε=0.05formulae-sequence80000formulae-sequence4000.05p=80000,n=400, =0.05p = 80000 , n = 400 , ε = 0.05, this configuration approximates one of the distributions explored in Xu et al. (2024), where grokking was observed. Under this setup, we train a two-layer neural network on (xi,yi)i=1nsuperscriptsubscriptsubscriptsubscript1\(x_i,y_i)\_i=1^n ( xitalic_i , yitalic_i ) i = 1n with default PyTorch initialization. We observe grokking, as shown in Figure 4(a), where overfitting is achieved by the fifth epoch and generalization begins around the 80808080-th epoch. Below we will show how our method GrokTransfer constructs a new embedding and eliminates the observed delay in generalization in subsequent sections. Figure 4: (a) Training dynamics of a two-layer neural network with a hidden width of 2048204820482048, where grokking is observed. (b) Training dynamics of a two-layer neural network with a hidden width of 3333. The model can only achieve around 75%percent7575\%75 % validation accuracy and a phase transition near 100100100100th epoch is observed. (c) Visualization of individual neuron weights from the model trained in (b). It shows three distinct patterns and each corresponds to a feature direction of the XOR data distribution. See Appendix A.3 for details of the experimental setup. 3.2 Empirical Analysis of the Weaker Model Applying GrokTransfer, we first train a small two-layer neural network with only 3333 neurons fS(x)=∑j=13ajϕ(⟨wj,x⟩)subscriptsuperscriptsubscript13subscriptitalic-ϕsubscriptf_S(x)= _j=1^3a_jφ( w_j,x )fitalic_S ( x ) = ∑j = 13 aitalic_j ϕ ( ⟨ witalic_j , x ⟩ ) till convergence (Figure 4(b)). Denote the first-layer weight matrix by W=[w1,w2,w3]∈ℝp×3subscript1subscript2subscript3superscriptℝ3W=[w_1,w_2,w_3] ^p× 3W = [ w1 , w2 , w3 ] ∈ blackboard_Rp × 3, the number of training steps by T, and the model after training by fS(T)superscriptsubscriptf_S^(T)fitalic_S( T ). Due to the complexity of the training dynamics, it is hard to derive the closed form of fS(T)superscriptsubscriptf_S^(T)fitalic_S( T ) and W(T)superscriptW^(T)W( T ). Below we empirically investigate what information the model has gained and how well it learns. Figure 4(b) shows that, after training, this weak model has non-trivial performance with test accuracy around 75%percent7575\%75 %. The neurons wj(T)j=13superscriptsubscriptsuperscriptsubscript13\w_j^(T)\_j=1^3 witalic_j( T ) j = 13 are visualized in Figure 4(c), displaying patterns [−1,1,0,⋯,0],[1,−1,0,⋯,0],110⋯0110⋯0[-1,1,0,·s,0],[1,-1,0,·s,0],[ - 1 , 1 , 0 , ⋯ , 0 ] , [ 1 , - 1 , 0 , ⋯ , 0 ] , and [−1,−1,0,⋯,0]110⋯0[-1,-1,0,·s,0][ - 1 , - 1 , 0 , ⋯ , 0 ]. Note that the specific features learned by the model are sensitive to its initialization. Nevertheless, we find that empirically, the learned features are always three among the four features [±1,±1,0,…,0]plus-or-minus1plus-or-minus10…0[± 1,± 1,0,…,0][ ± 1 , ± 1 , 0 , … , 0 ], provided the test accuracy is around 75%percent7575\%75 %. Notice that an optimal function for this classification task is f(x)=sgn(ϕ(x1+x2)+ϕ(−x1−x2)−ϕ(−x1+x2)−ϕ(x1−x2)),sgnitalic-ϕsubscript1subscript2italic-ϕsubscript1subscript2italic-ϕsubscript1subscript2italic-ϕsubscript1subscript2f(x)= sgn(φ(x_1+x_2)+φ(-x_1-x_2)-φ(-x_1+x_% 2)-φ(x_1-x_2)),f ( x ) = sgn ( ϕ ( x1 + x2 ) + ϕ ( - x1 - x2 ) - ϕ ( - x1 + x2 ) - ϕ ( x1 - x2 ) ) , which needs four neurons to represent all features [±1,±1]plus-or-minus1plus-or-minus1[± 1,± 1][ ± 1 , ± 1 ]. It thus follows intuitively that the weak model fSsubscriptf_Sfitalic_S cannot achieve better generalization with only three neurons. Formally, we establish the following lemma regarding the expressive power of fSsubscriptf_Sfitalic_S. Lemma 3.1. For any f(x)=∑j=13ajϕ(wj⊤x)superscriptsubscript13subscriptitalic-ϕsuperscriptsubscripttopf(x)= _j=1^3a_jφ(w_j x)f ( x ) = ∑j = 13 aitalic_j ϕ ( witalic_j⊤ x ), where ϕitalic-ϕφϕ is the ReLU activation function, we have ℙ(x,y)∼P(y=sgn(f(x)))≤75%.subscriptℙsimilar-tosgnpercent75P_(x,y) P(y= sgn(f(x)))≤ 75\%.blackboard_P( x , y ) ∼ P ( y = sgn ( f ( x ) ) ) ≤ 75 % . Although the model fS(T)superscriptsubscriptf_S^(T)fitalic_S( T ) fails to generalize perfectly due to the inherent limitation of capacity, it has correctly selected the subset that contains features after training as shown in Figure 4(c). Consequently, for any input x∼Psimilar-tox Px ∼ P, W(T)⊤xsuperscriptlimit-fromtopW^(T) xW( T ) ⊤ x becomes a high-quality embedding for x in a much lower dimensional space. Figure 5(a) shows that, with this new embedding, data points are well-separated in a three-dimensional space with a relatively high signal-to-noise ratio (SNR) compared to the original embedding. Next, we empirically examine the order of the ratio between the norm of the complementary subnetwork and the norm of the generalizing subnetwork. This will be used to estimate the SNR of P with the new embedding. Given the structure of the XOR cluster data, the first two rows of W(T)superscriptW^(T)W( T ) correspond to the generalizing subnetwork. We define the norm ratio between the complementary and generalizing subnetwork as follows: rW=‖W3:p,⋅(T)‖F/p−2‖W1:2,⋅(T)‖F/2.subscriptsubscriptnormsubscriptsuperscript:3⋅F2subscriptnormsubscriptsuperscript:12⋅F2r_W= \|W^(T)_3:p,·\|_F/ p-2\|W^(T)_1:2,% ·\|_F/ 2.ritalic_W = divide start_ARG ∥ W( T )3 : p , ⋅ ∥F / square-root start_ARG p - 2 end_ARG end_ARG start_ARG ∥ W( T )1 : 2 , ⋅ ∥F / square-root start_ARG 2 end_ARG end_ARG . Figure 5(b) and 5(c) show that the norm ratio is proportional to ε ε and 1/n11/ n1 / square-root start_ARG n end_ARG, i.e. rW∝ε/nproportional-tosubscriptr_W / nritalic_W ∝ ε / square-root start_ARG n end_ARG. We will use this property to show that, under mild assumptions, the target model can learn this low-dimensional XOR task with just one step of gradient descent. (a) (b) (c) Figure 5: (a) 3-D Visualization of the distribution P with the embedding from the weak model. The clusters are well-separated under the new embedding. (b) Norm ratio rWsubscriptr_Writalic_W for different values of p and ε ε with fixed sample size n, indicating that rWsubscriptr_Writalic_W does not depend on p. (c) Norm ratio rWsubscriptr_Writalic_W for different values of n and ε ε with fixed feature dimension p. For each ϵitalic-ϵεϵ, the slope is around −1/212-1/2- 1 / 2, indicating that rWsubscriptr_Writalic_W is proportional to 1/n11/ n1 / square-root start_ARG n end_ARG. See Appendix A.3 for details of the experimental setup. 3.3 Theoretical Analysis of the Target Model In this section, we theoretically analyze the behavior of GrokTransfer on the XOR cluster data. We consider the target model as a large model with width m of the form fL(x)=∑j=1majϕ(⟨vj,U⊤x⟩),subscriptsuperscriptsubscript1subscriptitalic-ϕsubscriptsuperscripttopf_L(x)= _j=1^ma_jφ( v_j,U x ),fitalic_L ( x ) = ∑j = 1m aitalic_j ϕ ( ⟨ vitalic_j , U⊤ x ⟩ ) , where U=[u1,u2,u3]∈ℝp×3subscript1subscript2subscript3superscriptℝ3U=[u_1,u_2,u_3] ^p× 3U = [ u1 , u2 , u3 ] ∈ blackboard_Rp × 3 comes from the first-layer weight matrix W(T)superscriptW^(T)W( T ) learned by the weak model (visualized in Figure 4(c)). Here, U is the embedding matrix being transferred from the weak model fSsubscriptf_Sfitalic_S, which will then go through another linear transformation (given by vjsubscriptv_jvitalic_j’s) to form the embedding in the target model. Following our observation in Section 3.2, we can write u1=[μ2⊤,δ1⊤]⊤,u2=[−μ2⊤,δ2⊤]⊤,u3=[−μ1⊤,δ3⊤]⊤,formulae-sequencesubscript1superscriptsuperscriptsubscript2topsuperscriptsubscript1toptopformulae-sequencesubscript2superscriptsuperscriptsubscript2topsuperscriptsubscript2toptopsubscript3superscriptsuperscriptsubscript1topsuperscriptsubscript3toptopu_1=[ _2 , _1 ] , u_2=[- _2 ,% _2 ] , u_3=[- _1 , _3 ]^% ,u1 = [ μ2⊤ , δ1⊤ ]⊤ , u2 = [ - μ2⊤ , δ2⊤ ]⊤ , u3 = [ - μ1⊤ , δ3⊤ ]⊤ , where μ1=[1,1]⊤,μ2=[−1,1]⊤formulae-sequencesubscript1superscript11topsubscript2superscript11top _1=[1,1] , _2=[-1,1] μ1 = [ 1 , 1 ]⊤ , μ2 = [ - 1 , 1 ]⊤ are two orthogonal features of P, and δj=[δj,1,⋯,δj,p−2]⊤∈ℝp−2(j∈[3])subscriptsuperscriptsubscript1⋯subscript2topsuperscriptℝ2delimited-[]3 _j=[ _j,1,·s, _j,p-2] ^p-2(j∈% [3])δitalic_j = [ δitalic_j , 1 , ⋯ , δitalic_j , p - 2 ]⊤ ∈ blackboard_Rp - 2 ( j ∈ [ 3 ] ).111We assume that the weak model learned three features [1,1],[−1,1],[−1,−1]111111[1,1],[-1,1],[-1,-1][ 1 , 1 ] , [ - 1 , 1 ] , [ - 1 , - 1 ] without loss of generality. Our result will hold the same for any three features among the four features [±1,±1]plus-or-minus1plus-or-minus1[± 1,± 1][ ± 1 , ± 1 ]. Here we let δ=[δ1,δ2,δ3]=W3:p,⋅(T)subscript1subscript2subscript3superscriptsubscript:3⋅δ=[ _1, _2, _3]=W_3:p,·^(T)δ = [ δ1 , δ2 , δ3 ] = W3 : p , ⋅( T ). Given a universal constant C>11C>1C > 1, we assume (A1) The noise scale ε≤(n/(plog3n))1/4superscriptsuperscript314 ≤(n/(p ^3n)) 14ε ≤ ( n / ( p log3 n ) )/ start_ARG 1 end_ARG start_ARG 4 end_ARG. (A2) The norm of the complementary subnetwork satisfies ‖δ‖F≤Cεp/n.subscriptnormF\|δ\|_F≤ C p/n.∥ δ ∥F ≤ C ε square-root start_ARG p / n end_ARG . (A3) The initialization scale vinit≤Clog−3/2(n)subscriptinitsuperscript32v_init≤ C ^- 32(n)vinit ≤ C log- / start_ARG 3 end_ARG start_ARG 2 end_ARG ( n ). (A4) The step size mvinit/C≤α≤mvinitsubscriptinitsubscriptinit mv_init/C≤α≤ mv_initsquare-root start_ARG m end_ARG vinit / C ≤ α ≤ square-root start_ARG m end_ARG vinit. (A5) The number of neurons satisfies m≥2log3n2superscript3m≥ 2 ^3nm ≥ 2 log3 n. Here Assumption (A2) corresponds to the finding that rW∝ε/nproportional-tosubscriptr_W / nritalic_W ∝ ε / square-root start_ARG n end_ARG in Section 3.2. Assumptions (A1) and (A2) together ensure that the SNR of the distribution P in the new embedding space is large enough. Assumption (A3) controls the initial weight norm of the target model such that the empirical risk starts within a reasonable range. Assumption (A4) guarantees that the step size is appropriately balanced; it is neither too small to prevent meaningful updates after a single-step gradient descent nor too large to cause overly drastic movements. Assumption (A5) ensures that the model’s width is large enough to ensure certain concentration results about the random initialization. All assumptions are satisfied in the empirical setup discussed in Section 3.2. We denote a=[a1,⋯,am]⊤∈ℝmsuperscriptsubscript1⋯subscripttopsuperscriptℝa=[a_1,·s,a_m] ^ma = [ a1 , ⋯ , aitalic_m ]⊤ ∈ blackboard_Rm and V=[v1,⋯,vm]∈ℝ3×psubscript1⋯subscriptsuperscriptℝ3V=[v_1,·s,v_m] ^3× pV = [ v1 , ⋯ , vitalic_m ] ∈ blackboard_R3 × p. We initialize a and V as follows: aj∼i.i.dUniform(±1/m),vj∼i.i.dUniform(±vinit3),j∈[m],formulae-sequencesuperscriptsimilar-toformulae-sequencesubscriptUniformplus-or-minus1formulae-sequencesuperscriptsimilar-toformulae-sequencesubscriptUniformsuperscriptplus-or-minussubscriptinit3delimited-[]a_j i.i.d Uniform(\± 1/ m\)% , v_j i.i.d Uniform(\± v_% init\^3), j∈[m],aitalic_j start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG i . i . d end_ARG end_RELOP Uniform ( ± 1 / square-root start_ARG m end_ARG ) , vitalic_j start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG i . i . d end_ARG end_RELOP Uniform ( ± vinit 3 ) , j ∈ [ m ] , and keep a and U fixed during the training process.222Our result will not be affected if a and U are also trainable. We set them fixed to simplify the analysis while still conveying the main ideas. Following the training method outlined in Section 3.1, we use gradient descent V(t+1)=V(t)−α∇VL^(fL(t))superscript1superscriptsubscript∇^superscriptsubscriptV^(t+1)=V^(t)-α _V L(f_L^(t))V( t + 1 ) = V( t ) - α ∇V over start_ARG L end_ARG ( fitalic_L( t ) ) at step t to update the linear layer V, where α is the step size and the empirical risk L^(⋅)^⋅ L(·)over start_ARG L end_ARG ( ⋅ ) is defined in Section 3.1. With the assumptions and initializations, we state the theorem that characterizes the train and test error of the target model after one step. Theorem 3.2. Suppose that Assumptions (A1)-(A5) hold. With probability at least 1−O(1/n2)11superscript21-O(1/n^2)1 - O ( 1 / n2 ) over the generation of the training data and initial weights of fLsubscriptf_Lfitalic_L, after one step of training, the classifier sgn(fL(1)(x))sgnsuperscriptsubscript1 sgn(f_L^(1)(x))sgn ( fitalic_L( 1 ) ( x ) ) can correctly classify all training datapoints and generalize with test error no greater than exp(−Ω(log2n))Ωsuperscript2 (- ( ^2n))exp ( - Ω ( log2 n ) ). Theorem 3.2 shows that with GrokTransfer, after just one step of gradient descent, the target model overfits all training data and achieves near perfect test accuracy. Notably, this is not in a kernel regime but a feature learning regime. Since models with normal training cannot achieve generalization in one step (Figure 4(a)), this result indicates that our method GrokTransfer effectively boosts the generalization speed of the target model and eliminates the time gap between overfitting and generalization. Empirically, the model continues to generalize with further training (see Figure 9 in Appendix A.3). Given that the weaker model fSsubscriptf_Sfitalic_S has only three neurons, the computational cost of training fSsubscriptf_Sfitalic_S is negligible compared to the cost of training the target model fLsubscriptf_Lfitalic_L with sufficiently large width. This implies that GrokTransfer may reduce the overall computational cost. In the next section, we will compare the computational cost of our method to that of standard training procedures. 4 Experiments This section empirically studies GrokTransfer in modular addition and multiplication, as well as the sparse parity task. Our experiments verify that GrokTransfer effectively reshapes the training dynamics and eliminate delayed generalization for both fully-connected neural networks (FNN) and Transformers (TF). The AdamW optimizer (Loshchilov & Hutter, 2019) is used in all experiments in this section. (d) Modular addition (e) Modular multiplication (f) (40,3)403(40,3)( 40 , 3 )-parity Figure 6: Training dynamics of FNNs on various tasks. The rows represent different models/training methods: The first row shows the dynamics of the weak model used in GrokTransfer, the second row shows the dynamics of the target model trained using GrokTransfer, and the target model trained from scratch. The columns represent different tasks: the first column is for the modular addition task, the second column is for the modular multiplication task, and the third column is for the (40,3)403(40,3)( 40 , 3 )-parity task. The comparison between the first and second rows shows that the target model trained via GrokTransfer can surpass the weak model’s performance. The comparison within the second row shows that GrokTransfer eliminates the sharp phase transition and enables the model to make continuous progress. See Appendix A.3 for details of the experimental setup. 4.1 FNN → FNN We first consider a three-layer FNN as the target model and conduct GrokTransfer on tasks including modular addition, modular multiplication, and (q,k)(q,k)( q , k )-parity (Barak et al., 2022). These results are compared to training a target model from scratch. The modular addition task is introduced in Section 2.1, and modular multiplication is defined similarly with the label y=ab mod p mod y=ab mod py = a b mod p. The (q,k)(q,k)( q , k )-parity task consists of a dataset (x,y):x∈±1q,y=∏i∈Sxi,|S|=kconditional-setformulae-sequencesuperscriptplus-or-minus1formulae-sequencesubscriptproductsubscript\(x,y):x∈\± 1\^q,y= _i∈ Sx_i,|S|=k\ ( x , y ) : x ∈ ± 1 q , y = ∏i ∈ S xitalic_i , | S | = k . Following the setting in Merrill et al. (2023), we choose q=40,k=3formulae-sequence403q=40,k=3q = 40 , k = 3, and S=1,2,3123S=\1,2,3\S = 1 , 2 , 3 . For the modular addition and multiplication tasks, we employ a two-layer neural network with a trainable embedding as the weak model, which we train for 104superscript10410^4104 epochs. We then initialize the target model by setting its layer A to the embedding learned by the weak model. Figure 6 and 6 show the training dynamics of the weak model, the target model trained via GrokTransfer, and the target model trained from scratch. Notably, GrokTransfer nearly eliminates the sharp phase transition observed in normal training. Here all training hyperparameters (initialization scale, learning rate, weight decay) are selected by grid search, and the best configuration is defined as the one that reaches 99%percent9999\%99 % test accuracy the quickest. The oscillations of accuracies in the second row of Figure 6 are related to the “slingshot mechanism” (Thilak et al., 2022) and training instabilities associated with large learning rates (Wortsman et al., 2024). Since large learning rate and this kind of oscillation are believed to help generalization (Damian et al., 2023; Lu et al., 2024), we do not change our configuration selection criteria. For the parity task, we use a three-layer FNN as the weak model, as empirical evidence suggests that a two-layer FNN without bias terms cannot generalize on this task. The weak model is trained until it achieves 70%percent7070\%70 % test accuracy, after which the first layer’s weight matrix is transferred to the target model. As shown in Figure 6, the weak model undergoes a generalization delay, but the large model inheriting its embedding generalizes continuously. Ablation study: To further understand the empirical effectiveness of GrokTransfer, we perform an ablation study by varying the training epochs of the weak model in the modular addition task. Figure 7: Ablation study showing the effect of the weak model’s performance on the test accuracy of the target model (initialized via GrokTransfer and trained for 104superscript10410^4104 epochs). We extract the embeddings of the weak model at epochs 100,500,800,900,1000,1100,1500,and 2000100500800900100011001500and 2000100,500,800,900,1000,1100,1500,and 2000100 , 500 , 800 , 900 , 1000 , 1100 , 1500 , and 2000. For each embedding, we apply GrokTransfer to the target model and train it for 104superscript10410^4104 epochs. To measure the generalization delay of the target model, we define Time Gap as the difference between the first epoch that achieves 95%percent9595\%95 % training accuracy and the first epoch that achieves 95%percent9595\%95 % test accuracy. If the target model fails to reach 95%percent9595\%95 % accuracy, we set 1/1/1 /Time Gap =0absent0=0= 0. Figure 7 shows that the test performance of the target model, initialized with the weak model’s embedding, is positively correlated with the test performance of the weak model. A grokked weak model is essential for the target model to achieve near-perfect generalization with minimal generalization delay. We hypothesize that the target model can only generalize well after the weak model has grokked. 4.2 FNN → Transformers Interestingly, we find that the embeddings extracted from the weak FNN model can be transferred to the target model even when the target model is a Transformer comparable to the scale of GPT2-small (Radford et al., 2019). Under this FNN → TF setting, GrokTransfer still mitigates the generalization delay of the target model. Specifically, we choose the target model to be a Transformer with 8888 attention layers, (dembed,dmlp,nhead)=(512,512,4)subscriptembedsubscriptmlpsubscripthead5125124(d_embed,d_mlp,n_head)=(512,512,4)( dembed , dmlp , nhead ) = ( 512 , 512 , 4 ). For each sample, the input is a sequence with two tokens (a,b)(a,b)( a , b ). We extract the embeddings of the weak model in Figure 6 at the point that it first reaches 30%percent3030\%30 % test accuracy. Figure 8(a) shows that GrokTransfer enables the target model to generalize much faster than training from scratch and exhibits little generalization delay. Here both method suffer from training instability of large learning rates. In terms of the computation cost, we use wall-clock time as the measure. The computation cost of GrokTransfer comprises the training of weak model and the training of target model. Table 1 shows the total wall-clock time for weak model, target model with GrokTransfer, and target model trained from scratch. The time spent training the weak FNN model is negligible compared to training the target transformer model. The total wall-clock time of GrokTransfer is approximately five times faster than training from scratch. Figure 8: Training dynamics of Transformers on Modular Addition Task. The weak model is a three-layer FNN. (a) Dynamics of the target model (an 8888-layer transformer) trained via GrokTransfer, and the target model trained from scratch. (b) Dynamics of the target model (a two-layer Transformer) trained via GrokTransfer, and the target model trained via GrokFast (Lee et al., 2024). Model Weak Target (GrokTransfer) Target (scratch) Total Wall-clock time (ms) 2828 71079 392667 Table 1: Comparison of total wall clock times (forward and backward passes) for different models. The weak model is a three-layer FNN. The target/large model is an 8888-layer transformer. Lee et al. (2024) proposed a gradient amplification algorithm GrokFast to accelerate grokking. We compare GrokTransfer with GrokFast in Figure 8(b). The weak model embedding we transfer is the same as the one used in Figure 8(a). For the target model, we follow the model used in Lee et al. (2024), which is a two-layer decoder-only transformer with (dembed,dmlp,nhead)=(128,512,4)subscriptembedsubscriptmlpsubscripthead1285124(d_embed,d_mlp,n_head)=(128,512,4)( dembed , dmlp , nhead ) = ( 128 , 512 , 4 ). The Time Gap of GrokTransfer is 46464646 while the Time Gap of GrokFast is 1119111911191119. 5 Conclusion To eliminate the unpredictability associated with grokking, we proposed GrokTransfer, a novel method that effectively accelerates grokking by transferring the embedding from a weaker model. Our method was inspired by the key observation that data embedding critically shapes training dynamics. We theoretically justified GrokTransfer on an XOR classification task. We also empirically evaluated it on various algorithmic tasks known to exhibit grokking under standard training. Our results showed that GrokTransfer can effectively modify training dynamics, enabling continuous progression in model performance. One limitation of our work is that the theoretical result only considers a relatively simple XOR task. For this task, after transferring the embedding from the smaller model, one step of gradient descent suffices for both memorization and generalization. Theoretical justification for more complex problems is an important future direction. Furthermore, our method focuses solely on accelerating grokking and was only investigated on problems where grokking occurs. It would be interesting to study whether similar ideas can be applied to improve training dynamics or enable weak-to-strong generalization in a broader context. Acknowledgments This work was supported in part by the Office of Naval Research under grant number N00014-23-1-2590, the National Science Foundation under Grant No. 2231174, No. 2310831, No. 2428059, No. 2435696, No. 2440954, and a Michigan Institute for Data Science Propelling Original Data Science (PODS) grant. References Barak et al. (2022) Boaz Barak, Benjamin Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang. Hidden progress in deep learning: Sgd learns parities near the computational limit. Advances in Neural Information Processing Systems, 35:21750–21764, 2022. Burns et al. (2023) Collin Burns, Pavel Izmailov, Jan Hendrik Kirchner, Bowen Baker, Leo Gao, Leopold Aschenbrenner, Yining Chen, Adrien Ecoffet, Manas Joglekar, Jan Leike, Ilya Sutskever, and Jeff Wu. Weak-to-strong generalization: Eliciting strong capabilities with weak supervision. arXiv preprint arXiv:2312.09390, 2023. Chen et al. (2024) Angelica Chen, Ravid Shwartz-Ziv, Kyunghyun Cho, Matthew L Leavitt, and Naomi Saphra. Sudden drops in the loss: Syntax acquisition, phase transitions, and simplicity bias in MLMs. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=MO5PiKHELW. Chughtai et al. (2023) Bilal Chughtai, Lawrence Chan, and Neel Nanda. A toy model of universality: Reverse engineering how networks learn group operations. In International Conference on Machine Learning, p. 6243–6267. PMLR, 2023. Damian et al. (2023) Alex Damian, Eshaan Nichani, and Jason D. Lee. Self-stabilization: The implicit bias of gradient descent at the edge of stability. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=nhKHA59gXz. Davies et al. (2023) Xander Davies, Lauro Langosco, and David Krueger. Unifying grokking and double descent. arXiv preprint arXiv:2303.06173, 2023. Doshi et al. (2024) Darshil Doshi, Tianyu He, Aritra Das, and Andrey Gromov. Grokking modular polynomials. arXiv preprint arXiv:2406.03495, 2024. Furuta et al. (2024) Hiroki Furuta, Minegishi Gouki, Yusuke Iwasawa, and Yutaka Matsuo. Interpreting grokked transformers in complex modular arithmetic. arXiv preprint arXiv:2402.16726, 2024. Gopalani et al. (2025) Pulkit Gopalani, Ekdeep S Lubana, and Wei Hu. Abrupt learning in transformers: A case study on matrix completion. Advances in Neural Information Processing Systems, 37:55053–55085, 2025. Gromov (2023) Andrey Gromov. Grokking modular arithmetic. arXiv preprint arXiv:2301.02679, 2023. He et al. (2024) Tianyu He, Darshil Doshi, Aritra Das, and Andrey Gromov. Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks. arXiv preprint arXiv:2406.02550, 2024. Humayun et al. (2024) Ahmed Imtiaz Humayun, Randall Balestriero, and Richard Baraniuk. Deep networks always grok and here is why. arXiv preprint arXiv:2402.15555, 2024. Kaplan et al. (2020) Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020. Kumar et al. (2024) Tanishq Kumar, Blake Bordelon, Samuel J. Gershman, and Cengiz Pehlevan. Grokking as the transition from lazy to rich training dynamics. In The Twelfth International Conference on Learning Representations, 2024. Lee et al. (2024) Jaerin Lee, Bong Gyun Kang, Kihoon Kim, and Kyoung Mu Lee. Grokfast: Accelerated grokking by amplifying slow gradients. arXiv preprint arXiv:2405.20233, 2024. Liu et al. (2022) Ziming Liu, Ouail Kitouni, Niklas S Nolte, Eric Michaud, Max Tegmark, and Mike Williams. Towards understanding grokking: An effective theory of representation learning. Advances in Neural Information Processing Systems, 35:34651–34663, 2022. Liu et al. (2023) Ziming Liu, Eric J Michaud, and Max Tegmark. Omnigrok: Grokking beyond algorithmic data. In The Eleventh International Conference on Learning Representations, 2023. Loshchilov & Hutter (2019) Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=Bkg6RiCqY7. Lu et al. (2024) Miao Lu, Beining Wu, Xiaodong Yang, and Difan Zou. Benign oscillation of stochastic gradient descent with large learning rate. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=wYmvN3sQpG. Lyu et al. (2024) Kaifeng Lyu, Jikai Jin, Zhiyuan Li, Simon Shaolei Du, Jason D Lee, and Wei Hu. Dichotomy of early and late phase implicit biases can provably induce grokking. In The Twelfth International Conference on Learning Representations, 2024. Mallinar et al. (2024) Neil Mallinar, Daniel Beaglehole, Libin Zhu, Adityanarayanan Radhakrishnan, Parthe Pandit, and Mikhail Belkin. Emergence in non-neural models: grokking modular arithmetic via average gradient outer product. arXiv preprint arXiv:2407.20199, 2024. Merrill et al. (2023) William Merrill, Nikolaos Tsilivis, and Aman Shukla. A tale of two circuits: Grokking as competition of sparse and dense subnetworks. arXiv preprint arXiv:2303.11873, 2023. Miller et al. (2023) Jack Miller, Charles O’Neill, and Thang Bui. Grokking beyond neural networks: An empirical exploration with model complexity. arXiv preprint arXiv:2310.17247, 2023. Minegishi et al. (2024) Gouki Minegishi, Yusuke Iwasawa, and Yutaka Matsuo. Bridging lottery ticket and grokking: Is weight norm sufficient to explain delayed generalization? In ICLR 2024 Workshop on Bridging the Gap Between Practice and Theory in Deep Learning, 2024. Mohamadi et al. (2023) Mohamad Amin Mohamadi, Wonho Bae, and Danica J Sutherland. A fast, well-founded approximation to the empirical neural tangent kernel. In International Conference on Machine Learning, p. 25061–25081. PMLR, 2023. Mohamadi et al. (2024) Mohamad Amin Mohamadi, Zhiyuan Li, Lei Wu, and Danica J Sutherland. Why do you grok? a theoretical analysis of grokking modular addition. arXiv preprint arXiv:2407.12332, 2024. Morwani et al. (2024) Depen Morwani, Benjamin L. Edelman, Costin-Andrei Oncescu, Rosie Zhao, and Sham M. Kakade. Feature emergence via margin maximization: case studies in algebraic tasks. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=i9wDX850jR. Nanda et al. (2023) Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. Progress measures for grokking via mechanistic interpretability. In The Eleventh International Conference on Learning Representations, 2023. OpenAI (2024) OpenAI. text-embedding-3-small model. https://platform.openai.com/docs/models/embedding-3, 2024. Power et al. (2022) Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra. Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv preprint arXiv:2201.02177, 2022. Radford et al. (2019) Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019. Thilak et al. (2022) Vimal Thilak, Etai Littwin, Shuangfei Zhai, Omid Saremi, Roni Paiss, and Joshua Susskind. The slingshot mechanism: An empirical study of adaptive optimizers and the grokking phenomenon. arXiv preprint arXiv:2206.04817, 2022. Varma et al. (2023) Vikrant Varma, Rohin Shah, Zachary Kenton, János Kramár, and Ramana Kumar. Explaining grokking through circuit efficiency. arXiv preprint arXiv:2309.02390, 2023. Wang et al. (2024) Boshi Wang, Xiang Yue, Yu Su, and Huan Sun. Grokked transformers are implicit reasoners: A mechanistic journey to the edge of generalization. arXiv preprint arXiv:2405.15071, 2024. Wang et al. (2023) Peihao Wang, Rameswar Panda, Lucas Torroba Hennigen, Philip Greengard, Leonid Karlinsky, Rogerio Feris, David Daniel Cox, Zhangyang Wang, and Yoon Kim. Learning to grow pretrained models for efficient transformer training. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=cDYRS5iZ16f. Wortsman et al. (2024) Mitchell Wortsman, Peter J Liu, Lechao Xiao, Katie E Everett, Alexander A Alemi, Ben Adlam, John D Co-Reyes, Izzeddin Gur, Abhishek Kumar, Roman Novak, Jeffrey Pennington, Jascha Sohl-Dickstein, Kelvin Xu, Jaehoon Lee, Justin Gilmer, and Simon Kornblith. Small-scale proxies for large-scale transformer training instabilities. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=d8w0pmvXbZ. Xu et al. (2024) Zhiwei Xu, Yutong Wang, Spencer Frei, Gal Vardi, and Wei Hu. Benign overfitting and grokking in ReLU networks for XOR cluster data. In The Twelfth International Conference on Learning Representations, 2024. Yang et al. (2025) Yongyi Yang, Core Francisco Park, Ekdeep Singh Lubana, Maya Okawa, Wei Hu, and Hidenori Tanaka. Dynamics of concept learning and compositional generalization. In The Thirteenth International Conference on Learning Representations, 2025. URL https://openreview.net/forum?id=s1zO0YBEF8. Zhu et al. (2024) Xuekai Zhu, Yao Fu, Bowen Zhou, and Zhouhan Lin. Critical data size of language models from a grokking perspective. arXiv preprint arXiv:2401.10463, 2024. Žunkovič & Ilievski (2022) Bojan Žunkovič and Enej Ilievski. Grokking phase transitions in learning local rules with gradient descent. arXiv preprint arXiv:2210.15435, 2022. Appendix A Appendix A.1 Proofs A.1.1 Proof of Lemma 3.1 See 3.1 Proof. For any (x,y)∼Psimilar-to(x,y) P( x , y ) ∼ P, define x′=(x1,−x2,x3,⋯,xp)superscript′subscript1subscript2subscript3⋯subscriptx =(x_1,-x_2,x_3,·s,x_p)x′ = ( x1 , - x2 , x3 , ⋯ , xitalic_p ) and y′=sgn(x1′x2′)=−ysuperscript′sgnsuperscriptsubscript1′subscript2′y = sgn(x_1 x_2 )=-y′ = sgn ( x1′ x2′ ) = - y. It is sufficient to show that if y=sgn(f(x)),y=sgn(f(−x)),y′=sgn(f(x′))formulae-sequencesgnformulae-sequencesgnsuperscript′sgnsuperscript′y= sgn(f(x)),y= sgn(f(-x)),y =% sgn(f(x ))y = sgn ( f ( x ) ) , y = sgn ( f ( - x ) ) , y′ = sgn ( f ( x′ ) ), then y′≠sgn(f(−x′))superscript′sgnsuperscript′y ≠ sgn(f(-x ))y′ ≠ sgn ( f ( - x′ ) ) with probability 1111. Assume y=sgn(f(x))sgny= sgn(f(x))y = sgn ( f ( x ) ) and y=sgn(f(−x))sgny= sgn(f(-x))y = sgn ( f ( - x ) ). Given ϕ(z)≥0,∀zitalic-ϕ0for-allφ(z)≥ 0,∀ zϕ ( z ) ≥ 0 , ∀ z, y=sgn(f(x))sgny= sgn(f(x))y = sgn ( f ( x ) ) implies that there exists at least one i∈[3]delimited-[]3i∈[3]i ∈ [ 3 ] such that aisubscripta_iaitalic_i has the same sign as y and wi⊤x>0superscriptsubscripttop0w_i x>0witalic_i⊤ x > 0. Without loss of generality, assume sgn(a1)=y,w1⊤x>0formulae-sequencesgnsubscript1superscriptsubscript1top0 sgn(a_1)=y,w_1 x>0sgn ( a1 ) = y , w1⊤ x > 0. Then for (−x,−y)(-x,-y)( - x , - y ), it follows that f(−x)=∑j=13ajϕ(−wj⊤x)=a2ϕ(−w2⊤x)+a3ϕ(−w3⊤x)superscriptsubscript13subscriptitalic-ϕsuperscriptsubscripttopsubscript2italic-ϕsuperscriptsubscript2topsubscript3italic-ϕsuperscriptsubscript3topf(-x)= _j=1^3a_jφ(-w_j x)=a_2φ(-w_2 x)+a_3% φ(-w_3 x)f ( - x ) = ∑j = 13 aitalic_j ϕ ( - witalic_j⊤ x ) = a2 ϕ ( - w2⊤ x ) + a3 ϕ ( - w3⊤ x ) has the same sign as y. Again without loss of generality, we assume sgn(a2)=ysgnsubscript2 sgn(a_2)=ysgn ( a2 ) = y. If y′=sgn(f(x′))superscript′sgnsuperscript′y = sgn(f(x ))y′ = sgn ( f ( x′ ) ) and y′≠sgn(f(−x′))superscript′sgnsuperscript′y ≠ sgn(f(-x ))y′ ≠ sgn ( f ( - x′ ) ) hold, following the same discussion, we have that at least two aisubscripta_iaitalic_i’s have the same sign as y′=−ysuperscript′y =-y′ = - y, which contradicts the previous assumption that sgn(a1)=sgn(a2)=ysgnsubscript1sgnsubscript2 sgn(a_1)= sgn(a_2)=ysgn ( a1 ) = sgn ( a2 ) = y. ∎ A.1.2 Proof of Theorem 3.2 Additional notations: For training dataset (xi,yi)i=1nsuperscriptsubscriptsubscriptsubscript1\(x_i,y_i)\_i=1^n ( xitalic_i , yitalic_i ) i = 1n, we denote the signal of xisubscriptx_ixitalic_i by x¯i=[xi,1,xi,2]⊤∈±μ1,±μ2subscript¯superscriptsubscript1subscript2topplus-or-minussubscript1plus-or-minussubscript2 x_i=[x_i,1,x_i,2] ∈\± _1,± _2\over¯ start_ARG x end_ARGi = [ xitalic_i , 1 , xitalic_i , 2 ]⊤ ∈ ± μ1 , ± μ2 . For each μ∈±μ1,±μ2plus-or-minussubscript1plus-or-minussubscript2μ∈\± _1,± _2\μ ∈ ± μ1 , ± μ2 , define ℐμ=i∈[n]:x¯i=μsubscriptℐconditional-setdelimited-[]subscript¯I_μ=\i∈[n]: x_i=μ\Iitalic_μ = i ∈ [ n ] : over¯ start_ARG x end_ARGi = μ and nμ=|ℐμ|subscriptsubscriptℐn_μ=|I_μ|nitalic_μ = | Iitalic_μ |. Denote the new embedding of the i-th datapoint by zi=U⊤xi,i∈[n]formulae-sequencesubscriptsuperscripttopsubscriptdelimited-[]z_i=U x_i,i∈[n]zitalic_i = U⊤ xitalic_i , i ∈ [ n ]. Define νi=[μ2,−μ2,−μ1]⊤μi,i=1,2.formulae-sequencesubscriptsuperscriptsubscript2subscript2subscript1topsubscript12 _i=[ _2,- _2,- _1] _i, i=1,2.νitalic_i = [ μ2 , - μ2 , - μ1 ]⊤ μitalic_i , i = 1 , 2 . Then ν1=[0,0,−2],ν2=[2,−2,0]formulae-sequencesubscript1002subscript2220 _1=[0,0,-2], _2=[2,-2,0]ν1 = [ 0 , 0 , - 2 ] , ν2 = [ 2 , - 2 , 0 ], and ±ν1,±ν2plus-or-minussubscript1plus-or-minussubscript2\± _1,± _2\ ± ν1 , ± ν2 becomes the features for P with the new embedding. Denote the signal of zisubscriptz_izitalic_i by z¯i=[μ2,−μ2,−μ1]⊤x¯isubscript¯superscriptsubscript2subscript2subscript1topsubscript¯ z_i=[ _2,- _2,- _1] x_iover¯ start_ARG z end_ARGi = [ μ2 , - μ2 , - μ1 ]⊤ over¯ start_ARG x end_ARGi. Define the set of training data data=(xi,yi)i=1n:‖zi−z¯i‖≤ε2pnlogn, for all i∈[n].subscriptdataconditional-setsuperscriptsubscriptsubscriptsubscript1formulae-sequencenormsubscriptsubscript¯superscript2 for all delimited-[]G_ data=\\(x_i,y_i)\_i=1^n:\|z_i-% z_i\|≤ ^2 pn n, for all i∈[% n]\.Groman_data = ( xitalic_i , yitalic_i ) i = 1n : ∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n , for all i ∈ [ n ] . By Lemma A.2, ℙ((xi,yi)i=1n∈data)≥1−exp(−Ω(log2n))ℙsuperscriptsubscriptsubscriptsubscript1subscriptdata1Ωsuperscript2P(\(x_i,y_i)\_i=1^n _ data)% ≥ 1- (- ( ^2n))blackboard_P ( ( xitalic_i , yitalic_i ) i = 1n ∈ Groman_data ) ≥ 1 - exp ( - Ω ( log2 n ) ). We further define sets to separate the second-layer coefficients for the ease of discussion: Pos=j∈[m]:aj>0;Neg=j∈[m]:aj<0.formulae-sequencesubscriptPosconditional-setdelimited-[]subscript0subscriptNegconditional-setdelimited-[]subscript0J_ Pos=\j∈[m]:a_j>0\; _% Neg=\j∈[m]:a_j<0\.Jroman_Pos = j ∈ [ m ] : aitalic_j > 0 ; Jroman_Neg = j ∈ [ m ] : aitalic_j < 0 . We divide the index of neurons by its initialization and define e=j∈[m]:vj(0)=vinitesubscriptconditional-setdelimited-[]superscriptsubscript0subscriptinitJ_e=\j∈[m]:v_j^(0)=v_inite\Jitalic_e = j ∈ [ m ] : vitalic_j( 0 ) = vinit e for e∈Uniform(±1)3Uniformsuperscriptplus-or-minus13e (\± 1\)^3e ∈ Uniform ( ± 1 )3. We further define Pos,e=Pos∩e;Neg,e=Neg∩e.formulae-sequencesubscriptPossubscriptPossubscriptsubscriptNegsubscriptNegsubscriptJ_ Pos,e=J_ Pos∩% J_e; _ Neg,e=J_% Neg _e.Jroman_Pos , e = Jroman_Pos ∩ Jitalic_e ; Jroman_Neg , e = Jroman_Neg ∩ Jitalic_e . For each initialization of vj(0)superscriptsubscript0v_j^(0)vitalic_j( 0 ), we denote the set of datapoints which have positive inner product with it by ℐe,μ=i∈ℐμ:⟨e,zi⟩>0,e∈Uniform(±13),μ∈±μ1,±μ2.formulae-sequencesubscriptℐconditional-setsubscriptℐsubscript0formulae-sequenceUniformsuperscriptplus-or-minus13plus-or-minussubscript1plus-or-minussubscript2I_e,μ=\i _μ: e,z_i >0\, e% (\± 1\^3),μ∈\± _1,± _2\.Iitalic_e , μ = i ∈ Iitalic_μ : ⟨ e , zitalic_i ⟩ > 0 , e ∈ Uniform ( ± 1 3 ) , μ ∈ ± μ1 , ± μ2 . See 3.2 Proof. For brevity, we omit the subscript L in fLsubscriptf_Lfitalic_L in the proof below. At step t=00t=0t = 0: for each (xi,yi)subscriptsubscript(x_i,y_i)( xitalic_i , yitalic_i ), we have f(0)(xi)=∑j=1majϕ(⟨vj(0),zi⟩),superscript0subscriptsuperscriptsubscript1subscriptitalic-ϕsuperscriptsubscript0subscriptf^(0)(x_i)= _j=1^ma_jφ( v_j^(0),z_i ),f( 0 ) ( xitalic_i ) = ∑j = 1m aitalic_j ϕ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) , where ajϕ(⟨vj(0),zi⟩),j∈[m]subscriptitalic-ϕsuperscriptsubscript0subscriptdelimited-[]a_jφ( v_j^(0),z_i ),j∈[m]aitalic_j ϕ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) , j ∈ [ m ] are bounded random variables with zero mean. The absolute bound is |ajϕ(⟨vj(0),zi⟩)|≤3vinitm(maxi‖z¯i‖+ε2p/nlogn)≤5vinit/m,subscriptitalic-ϕsuperscriptsubscript0subscript3subscriptinitsubscriptnormsubscript¯superscript25subscriptinit|a_jφ( v_j^(0),z_i )|≤ 3v_init% m( _i\| z_i\|+ ^2 p/n n)≤ 5v_% init/ m,| aitalic_j ϕ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) | ≤ divide start_ARG square-root start_ARG 3 end_ARG vinit end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( maxitalic_i ∥ over¯ start_ARG z end_ARGi ∥ + ε2 square-root start_ARG p / n end_ARG log n ) ≤ 5 vinit / square-root start_ARG m end_ARG , where the first inequality uses Lemma A.2 and the second inequality uses maxi‖z¯i‖=22subscriptnormsubscript¯22 _i\| z_i\|=2 2maxitalic_i ∥ over¯ start_ARG z end_ARGi ∥ = 2 square-root start_ARG 2 end_ARG and Assumption (A1). Then by Hoeffding’s inequality and law of total probability, ℙ(|f(0)(xi)|>t)≤ℙ(|f(0)(xi)|>t|data)+ℙ(data)≤2exp(−2t225vinit2)+exp(−Ω(log2n)).ℙsuperscript0subscriptℙsuperscript0subscriptconditionalsubscriptdataℙsubscriptdata22superscript225superscriptsubscriptinit2Ωsuperscript2P(|f^(0)(x_i)|>t) (|f^(0)(x_i)|>t|G_% data)+P(G_ data)≤ 2% (- 2t^225v_init^2 )+ (- ( ^2n% )).blackboard_P ( | f( 0 ) ( xitalic_i ) | > t ) ≤ blackboard_P ( | f( 0 ) ( xitalic_i ) | > t | Groman_data ) + blackboard_P ( Groman_data ) ≤ 2 exp ( - divide start_ARG 2 t2 end_ARG start_ARG 25 vinit2 end_ARG ) + exp ( - Ω ( log2 n ) ) . Let t=vinitlognsubscriptinitt=v_init nt = vinit log n. It follows that ℙ(maxi∈[n]|f(0)(xi)|≤t)≥1−∑i=1nℙ(|f(0)(xi)|>t)≥1−2nexp(−2log2n25)−nexp(−Ω(log2n))=1−exp(−Ω(log2n)).ℙsubscriptdelimited-[]superscript0subscript1superscriptsubscript1ℙsuperscript0subscript122superscript225Ωsuperscript21Ωsuperscript2 splitP( _i∈[n]|f^(0)(x_i)|≤ t)&≥ 1- _i=1% ^nP(|f^(0)(x_i)|>t)\\ &≥ 1-2n (- 2 ^2n25)-n (- ( ^2n))=1- (-% ( ^2n)). splitstart_ROW start_CELL blackboard_P ( maxitalic_i ∈ [ n ] | f( 0 ) ( xitalic_i ) | ≤ t ) end_CELL start_CELL ≥ 1 - ∑i = 1n blackboard_P ( | f( 0 ) ( xitalic_i ) | > t ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ 1 - 2 n exp ( - divide start_ARG 2 log2 n end_ARG start_ARG 25 end_ARG ) - n exp ( - Ω ( log2 n ) ) = 1 - exp ( - Ω ( log2 n ) ) . end_CELL end_ROW (1) We define a set of training data and initial weights: =((xi,yi)i=1n,a,V(0)):(xi,yi)i=1n∈data,condition (1) and all conditions in Lemma A.1 and A.4 hold.conditional-setsuperscriptsubscriptsubscriptsubscript1superscript0superscriptsubscriptsubscriptsubscript1subscriptdatacondition (1) and all conditions in Lemma A.1 and A.4 hold splitG= \(\(x_i,y_i)\_i=1^n,a,V^(0)):\(&x% _i,y_i)\_i=1^n _ data,condition% eq:last-condition and \\ &all conditions in Lemma lem:init-neuron-alignment and lem: % xor-initialization_properties hold \. splitstart_ROW start_CELL G = ( ( xitalic_i , yitalic_i ) i = 1n , a , V( 0 ) ) : ( end_CELL start_CELL xitalic_i , yitalic_i ) i = 1n ∈ Groman_data , condition ( ) and end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL all conditions in Lemma and hold . end_CELL end_ROW Combining (1), Lemma A.1, A.2, and A.4 then applying the union bound, we have ℙ(((xi,yi)i=1n,a,V(0))∈)≥1−exp(−Ω(log2n))−O(1n2)−O(1n4)=1−O(1n2).ℙsuperscriptsubscriptsubscriptsubscript1superscript01Ωsuperscript21superscript21superscript411superscript2P((\(x_i,y_i)\_i=1^n,a,V^(0)) )≥ 1- (-% ( ^2n))-O( 1n^2)-O( 1n^4)=1-O( 1n^2).blackboard_P ( ( ( xitalic_i , yitalic_i ) i = 1n , a , V( 0 ) ) ∈ G ) ≥ 1 - exp ( - Ω ( log2 n ) ) - O ( divide start_ARG 1 end_ARG start_ARG n2 end_ARG ) - O ( divide start_ARG 1 end_ARG start_ARG n4 end_ARG ) = 1 - O ( divide start_ARG 1 end_ARG start_ARG n2 end_ARG ) . Denote li(t)=l(yi,f(t)(xi))=exp(−yif(t)(xi))subscriptsuperscriptsubscriptsuperscriptsubscriptsubscriptsuperscriptsubscriptl^(t)_i=l(y_i,f^(t)(x_i))= (-y_if^(t)(x_i))l( t )i = l ( yitalic_i , f( t ) ( xitalic_i ) ) = exp ( - yitalic_i f( t ) ( xitalic_i ) ). Conditioning on GG, the ratio between the maximum and minimum loss is bounded by: R(0):=maxi∈[n]li(0)mini∈[n]li(0)≤exp(2vinitlogn).assignsuperscript0subscriptdelimited-[]subscriptsuperscript0subscriptdelimited-[]subscriptsuperscript02subscriptinitR^(0):= _i∈[n]l^(0)_i _i∈[n]l^(0)_i≤ (2% v_init n).R( 0 ) := divide start_ARG maxitalic_i ∈ [ n ] l( 0 )i end_ARG start_ARG minitalic_i ∈ [ n ] l( 0 )i end_ARG ≤ exp ( 2 vinit log n ) . (2) For each j, below we will analyze the gradient descent update for all possible combinations of aj(0),vj(0)superscriptsubscript0superscriptsubscript0a_j^(0),v_j^(0)aitalic_j( 0 ) , vitalic_j( 0 ) conditioning on the event GG. (1) When aj>0subscript0a_j>0aitalic_j > 0: If vj(0)=vinit[1,1,1]superscriptsubscript0subscriptinit111v_j^(0)=v_init[1,1,1]vitalic_j( 0 ) = vinit [ 1 , 1 , 1 ], then according to Lemma A.1, we have ℐ[1,1,1],+μ1=∅;ℐ[1,1,1],−μ1=ℐ−μ1;||ℐ[1,1,1],μ|−nμ2|≤nlogn,μ=±μ2.formulae-sequencesubscriptℐ111subscript1formulae-sequencesubscriptℐ111subscript1subscriptℐsubscript1formulae-sequencesubscriptℐ111subscript2plus-or-minussubscript2I_[1,1,1],+ _1= ; _[1,1,1],- _1% =I_- _1; ||I_[1,1,1],μ|- n_μ% 2 |≤ n n,μ=± _2.I[ 1 , 1 , 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ ; I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; | | I[ 1 , 1 , 1 ] , μ | - divide start_ARG nitalic_μ end_ARG start_ARG 2 end_ARG | ≤ square-root start_ARG n log n end_ARG , μ = ± μ2 . (3) Recall that the gradient descent update of vj(t)superscriptsubscriptv_j^(t)vitalic_j( t ) is vj(t+1)=vj(t)+αnaj∑i=1nyiexp(−yif(t)(xi))ϕ′(⟨vj(t),zi⟩)zi.superscriptsubscript1superscriptsubscriptsubscriptsuperscriptsubscript1subscriptsubscriptsuperscriptsubscriptsuperscriptitalic-ϕ′subscriptsubscriptsubscriptv_j^(t+1)=v_j^(t)+ αna_j _i=1^ny_i (-y_if% ^(t)(x_i))φ ( v_j^(t),z_i )z_i.vitalic_j( t + 1 ) = vitalic_j( t ) + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i = 1n yitalic_i exp ( - yitalic_i f( t ) ( xitalic_i ) ) ϕ′ ( ⟨ vitalic_j( t ) , zitalic_i ⟩ ) zitalic_i . (4) It follows that vj,3(1)=vj,3(0)+αnaj∑i=1nyili(0)ϕ′(⟨vj(0),zi⟩)zi,3=vj,3(0)+αnaj∑i∈ℐ−μ1yili(0)zi,3+αnaj∑i∈ℐ[1,1,1],μ2∪ℐ[1,1,1],−μ2yili(0)zi,3≥vinit+2αnm∑i∈ℐ−μ1li(0)−O(αmmaxili(0)ε2pnlogn)≥vinit+1.9α|ℐ−μ1|nmexp(−vinitlogn)≥vinit+1.9α4m(1−4logn)(1−vinitlogn)≥vinit+2α5m,superscriptsubscript31superscriptsubscript30subscriptsuperscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscript0subscriptsubscript3superscriptsubscript30subscriptsubscriptsubscriptℐsubscript1subscriptsubscriptsuperscript0subscript3subscriptsubscriptsubscriptℐ111subscript2subscriptℐ111subscript2subscriptsubscriptsuperscript0subscript3subscriptinit2subscriptsubscriptℐsubscript1subscriptsuperscript0subscriptsubscriptsuperscript0superscript2subscriptinit1.9subscriptℐsubscript1subscriptinitsubscriptinit1.94141subscriptinitsubscriptinit25 splitv_j,3^(1)&=v_j,3^(0)+ αna_j _i=1^ny% _il^(0)_iφ ( v_j^(0),z_i )z_i,3\\ &=v_j,3^(0)+ αna_j _i _- _1y_il^% (0)_iz_i,3+ αna_j _i _[1,1,1], _2% _[1,1,1],- _2y_il^(0)_iz_i,3\\ &≥ v_init+ 2αn m _i _- _1% l^(0)_i-O( α m _il^(0)_i ^2 % pn n)\\ &≥ v_init+ 1.9α|I_- _1|n m (% -v_init n)≥ v_init+ 1.9α4 m(1-% 4 n)(1-v_init n)\\ &≥ v_init+ 2α5 m, splitstart_ROW start_CELL vitalic_j , 3( 1 ) end_CELL start_CELL = vitalic_j , 3( 0 ) + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = vitalic_j , 3( 0 ) + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∪ I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit + divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i - O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG maxitalic_i l( 0 )i ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit + divide start_ARG 1.9 α | I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG exp ( - vinit log n ) ≥ vinit + divide start_ARG 1.9 α end_ARG start_ARG 4 square-root start_ARG m end_ARG end_ARG ( 1 - divide start_ARG 4 end_ARG start_ARG log n end_ARG ) ( 1 - vinit log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit + divide start_ARG 2 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG , end_CELL end_ROW (5) where the first inequality uses Lemma A.2 and z¯i,3=0,i∈ℐ±μ2formulae-sequencesubscript¯30subscriptℐplus-or-minussubscript2 z_i,3=0,i _± _2over¯ start_ARG z end_ARGi , 3 = 0 , i ∈ I± μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT; the second inequality uses Assumption (A1), (A3) and (A4); the third inequality uses Lemma A.4 and exp(x)≥1+x1 (x)≥ 1+xexp ( x ) ≥ 1 + x. Further for l=1,2,12l=1,2,l = 1 , 2 , we have |vj,l(1)−vj,l(0)|=|αnaj∑i∈ℐ−μ1yili(0)zi,l+αnaj∑i∈ℐ[1,1,1],μ2∪ℐ[1,1,1],−μ2yili(0)zi,l|=αnaj|∑i∈ℐ−μ1∪ℐ[1,1,1],μ2∪ℐ[1,1,1],−μ2yili(0)(zi,l−z¯i,l)−[∑i∈ℐ[1,1,1],μ2li(0)z¯i,l+∑i∈ℐ[1,1,1],−μ2li(0)z¯i,l]|≤αmexp(vinitlogn)ε2pnlogn+2αnm|∑i∈ℐ[1,1,1],μ2li(0)−∑i∈ℐ[1,1,1],−μ2li(0)|≤αmexp(vinitlogn)ε2pnlogn+2αnmexp(vinitlogn)(n8+n2logn+nlogn−exp(−2vinitlogn)(n8−n2logn−nlogn))≤Cαε2pmnlogn+Cαnm(nlogn+vinitnlogn)≤Cαmlogn, split |v_j,l^(1)-v_j,l^(0) |&= | αn% a_j _i _- _1y_il^(0)_iz_i,l+ αn% a_j _i _[1,1,1], _2 _[1,1,1],- _2% y_il^(0)_iz_i,l |\\ &= αna_j | _i _- _1 _% [1,1,1], _2 _[1,1,1],- _2y_il^(0)_i(z_i,l-% z_i,l)\\ & - [ _i _[1,1,1], _2l^(% 0)_i z_i,l+ _i _[1,1,1],- _2l^(0)_i % z_i,l ] |\\ &≤ α m (v_init n) ^2 % pn n+ 2αn m | _i _[1,1% ,1], _2l^(0)_i- _i _[1,1,1],- _2l^(0)_i% |\\ &≤ α m (v_init n) ^2 % pn n+ 2αn m (v_init n) (% n8+ n2 n\\ & + n n- (-2v_init n)( n8- n% 2 n- n n) )\\ &≤ C α ^2 p mn n+C αn% m ( n n+v_initn n )≤ C % α m n, splitstart_ROW start_CELL | vitalic_j , l( 1 ) - vitalic_j , l( 0 ) | end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , l + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∪ I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , l | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j | ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∪ I[ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∪ I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i ( zitalic_i , l - over¯ start_ARG z end_ARGi , l ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL - [ ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i over¯ start_ARG z end_ARGi , l + ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i over¯ start_ARG z end_ARGi , l ] | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG exp ( vinit log n ) ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n + divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG | ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i - ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG exp ( vinit log n ) ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n + divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG exp ( vinit log n ) ( divide start_ARG n end_ARG start_ARG 8 end_ARG + divide start_ARG n end_ARG start_ARG 2 log n end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + square-root start_ARG n log n end_ARG - exp ( - 2 vinit log n ) ( divide start_ARG n end_ARG start_ARG 8 end_ARG - divide start_ARG n end_ARG start_ARG 2 log n end_ARG - square-root start_ARG n log n end_ARG ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ C divide start_ARG α ε2 square-root start_ARG p end_ARG end_ARG start_ARG square-root start_ARG m n end_ARG end_ARG log n + C divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ( divide start_ARG n end_ARG start_ARG log n end_ARG + vinit n log n ) ≤ C divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG , end_CELL end_ROW (6) where the first inequality uses (2) and z¯i,l=−z¯j,lsubscript¯subscript¯ z_i,l=- z_j,lover¯ start_ARG z end_ARGi , l = - over¯ start_ARG z end_ARGj , l for i∈ℐ[1,1,1],μ2,j∈ℐ[1,1,1],−μ2formulae-sequencesubscriptℐ111subscript2subscriptℐ111subscript2i _[1,1,1], _2,j _[1,1,1],- _2i ∈ I[ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , j ∈ I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT; the second inequality uses (2), (3), and (B4) in Lemma A.4; the third inequality uses Assumption (A1)-(A5); and the last inequality uses Assumption (A1), (A3) and (A4). For a datapoint (x,y)∼Psimilar-to(x,y) P( x , y ) ∼ P, define z=[z1,z2,z3]⊤=U⊤xsuperscriptsubscript1subscript2subscript3topsuperscripttopz=[z_1,z_2,z_3] =U xz = [ z1 , z2 , z3 ]⊤ = U⊤ x. Applying Lemma A.3 we obtain ℙ(‖z−z¯‖≤ε2pnlogn)≥1−exp(−Ω(log2n)).ℙnorm¯superscript21Ωsuperscript2P(\|z- z\|≤ ^2 pn n)≥ 1-% (- ( ^2n)).blackboard_P ( ∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) ≥ 1 - exp ( - Ω ( log2 n ) ) . Conditioning on ‖z−z¯‖≤ε2pnlogn,norm¯superscript2\|z- z\|≤ ^2 pn n,∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n , (7) if xsignal=−μ1subscriptsignalsubscript1x_signal=- _1xsignal = - μ1, we combine (5) and (6) and have ⟨vj(1),z⟩=⟨vj(1),z¯⟩+⟨vj(1),z−z¯⟩≥2(vinit+2α5m)−Cvinitε2pnlogn≥32(vinit+2α5m).superscriptsubscript1superscriptsubscript1¯superscriptsubscript1¯2subscriptinit25subscriptinitsuperscript232subscriptinit25 v_j^(1),z = v_j^(1), z + % v_j^(1),z- z ≥ 2(v_init+ 2α5 m% )-Cv_init ^2 pn n≥ 32(v_% init+ 2α5 m).⟨ vitalic_j( 1 ) , z ⟩ = ⟨ vitalic_j( 1 ) , over¯ start_ARG z end_ARG ⟩ + ⟨ vitalic_j( 1 ) , z - over¯ start_ARG z end_ARG ⟩ ≥ 2 ( vinit + divide start_ARG 2 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG ) - C vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ≥ divide start_ARG 3 end_ARG start_ARG 2 end_ARG ( vinit + divide start_ARG 2 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG ) . (8) Further for any pair j1,j2subscript1subscript2j_1,j_2j1 , j2 with vj1(0)=vj2(0)=vinit[1,1,1]superscriptsubscriptsubscript10superscriptsubscriptsubscript20subscriptinit111v_j_1^(0)=v_j_2^(0)=v_init[1,1,1]vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) = vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) = vinit [ 1 , 1 , 1 ] and aj1>0,aj2<0formulae-sequencesubscriptsubscript10subscriptsubscript20a_j_1>0,a_j_2<0aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 , aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < 0: If ⟨vj2(1),z⟩<0superscriptsubscriptsubscript210 v_j_2^(1),z <0⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ < 0, it follows that z3αnm∑i=1nyili(0)ϕ′(⟨vj2(0),zi⟩)zi,3=−⟨vj2(1),z⟩+z3vj2,3(0)+∑l=12zlvj1,l(1)≥z3vinit−∑l=12|zl−z¯l||vj2,l(1)|≥z3vinit−2ε2pnlogn(vinit+Cαmlogn)≥z32vinit,subscript3superscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscriptsubscript20subscriptsubscript3superscriptsubscriptsubscript21subscript3superscriptsubscriptsubscript230superscriptsubscript12subscriptsuperscriptsubscriptsubscript11subscript3subscriptinitsuperscriptsubscript12subscriptsubscript¯superscriptsubscriptsubscript21subscript3subscriptinit2superscript2subscriptinitsubscript32subscriptinit splitz_3 αn m _i=1^ny_il^(0)_iφ^% ( v_j_2^(0),z_i )z_i,3&=- v_j_2^(1)% ,z +z_3v_j_2,3^(0)+ _l=1^2z_lv_j_1,l^(1)\\ &≥ z_3v_init- _l=1^2|z_l- z_l||v_j_2,l^(1)% |\\ &≥ z_3v_init-2 ^2 pn n (v_% init+C α m n )≥ z_32v_% init, splitstart_ROW start_CELL z3 divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 end_CELL start_CELL = - ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ + z3 vitalic_j start_POSTSUBSCRIPT 2 , 3 end_POSTSUBSCRIPT( 0 ) + ∑l = 12 zitalic_l vitalic_j start_POSTSUBSCRIPT 1 , l end_POSTSUBSCRIPT( 1 ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ z3 vinit - ∑l = 12 | zitalic_l - over¯ start_ARG z end_ARGl | | vitalic_j start_POSTSUBSCRIPT 2 , l end_POSTSUBSCRIPT( 1 ) | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ z3 vinit - 2 ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ( vinit + C divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) ≥ divide start_ARG z3 end_ARG start_ARG 2 end_ARG vinit , end_CELL end_ROW (9) where the first inequality uses vj2,3(0)=vinitsuperscriptsubscriptsubscript230subscriptinitv_j_2,3^(0)=v_initvitalic_j start_POSTSUBSCRIPT 2 , 3 end_POSTSUBSCRIPT( 0 ) = vinit and z¯l=0,l=1,2formulae-sequencesubscript¯012 z_l=0,l=1,2over¯ start_ARG z end_ARGl = 0 , l = 1 , 2; the second inequality uses (7); and the last inequality uses condition (7), z¯3=2subscript¯32 z_3=2over¯ start_ARG z end_ARG3 = 2, and Assumption (A1), (A3) and (A4). Combining (5) and (9), we have αnm∑i=1nyili(0)ϕ′(⟨vj2(0),zi⟩)zi,3≥maxvinit2,2α5m,superscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscriptsubscript20subscriptsubscript3subscriptinit225 αn m _i=1^ny_il^(0)_iφ ( v_% j_2^(0),z_i )z_i,3≥ \ v_init2, 2% α5 m\,divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 ≥ max divide start_ARG vinit end_ARG start_ARG 2 end_ARG , divide start_ARG 2 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG , which together with (6) yield that aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=aj1⟨vj1(1),z⟩=1m[⟨vj1(0),z⟩+αnmz3∑i=1nyili(0)ϕ′(⟨vj2(0),zi⟩)zi,3+∑l=12(vj1,l(1)−vj1,l(0))zl]≥1m[vinit+maxvinit2,2α5m−Cvinitε2pnlogn−Cαmlognε2pnlogn]≥1m[vinit+vinit4+α5m−Cvinitε2pnlogn−Cαmlognε2pnlogn]≥vinitm+α10m,subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript21subscriptsubscript1superscriptsubscriptsubscript111delimited-[]superscriptsubscriptsubscript10subscript3superscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscriptsubscript20subscriptsubscript3superscriptsubscript12superscriptsubscriptsubscript11superscriptsubscriptsubscript10subscript1delimited-[]subscriptinitsubscriptinit225subscriptinitsuperscript2superscript21delimited-[]subscriptinitsubscriptinit45subscriptinitsuperscript2superscript2subscriptinit10 split&a_j_1φ( v_j_1^(1),z )+a_j_2φ% ( v_j_2^(1),z )=a_j_1 v_j_1^(1),z% \\ &= 1 m [ v_j_1^(0),z + αn% mz_3 _i=1^ny_il^(0)_iφ ( v_j_2^(% 0),z_i )z_i,3+ _l=1^2(v_j_1,l^(1)-v_j_1,l^(0))z_% l ]\\ &≥ 1 m [v_init+ \ v_init2,% 2α5 m\-Cv_init ^2 pn% n-C α m n ^2 pn n% ]\\ &≥ 1 m [v_init+ v_init4+ % α5 m-Cv_init ^2 pn n-C% α m n ^2 pn n ]\\ &≥ v_init m+ α10m, splitstart_ROW start_CELL end_CELL start_CELL aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) , z ⟩ + divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG z3 ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 + ∑l = 12 ( vitalic_j start_POSTSUBSCRIPT 1 , l end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 1 , l end_POSTSUBSCRIPT( 0 ) ) zitalic_l ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ vinit + max divide start_ARG vinit end_ARG start_ARG 2 end_ARG , divide start_ARG 2 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG - C vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n - C divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ vinit + divide start_ARG vinit end_ARG start_ARG 4 end_ARG + divide start_ARG α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG - C vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n - C divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ divide start_ARG vinit end_ARG start_ARG square-root start_ARG m end_ARG end_ARG + divide start_ARG α end_ARG start_ARG 10 m end_ARG , end_CELL end_ROW (10) where the second inequality uses max(x,y)≥(x+y)/22 (x,y)≥(x+y)/2max ( x , y ) ≥ ( x + y ) / 2 and the last inequality uses the fact that n is sufficiently large. If ⟨vj2(1),z⟩>0superscriptsubscriptsubscript210 v_j_2^(1),z >0⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ > 0, we have aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=1m⟨vj1(1)−vj2(1),z⟩=1m⟨vj1(1)−vj1(0),z⟩−1m⟨vj2(1)−vj2(0),z⟩=1m[2αnmz3∑i=1nyili(0)ϕ′(⟨vj1(0),zi⟩)zi,3+∑l=12(vj1,l(1)−vj2,l(1))(zl−z¯l)]≥1m[4α5m−Cαmlognε2pnlogn]≥2α5m,subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript211superscriptsubscriptsubscript11superscriptsubscriptsubscript211superscriptsubscriptsubscript11superscriptsubscriptsubscript101superscriptsubscriptsubscript21superscriptsubscriptsubscript201delimited-[]2subscript3superscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscriptsubscript10subscriptsubscript3superscriptsubscript12superscriptsubscriptsubscript11superscriptsubscriptsubscript21subscriptsubscript¯1delimited-[]45superscript225 split&a_j_1φ( v_j_1^(1),z )+a_j_2φ% ( v_j_2^(1),z )= 1 m v_j_1^(1)% -v_j_2^(1),z \\ &= 1 m v_j_1^(1)-v_j_1^(0),z - % 1 m v_j_2^(1)-v_j_2^(0),z \\ &= 1 m [2 αn mz_3 _i=1^ny_il^% (0)_iφ ( v_j_1^(0),z_i )z_i,3+ _l=1% ^2(v_j_1,l^(1)-v_j_2,l^(1))(z_l- z_l) ]\\ &≥ 1 m [ 4α5 m-C α m% n ^2 pn n ]≥ 2α5m,% splitstart_ROW start_CELL end_CELL start_CELL aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) , z ⟩ - divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) , z ⟩ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ 2 divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG z3 ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 + ∑l = 12 ( vitalic_j start_POSTSUBSCRIPT 1 , l end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 , l end_POSTSUBSCRIPT( 1 ) ) ( zitalic_l - over¯ start_ARG z end_ARGl ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ divide start_ARG 4 α end_ARG start_ARG 5 square-root start_ARG m end_ARG end_ARG - C divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ] ≥ divide start_ARG 2 α end_ARG start_ARG 5 m end_ARG , end_CELL end_ROW (11) where the second equation uses vj1(0)=vj2(0)superscriptsubscriptsubscript10superscriptsubscriptsubscript20v_j_1^(0)=v_j_2^(0)vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) = vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ); the third equation uses (4); the first inequality uses (5); and the second inequality uses Assumption (A1). Combining (10) and (11), it follows that aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)≥2α5msubscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript2125a_j_1φ( v_j_1^(1),z )+a_j_2φ( v_j_% 2^(1),z )≥ 2α5maitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) ≥ divide start_ARG 2 α end_ARG start_ARG 5 m end_ARG (12) when vj1(0)=vj2(0)=vinit[1,1,1]superscriptsubscriptsubscript10superscriptsubscriptsubscript20subscriptinit111v_j_1^(0)=v_j_2^(0)=v_init[1,1,1]vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) = vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) = vinit [ 1 , 1 , 1 ] and xsignal=−μ1subscriptsignalsubscript1x_signal=- _1xsignal = - μ1. If xsignal=+μ1subscriptsignalsubscript1x_signal=+ _1xsignal = + μ1, following the same procedure, we obtain that ⟨vj(1),z⟩<0superscriptsubscript10 v_j^(1),z <0⟨ vitalic_j( 1 ) , z ⟩ < 0 for aj>0subscript0a_j>0aitalic_j > 0. For aj<0subscript0a_j<0aitalic_j < 0, similar to (5), we have vj,3(1)=vj,3(0)+αnaj∑i∈ℐ−μ1yili(0)zi,3+αnaj∑i∈ℐ[1,1,1],μ2∪ℐ[1,1,1],−μ2yili(0)zi,3≥vinit−2αnm∑i∈ℐ−μ1li(0)−O(αmmaxili(0)ε2pnlogn)≥vinit−2.1α|ℐ−μ1|nmexp(vinitlogn)≥vinit−2.1α4m(1+4logn)(1+2vinitlogn)≥vinit−3α4m≥vinit4,superscriptsubscript31superscriptsubscript30subscriptsubscriptsubscriptℐsubscript1subscriptsubscriptsuperscript0subscript3subscriptsubscriptsubscriptℐ111subscript2subscriptℐ111subscript2subscriptsubscriptsuperscript0subscript3subscriptinit2subscriptsubscriptℐsubscript1subscriptsuperscript0subscriptsubscriptsuperscript0superscript2subscriptinit2.1subscriptℐsubscript1subscriptinitsubscriptinit2.141412subscriptinitsubscriptinit34subscriptinit4 splitv_j,3^(1)&=v_j,3^(0)+ αna_j _i∈% I_- _1y_il^(0)_iz_i,3+ αna_j _i∈% I_[1,1,1], _2 _[1,1,1],- _2y_il^(0)_% iz_i,3\\ &≥ v_init- 2αn m _i _- _1% l^(0)_i-O( α m _il^(0)_i ^2 % pn n)\\ &≥ v_init- 2.1α|I_- _1|n m (% v_init n)≥ v_init- 2.1α4 m(1+ % 4 n)(1+2v_init n)\\ &≥ v_init- 3α4 m≥ v_init4,% splitstart_ROW start_CELL vitalic_j , 3( 1 ) end_CELL start_CELL = vitalic_j , 3( 0 ) + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 + divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i ∈ I start_POSTSUBSCRIPT [ 1 , 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∪ I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit - divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i - O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG maxitalic_i l( 0 )i ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit - divide start_ARG 2.1 α | I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG exp ( vinit log n ) ≥ vinit - divide start_ARG 2.1 α end_ARG start_ARG 4 square-root start_ARG m end_ARG end_ARG ( 1 + divide start_ARG 4 end_ARG start_ARG log n end_ARG ) ( 1 + 2 vinit log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ vinit - divide start_ARG 3 α end_ARG start_ARG 4 square-root start_ARG m end_ARG end_ARG ≥ divide start_ARG vinit end_ARG start_ARG 4 end_ARG , end_CELL end_ROW (13) where the last inequality comes from Assumption (A4). Then ⟨vj(1),z⟩<0superscriptsubscript10 v_j^(1),z <0⟨ vitalic_j( 1 ) , z ⟩ < 0 also hold for aj<0subscript0a_j<0aitalic_j < 0 following the same analysis. Thus we have aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=0subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript210a_j_1φ( v_j_1^(1),z )+a_j_2φ( v_j_% 2^(1),z )=0aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = 0 (14) when vj1(0)=vj2(0)=vinit[1,1,1]superscriptsubscriptsubscript10superscriptsubscriptsubscript20subscriptinit111v_j_1^(0)=v_j_2^(0)=v_init[1,1,1]vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) = vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) = vinit [ 1 , 1 , 1 ] and xsignal=+μ1subscriptsignalsubscript1x_signal=+ _1xsignal = + μ1. If xsignal∈±μ2subscriptsignalplus-or-minussubscript2x_signal∈\± _2\xsignal ∈ ± μ2 , combining (5) and (6), we have |⟨vj(1),z⟩|≤|⟨vj(0),z¯⟩|+|⟨vj(0),z−z¯⟩|+|⟨vj(1)−vj(0),z¯⟩|+|⟨vj(1)−vj(0),z−z¯⟩|≤0+vinitε2pnlogn+0+Cαmε2pnlogn≤2vinitε2pnlogn,superscriptsubscript1superscriptsubscript0¯superscriptsubscript0¯superscriptsubscript1superscriptsubscript0¯superscriptsubscript1superscriptsubscript0¯0subscriptinitsuperscript20superscript22subscriptinitsuperscript2 split| v_j^(1),z |&≤| v_j^(0), % z |+| v_j^(0),z- z |+| v_j^(1)-v_% j^(0), z |+| v_j^(1)-v_j^(0),z- z% |\\ &≤ 0+v_init ^2 pn n+0+C α% m ^2 pn n≤ 2v_init% ^2 pn n, splitstart_ROW start_CELL | ⟨ vitalic_j( 1 ) , z ⟩ | end_CELL start_CELL ≤ | ⟨ vitalic_j( 0 ) , over¯ start_ARG z end_ARG ⟩ | + | ⟨ vitalic_j( 0 ) , z - over¯ start_ARG z end_ARG ⟩ | + | ⟨ vitalic_j( 1 ) - vitalic_j( 0 ) , over¯ start_ARG z end_ARG ⟩ | + | ⟨ vitalic_j( 1 ) - vitalic_j( 0 ) , z - over¯ start_ARG z end_ARG ⟩ | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ 0 + vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n + 0 + C divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ≤ 2 vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n , end_CELL end_ROW where the last inequality uses Assumption (A3) and (A4). Thus |aj⟨vj(1),z⟩|≤2vinitε2pnmlogn≤2vinitmlogn.subscriptsuperscriptsubscript12subscriptinitsuperscript22subscriptinit|a_j v_j^(1),z |≤ 2v_init ^2% pnm n≤ 2v_init m n.| aitalic_j ⟨ vitalic_j( 1 ) , z ⟩ | ≤ 2 vinit ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n m end_ARG end_ARG log n ≤ divide start_ARG 2 vinit end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG . (15) Note that neurons initialized with vinit[i,i,k],i,k∈±1subscriptinitplus-or-minus1v_init[i,i,k],i,k∈\± 1\vinit [ i , i , k ] , i , k ∈ ± 1 share very similar dynamics and following the same procedure, specifically, if k=+11k=+1k = + 1 (resp. −11-1- 1), the neurons align well with −μ1subscript1- _1- μ1 (resp. +μ1subscript1+ _1+ μ1). Additionally, the neurons do not align well with ±μ2plus-or-minussubscript2± _2± μ2 for both i=+11i=+1i = + 1 and i=−11i=-1i = - 1. For brevity, we omit the analysis for vj(0)=vinit[i,i,k],i,k∈±1\vinit[1,1,1]formulae-sequencesuperscriptsubscript0subscriptinit -or-minus1subscriptinit111v_j^(0)=v_init[i,i,k],i,k∈\± 1\ \v_init[% 1,1,1]\vitalic_j( 0 ) = vinit [ i , i , k ] , i , k ∈ ± 1 \ vinit [ 1 , 1 , 1 ] . Next we analyze the one-step update of neuron vjsubscriptv_jvitalic_j with initialization vinit[1,−1,1]subscriptinit111v_init[1,-1,1]vinit [ 1 , - 1 , 1 ]. (2) If vj(0)=vinit[1,−1,1]superscriptsubscript0subscriptinit111v_j^(0)=v_init[1,-1,1]vitalic_j( 0 ) = vinit [ 1 , - 1 , 1 ], then according to Lemma A.1, we have ℐ[1,−1,1],+μ1=∅;ℐ[1,−1,1],−μ1=ℐ−μ1;ℐ[1,−1,1],μ2=ℐ+μ2;ℐ[1,−1,1],−μ2=∅.formulae-sequencesubscriptℐ111subscript1formulae-sequencesubscriptℐ111subscript1subscriptℐsubscript1formulae-sequencesubscriptℐ111subscript2subscriptℐsubscript2subscriptℐ111subscript2I_[1,-1,1],+ _1= ; _[1,-1,1],- _1% =I_- _1; _[1,-1,1], _2=I_+% _2; _[1,-1,1],- _2= .I[ 1 , - 1 , 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ ; I[ 1 , - 1 , 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; I[ 1 , - 1 , 1 ] , μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; I[ 1 , - 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∅ . (16) Similar to (5), we have |vj,3(1)−(vj,3(0)+α2aj)|=|αnaj∑i=1nyili(0)ϕ′(⟨vj(0),zi⟩)zi,3−α2aj|=|αnaj[∑i∈ℐ−μ1yili(0)zi,3+∑i∈ℐ+μ2yili(0)zi,3]−α2aj|=|αnm[∑i∈ℐ−μ1li(0)zi,3−∑i∈ℐ+μ2li(0)zi,3]−α2m|=|αnm[∑i∈ℐ−μ1li(0)z¯i,3+∑i∈ℐ−μ1li(0)(zi,3−z¯i,3)−∑i∈ℐ+μ2li(0)(zi,3−z¯i,3)]−α2m|≤2αnm|n4−n−μ1exp(−vinitlogn)|+αnm(n−μ1+n+μ2)exp(vinitlogn)ε2pnlogn=O(αm(ε2pn+vinit)logn)=O(αmlogn),superscriptsubscript31superscriptsubscript302subscriptsubscriptsuperscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscript0subscriptsubscript32subscriptsubscriptdelimited-[]subscriptsubscriptℐsubscript1subscriptsubscriptsuperscript0subscript3subscriptsubscriptℐsubscript2subscriptsubscriptsuperscript0subscript32subscriptdelimited-[]subscriptsubscriptℐsubscript1subscriptsuperscript0subscript3subscriptsubscriptℐsubscript2subscriptsuperscript0subscript32delimited-[]subscriptsubscriptℐsubscript1subscriptsuperscript0subscript¯3subscriptsubscriptℐsubscript1subscriptsuperscript0subscript3subscript¯3subscriptsubscriptℐsubscript2subscriptsuperscript0subscript3subscript¯3224subscriptsubscript1subscriptinitsubscriptsubscript1subscriptsubscript2subscriptinitsuperscript2superscript2subscriptinit split& |v_j,3^(1)- (v_j,3^(0)+ α2a_j% ) |= | αna_j _i=1^ny_il^(0)_iφ^% ( v_j^(0),z_i )z_i,3- α2a_j |% \\ &= | αna_j [ _i _- _1y_il^% (0)_iz_i,3+ _i _+ _2y_il^(0)_iz_i,3 ]% - α2a_j |\\ &= | αn m [ _i _- _1l^(0% )_iz_i,3- _i _+ _2l^(0)_iz_i,3 ]- % α2 m |\\ &= | αn m [ _i _- _1l^(0% )_i z_i,3+ _i _- _1l^(0)_i(z_i,3- % z_i,3)- _i _+ _2l^(0)_i(z_i,3- z_i,3)% ]- α2 m |\\ &≤ 2αn m | n4-n_- _1 (-v_% init n) |+ αn m (n_- _1+n_+ _2% ) (v_init n) ^2 pn n\\ &=O( α m( ^2 pn+v_init)% n)=O( α m n), splitstart_ROW start_CELL end_CELL start_CELL | vitalic_j , 3( 1 ) - ( vitalic_j , 3( 0 ) + divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j ) | = | divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) zitalic_i , 3 - divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 + ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 3 ] - divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 3 - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 3 ] - divide start_ARG α end_ARG start_ARG 2 square-root start_ARG m end_ARG end_ARG | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i over¯ start_ARG z end_ARGi , 3 + ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i ( zitalic_i , 3 - over¯ start_ARG z end_ARGi , 3 ) - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i ( zitalic_i , 3 - over¯ start_ARG z end_ARGi , 3 ) ] - divide start_ARG α end_ARG start_ARG 2 square-root start_ARG m end_ARG end_ARG | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG | divide start_ARG n end_ARG start_ARG 4 end_ARG - n- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT exp ( - vinit log n ) | + divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ( n- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + n+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) exp ( vinit log n ) ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG + vinit ) log n ) = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) , end_CELL end_ROW (17) where the first equation comes from the GD update; the second equation uses (16); the third equation uses |aj|=1/msubscript1|a_j|=1/ m| aitalic_j | = 1 / square-root start_ARG m end_ARG; the fourth equation uses z¯i,3=0subscript¯30 z_i,3=0over¯ start_ARG z end_ARGi , 3 = 0 for i∈ℐ+μ2subscriptℐsubscript2i _+ _2i ∈ I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT; the first inequality uses z¯i,3=2,i∈ℐ−μ1formulae-sequencesubscript¯32subscriptℐsubscript1 z_i,3=2,i _- _1over¯ start_ARG z end_ARGi , 3 = 2 , i ∈ I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, (2) and the definition of GG; the fifth equation uses |nμ−n/4|≤n/lognsubscript4|n_μ-n/4|≤ n/ n| nitalic_μ - n / 4 | ≤ n / log n and |exp(−vinitlogn)−1|≤2vinitlogn≤2/lognsubscriptinit12subscriptinit2| (-v_init n)-1|≤ 2v_init n≤ 2/ n| exp ( - vinit log n ) - 1 | ≤ 2 vinit log n ≤ 2 / square-root start_ARG log n end_ARG by Assumption (A3); and the last equation uses Assumption (A1) and (A3). Further for the first entry of vjsubscriptv_jvitalic_j, we have |vj,1(1)−(vj,1(0)−α2aj)|=|αnaj∑i=1nyili(0)ϕ′(⟨vj(0),zi⟩)zi,1+α2aj|=|αnaj[∑i∈ℐ−μ1yili(0)zi,1+∑i∈ℐ+μ2yili(0)zi,1]+α2aj|=|αnm[∑i∈ℐ−μ1li(0)zi,1−∑i∈ℐ+μ2li(0)zi,1]+α2m|=|αnm[−∑i∈ℐ+μ2li(0)z¯i,1+∑i∈ℐ−μ1li(0)(zi,1−z¯i,1)−∑i∈ℐ+μ2li(0)(zi,1−z¯i,1)]+α2m|≤2αnm|1−n+μ2exp(−vinitlogn)|+αnm(n−μ1+n+μ2)exp(vinitlogn)ε2pnlogn=O(αm(ε2pn+vinit)logn)=O(αmlogn),superscriptsubscript11superscriptsubscript102subscriptsubscriptsuperscriptsubscript1subscriptsubscriptsuperscript0superscriptitalic-ϕ′subscript0subscriptsubscript12subscriptsubscriptdelimited-[]subscriptsubscriptℐsubscript1subscriptsubscriptsuperscript0subscript1subscriptsubscriptℐsubscript2subscriptsubscriptsuperscript0subscript12subscriptdelimited-[]subscriptsubscriptℐsubscript1subscriptsuperscript0subscript1subscriptsubscriptℐsubscript2subscriptsuperscript0subscript12delimited-[]subscriptsubscriptℐsubscript2subscriptsuperscript0subscript¯1subscriptsubscriptℐsubscript1subscriptsuperscript0subscript1subscript¯1subscriptsubscriptℐsubscript2subscriptsuperscript0subscript1subscript¯1221subscriptsubscript2subscriptinitsubscriptsubscript1subscriptsubscript2subscriptinitsuperscript2superscript2subscriptinit split& |v_j,1^(1)- (v_j,1^(0)- α2a_j% ) |= | αna_j _i=1^ny_il^(0)_iφ^% ( v_j^(0),z_i )z_i,1+ α2a_j |% \\ &= | αna_j [ _i _- _1y_il^% (0)_iz_i,1+ _i _+ _2y_il^(0)_iz_i,1 ]% + α2a_j |\\ &= | αn m [ _i _- _1l^(0% )_iz_i,1- _i _+ _2l^(0)_iz_i,1 ]+ % α2 m |\\ &= | αn m [- _i _+ _2l^(% 0)_i z_i,1+ _i _- _1l^(0)_i(z_i,1- % z_i,1)- _i _+ _2l^(0)_i(z_i,1- z_i,1)% ]+ α2 m |\\ &≤ 2αn m |1-n_+ _2 (-v_init n% ) |+ αn m (n_- _1+n_+ _2 ) (v% _init n) ^2 pn n\\ &=O( α m( ^2 pn+v_init)% n)=O( α m n), splitstart_ROW start_CELL end_CELL start_CELL | vitalic_j , 1( 1 ) - ( vitalic_j , 1( 0 ) - divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j ) | = | divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j ∑i = 1n yitalic_i l( 0 )i ϕ′ ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ ) zitalic_i , 1 + divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n end_ARG aitalic_j [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 1 + ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT yitalic_i l( 0 )i zitalic_i , 1 ] + divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 1 - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 1 ] + divide start_ARG α end_ARG start_ARG 2 square-root start_ARG m end_ARG end_ARG | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG [ - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i over¯ start_ARG z end_ARGi , 1 + ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i ( zitalic_i , 1 - over¯ start_ARG z end_ARGi , 1 ) - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i ( zitalic_i , 1 - over¯ start_ARG z end_ARGi , 1 ) ] + divide start_ARG α end_ARG start_ARG 2 square-root start_ARG m end_ARG end_ARG | end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ divide start_ARG 2 α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG | 1 - n+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT exp ( - vinit log n ) | + divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG ( n- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + n+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) exp ( vinit log n ) ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG + vinit ) log n ) = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) , end_CELL end_ROW (18) where the inequality uses z¯i,1=2subscript¯12 z_i,1=2over¯ start_ARG z end_ARGi , 1 = 2 for i∈ℐ+μ2subscriptℐsubscript2i _+ _2i ∈ I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. And for the second entry of vjsubscriptv_jvitalic_j, it follows similarly that |vj,2(1)−(vj,2(0)+α2aj)|=|αnm[∑i∈ℐ−μ1li(0)zi,2−∑i∈ℐ+μ2li(0)zi,2]−α2m|=O(αmlogn).superscriptsubscript21superscriptsubscript202subscriptdelimited-[]subscriptsubscriptℐsubscript1subscriptsuperscript0subscript2subscriptsubscriptℐsubscript2subscriptsuperscript0subscript22 split |v_j,2^(1)- (v_j,2^(0)+ α2a_j% ) |&= | αn m [ _i _-% _1l^(0)_iz_i,2- _i _+ _2l^(0)_iz_i,2% ]- α2 m |=O( α m n). % splitstart_ROW start_CELL | vitalic_j , 2( 1 ) - ( vitalic_j , 2( 0 ) + divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j ) | end_CELL start_CELL = | divide start_ARG α end_ARG start_ARG n square-root start_ARG m end_ARG end_ARG [ ∑i ∈ I start_POSTSUBSCRIPT - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 2 - ∑i ∈ I start_POSTSUBSCRIPT + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT l( 0 )i zitalic_i , 2 ] - divide start_ARG α end_ARG start_ARG 2 square-root start_ARG m end_ARG end_ARG | = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) . end_CELL end_ROW (19) Unifying (17), (17) and (18), we obtain |vj,l(1)−(vj,l(0)+α2ajsgn(vj,l(0))ξl)|=O(αmlogn)superscriptsubscript1superscriptsubscript02subscriptsgnsuperscriptsubscript0subscript |v_j,l^(1)- (v_j,l^(0)+ α2a_j % sgn(v_j,l^(0)) _l ) |=O( α m n)| vitalic_j , l( 1 ) - ( vitalic_j , l( 0 ) + divide start_ARG α end_ARG start_ARG 2 end_ARG aitalic_j sgn ( vitalic_j , l( 0 ) ) ξitalic_l ) | = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) (20) for l=1,2,3123l=1,2,3l = 1 , 2 , 3. Here ξlsubscript\ _l\ ξitalic_l are defined as ξl=−1,l=1,2formulae-sequencesubscript112 _l=-1,l=1,2ξitalic_l = - 1 , l = 1 , 2 and ξ3=1subscript31 _3=1ξ3 = 1. For a datapoint (x,y)∼Psimilar-to(x,y) P( x , y ) ∼ P with z=[z1,z2,z3]⊤=U⊤xsuperscriptsubscript1subscript2subscript3topsuperscripttopz=[z_1,z_2,z_3] =U xz = [ z1 , z2 , z3 ]⊤ = U⊤ x. We condition on the event ‖z−z¯‖≤ε2pnlogn.norm¯superscript2\|z- z\|≤ ^2 pn n.∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n . If xsignal=−μ1subscriptsignalsubscript1x_signal=- _1xsignal = - μ1: for each pair j1,j2subscript1subscript2j_1,j_2j1 , j2 with vj1(0)=vj2(0)=vinit[1,−1,1]superscriptsubscriptsubscript10superscriptsubscriptsubscript20subscriptinit111v_j_1^(0)=v_j_2^(0)=v_init[1,-1,1]vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) = vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) = vinit [ 1 , - 1 , 1 ] and aj1>0,aj2<0formulae-sequencesubscriptsubscript10subscriptsubscript20a_j_1>0,a_j_2<0aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 , aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < 0, we have ⟨vjl(1),z⟩>0,l=1,2formulae-sequencesuperscriptsubscriptsubscript1012 v_j_l^(1),z >0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 1 ) , z ⟩ > 0 , l = 1 , 2, and ‖vj1(1)−vj2(1)−αm[−1,1,1]⊤‖=‖(vj1(1)−vj1(0))−(vj2(1)−vj2(0))−αm[−1,1,1]⊤‖=O(αmlogn)normsuperscriptsubscriptsubscript11superscriptsubscriptsubscript21superscript111topnormsuperscriptsubscriptsubscript11superscriptsubscriptsubscript10superscriptsubscriptsubscript21superscriptsubscriptsubscript20superscript111top\|v_j_1^(1)-v_j_2^(1)- α m[-1,1,1] \|=\|(% v_j_1^(1)-v_j_1^(0))-(v_j_2^(1)-v_j_2^(0))- % α m[-1,1,1] \|=O( α m n)∥ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) - divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ]⊤ ∥ = ∥ ( vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 0 ) ) - ( vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 0 ) ) - divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ]⊤ ∥ = O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG ) by (20). It follows that aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=1m⟨vj1(1)−vj2(1),z⟩=1m(⟨αm[−1,1,1],z¯⟩+⟨vj1(1)−vj2(1)−αm[−1,1,1],z¯⟩+⟨vj1(1)−vj2(1),z−z¯⟩)≥1m(2αm−O(αmlogn)−O(αmlognε2pnlogn))≥αm.subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript211superscriptsubscriptsubscript11superscriptsubscriptsubscript211111¯superscriptsubscriptsubscript11superscriptsubscriptsubscript21111¯superscriptsubscriptsubscript11superscriptsubscriptsubscript21¯12superscript2 split&a_j_1φ( v_j_1^(1),z )+a_j_2φ% ( v_j_2^(1),z )= 1 m v_j_1^(1)% -v_j_2^(1),z \\ &= 1 m ( α m[-1,1,1], z% + v_j_1^(1)-v_j_2^(1)- α m[-1,1,1% ], z + v_j_1^(1)-v_j_2^(1),z- z% )\\ &≥ 1 m ( 2α m-O( α m% n)-O( α m n ^2 pn n% ) )≥ αm. splitstart_ROW start_CELL end_CELL start_CELL aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( ⟨ divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ] , over¯ start_ARG z end_ARG ⟩ + ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) - divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ] , over¯ start_ARG z end_ARG ⟩ + ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z - over¯ start_ARG z end_ARG ⟩ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( divide start_ARG 2 α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG - O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG log n end_ARG ) - O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG log n end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) ) ≥ divide start_ARG α end_ARG start_ARG m end_ARG . end_CELL end_ROW (21) If xsignal=+μ1subscriptsignalsubscript1x_signal=+ _1xsignal = + μ1: we have ⟨vjl(1),z⟩<0,l=1,2formulae-sequencesuperscriptsubscriptsubscript1012 v_j_l^(1),z <0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 1 ) , z ⟩ < 0 , l = 1 , 2, thus aj1ϕ(⟨vj1(1),z⟩)=aj2ϕ(⟨vj2(1),z⟩)=0subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript210a_j_1φ( v_j_1^(1),z )=a_j_2φ( v_j_% 2^(1),z )=0aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = 0 If xsignal=+μ2subscriptsignalsubscript2x_signal=+ _2xsignal = + μ2: we have ⟨vjl(0),z⟩>0,l=1,2formulae-sequencesuperscriptsubscriptsubscript0012 v_j_l^(0),z >0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 0 ) , z ⟩ > 0 , l = 1 , 2. Applying (20) and Assumption (A4), we have ⟨vjl(1),z⟩>0,l=1,2formulae-sequencesuperscriptsubscriptsubscript1012 v_j_l^(1),z >0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 1 ) , z ⟩ > 0 , l = 1 , 2. It follows that aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=1m⟨vj1(1)−vj2(1),z⟩=1m(⟨αm[−1,1,1],z¯⟩+⟨vj1(1)−vj2(1)−αm[−1,1,1],z¯⟩+⟨vj1(1)−vj2(1),z−z¯⟩)≤(−2αm+O(αmlogn)+O(αmlognε2pnlogn))≤−αmsubscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript211superscriptsubscriptsubscript11superscriptsubscriptsubscript211111¯superscriptsubscriptsubscript11superscriptsubscriptsubscript21111¯superscriptsubscriptsubscript11superscriptsubscriptsubscript21¯2superscript2 split&a_j_1φ( v_j_1^(1),z )+a_j_2φ% ( v_j_2^(1),z )= 1 m v_j_1^(1)% -v_j_2^(1),z \\ &= 1 m ( α m[-1,1,1], z% + v_j_1^(1)-v_j_2^(1)- α m[-1,1,1% ], z + v_j_1^(1)-v_j_2^(1),z- z% )\\ &≤ (- 2α m+O( α m n)+O( % α m n ^2 pn n) )≤-% αm splitstart_ROW start_CELL end_CELL start_CELL aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ( ⟨ divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ] , over¯ start_ARG z end_ARG ⟩ + ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) - divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG [ - 1 , 1 , 1 ] , over¯ start_ARG z end_ARG ⟩ + ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) - vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z - over¯ start_ARG z end_ARG ⟩ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ ( - divide start_ARG 2 α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG + O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG log n end_ARG ) + O ( divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG log n end_ARG ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) ) ≤ - divide start_ARG α end_ARG start_ARG m end_ARG end_CELL end_ROW (22) for sufficiently large n. Here the last inequality uses Assumption (A1). If xsignal=−μ2subscriptsignalsubscript2x_signal=- _2xsignal = - μ2: we have ⟨vjl(0),z⟩<0,l=1,2formulae-sequencesuperscriptsubscriptsubscript0012 v_j_l^(0),z <0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 0 ) , z ⟩ < 0 , l = 1 , 2. Applying (20) and Assumption (A4), we have ⟨vjl(1),z⟩<0,l=1,2formulae-sequencesuperscriptsubscriptsubscript1012 v_j_l^(1),z <0,l=1,2⟨ vitalic_j start_POSTSUBSCRIPT l end_POSTSUBSCRIPT( 1 ) , z ⟩ < 0 , l = 1 , 2. It follows that aj1ϕ(⟨vj1(1),z⟩)+aj2ϕ(⟨vj2(1),z⟩)=0.subscriptsubscript1italic-ϕsuperscriptsubscriptsubscript11subscriptsubscript2italic-ϕsuperscriptsubscriptsubscript210a_j_1φ( v_j_1^(1),z )+a_j_2φ( v_j_% 2^(1),z )=0.aitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) + aitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ϕ ( ⟨ vitalic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT( 1 ) , z ⟩ ) = 0 . (23) In conclusion, for datapoint (x,y)(x,y)( x , y ) with xsignal=−μ1subscriptsignalsubscript1x_signal=- _1xsignal = - μ1, conditioning on (7), the output of f(1)superscript1f^(1)f( 1 ) is f(1)(x)=∑j=1majϕ(⟨vj(1),z⟩)=∑e∈Uniform(±13)∑j∈eajϕ(⟨vj(1),z⟩)=∑e:e3=1[∑j∈Pos,eajϕ(⟨vj(1),z⟩)−∑j∈Neg,eajϕ(⟨vj(1),z⟩)]≥∑e:e3=1[min|Pos,e|,|Neg,e|2α5m−4mlogn(vinit+αm)]≥∑e:e3=1[α40−4mlogn(Cαm+αm)]>0superscript1superscriptsubscript1subscriptitalic-ϕsuperscriptsubscript1subscriptUniformsuperscriptplus-or-minus13subscriptsubscriptsubscriptitalic-ϕsuperscriptsubscript1subscript:subscript31delimited-[]subscriptsubscriptPossubscriptitalic-ϕsuperscriptsubscript1subscriptsubscriptNegsubscriptitalic-ϕsuperscriptsubscript1subscript:subscript31delimited-[]subscriptPossubscriptNeg254subscriptinitsubscript:subscript31delimited-[]4040 splitf^(1)(x)&= _j=1^ma_jφ( v_j^(1),z% )= _e (\± 1\^3) _j _ea_% jφ( v_j^(1),z )\\ &= _e:e_3=1 [ _j _ Pos,ea_j% φ( v_j^(1),z )- _j _ % Neg,ea_jφ( v_j^(1),z ) ]\\ &≥ _e:e_3=1 [ \|J_ Pos,e|,|% J_ Neg,e|\ 2α5m- 4 m% n(v_init+ α m) ]\\ &≥ _e:e_3=1 [ α40- 4 m n(C % α m+ α m) ]>0 splitstart_ROW start_CELL f( 1 ) ( x ) end_CELL start_CELL = ∑j = 1m aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) = ∑e ∈ Uniform ( ± 1 3 ) ∑j ∈ J start_POSTSUBSCRIPT e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑e : e start_POSTSUBSCRIPT 3 = 1 end_POSTSUBSCRIPT [ ∑j ∈ J start_POSTSUBSCRIPT Pos , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) - ∑j ∈ J start_POSTSUBSCRIPT Neg , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ ∑e : e start_POSTSUBSCRIPT 3 = 1 end_POSTSUBSCRIPT [ min | Jroman_Pos , e | , | Jroman_Neg , e | divide start_ARG 2 α end_ARG start_ARG 5 m end_ARG - divide start_ARG 4 square-root start_ARG m end_ARG end_ARG start_ARG log n end_ARG ( vinit + divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ ∑e : e start_POSTSUBSCRIPT 3 = 1 end_POSTSUBSCRIPT [ divide start_ARG α end_ARG start_ARG 40 end_ARG - divide start_ARG 4 square-root start_ARG m end_ARG end_ARG start_ARG log n end_ARG ( C divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG + divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ) ] > 0 end_CELL end_ROW (24) for sufficiently large n. Here the first inequality uses (12) and (21), the property that ||Pos,e|−|Neg,e||≤2m/lognsubscriptPossubscriptNeg2||J_ Pos,e|-|J_ Neg,e% ||≤ 2m/ n| | Jroman_Pos , e | - | Jroman_Neg , e | | ≤ 2 m / log n from (B3) in Lemma A.4 and the property that ϕ(⟨vj(1),z⟩)≤2(vinit+α/m)italic-ϕsuperscriptsubscript12subscriptinitφ( v_j^(1),z )≤ 2(v_init+α/ m)ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) ≤ 2 ( vinit + α / square-root start_ARG m end_ARG ); the second inequality uses (B3) and Assumption (A4). Similarly, we have that for datapoint (x,y)(x,y)( x , y ) with xsignal=+μ1subscriptsignalsubscript1x_signal=+ _1xsignal = + μ1, conditioning on (7), the output of f(1)superscript1f^(1)f( 1 ) is f(1)(x)=∑e:e3=−1[∑j∈Pos,eajϕ(⟨vj(1),z⟩)−∑j∈Neg,eajϕ(⟨vj(1),z⟩)]>0.superscript1subscript:subscript31delimited-[]subscriptsubscriptPossubscriptitalic-ϕsuperscriptsubscript1subscriptsubscriptNegsubscriptitalic-ϕsuperscriptsubscript10f^(1)(x)= _e:e_3=-1 [ _j _ Pos% ,ea_jφ( v_j^(1),z )- _j _ % Neg,ea_jφ( v_j^(1),z ) ]>0.f( 1 ) ( x ) = ∑e : e start_POSTSUBSCRIPT 3 = - 1 end_POSTSUBSCRIPT [ ∑j ∈ J start_POSTSUBSCRIPT Pos , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) - ∑j ∈ J start_POSTSUBSCRIPT Neg , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) ] > 0 . (25) For datapoint (x,y)(x,y)( x , y ) with xsignal=+μ2subscriptsignalsubscript2x_signal=+ _2xsignal = + μ2, conditioning on (7), the output of f(1)superscript1f^(1)f( 1 ) is f(1)(x)=∑j=1majϕ(⟨vj(1),z⟩)=∑e∈Uniform(±13)∑j∈eajϕ(⟨vj(1),z⟩)=(∑e:[e1,e2]=[1,−1]+∑e:e1=e2)[∑j∈Pos,eajϕ(⟨vj(1),z⟩)−∑j∈Neg,eajϕ(⟨vj(1),z⟩)]≤∑e:[e1,e2]=[1,−1][−min|Pos,e|,|Neg,e|αm+4mlogn(vinit+αm)]+∑e:e1=e22vinit|Neg,e|mlogn≤2(−α16+5mlogn(vinit+αm))+vinitmlogn≤−α8+10(C+1)αlogn+Cαlogn<0,superscript1superscriptsubscript1subscriptitalic-ϕsuperscriptsubscript1subscriptUniformsuperscriptplus-or-minus13subscriptsubscriptsubscriptitalic-ϕsuperscriptsubscript1subscript:subscript1subscript211subscript:subscript1subscript2delimited-[]subscriptsubscriptPossubscriptitalic-ϕsuperscriptsubscript1subscriptsubscriptNegsubscriptitalic-ϕsuperscriptsubscript1subscript:subscript1subscript211delimited-[]subscriptPossubscriptNeg4subscriptinitsubscript:subscript1subscript22subscriptinitsubscriptNeg2165subscriptinitsubscriptinit81010 splitf^(1)(x)&= _j=1^ma_jφ( v_j^(1),z% )= _e (\± 1\^3) _j _ea_% jφ( v_j^(1),z )\\ &=( _e:[e_1,e_2]=[1,-1]+ _e:e_1=e_2) [ _j∈% J_ Pos,ea_jφ( v_j^(1),z% )- _j _ Neg,ea_jφ( v_% j^(1),z ) ]\\ &≤ _e:[e_1,e_2]=[1,-1] [- \|J_ % Pos,e|,|J_ Neg,e|\ αm+ 4% m n(v_init+ α m) ]+ _e:e_1% =e_2 2v_init|J_ Neg,e| m% n\\ &≤ 2 (- α16+ 5 m n(v_init+% α m) )+ v_init m n% ≤- α8+ 10(C+1)α n+ Cα n% <0, splitstart_ROW start_CELL f( 1 ) ( x ) end_CELL start_CELL = ∑j = 1m aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) = ∑e ∈ Uniform ( ± 1 3 ) ∑j ∈ J start_POSTSUBSCRIPT e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ( ∑e : [ e start_POSTSUBSCRIPT 1 , e2 ] = [ 1 , - 1 ] end_POSTSUBSCRIPT + ∑e : e start_POSTSUBSCRIPT 1 = e2 end_POSTSUBSCRIPT ) [ ∑j ∈ J start_POSTSUBSCRIPT Pos , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) - ∑j ∈ J start_POSTSUBSCRIPT Neg , e end_POSTSUBSCRIPT aitalic_j ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ ∑e : [ e start_POSTSUBSCRIPT 1 , e2 ] = [ 1 , - 1 ] end_POSTSUBSCRIPT [ - min | Jroman_Pos , e | , | Jroman_Neg , e | divide start_ARG α end_ARG start_ARG m end_ARG + divide start_ARG 4 square-root start_ARG m end_ARG end_ARG start_ARG log n end_ARG ( vinit + divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ) ] + ∑e : e start_POSTSUBSCRIPT 1 = e2 end_POSTSUBSCRIPT divide start_ARG 2 vinit | Jroman_Neg , e | end_ARG start_ARG square-root start_ARG m log n end_ARG end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≤ 2 ( - divide start_ARG α end_ARG start_ARG 16 end_ARG + divide start_ARG 5 square-root start_ARG m end_ARG end_ARG start_ARG log n end_ARG ( vinit + divide start_ARG α end_ARG start_ARG square-root start_ARG m end_ARG end_ARG ) ) + divide start_ARG vinit square-root start_ARG m end_ARG end_ARG start_ARG square-root start_ARG log n end_ARG end_ARG ≤ - divide start_ARG α end_ARG start_ARG 8 end_ARG + divide start_ARG 10 ( C + 1 ) α end_ARG start_ARG log n end_ARG + divide start_ARG C α end_ARG start_ARG square-root start_ARG log n end_ARG end_ARG < 0 , end_CELL end_ROW (26) where the first inequality uses (15) , (22), (B3) and the property that ϕ(⟨vj(1),z⟩)≤2(vinit+α/m)italic-ϕsuperscriptsubscript12subscriptinitφ( v_j^(1),z )≤ 2(v_init+α/ m)ϕ ( ⟨ vitalic_j( 1 ) , z ⟩ ) ≤ 2 ( vinit + α / square-root start_ARG m end_ARG ); the second inequality uses (B3); and the third inequality uses Assumption (A4). Similarly, we have that for datapoint (x,y)(x,y)( x , y ) with xsignal=−μ2subscriptsignalsubscript2x_signal=- _2xsignal = - μ2, conditioning on (7), f(1)(x)<0superscript10f^(1)(x)<0f( 1 ) ( x ) < 0, which combined with (24), (25) and (26), yields that sgn(f(1)(x))=ysgnsuperscript1 sgn(f^(1)(x))=ysgn ( f( 1 ) ( x ) ) = y for any (x,y)∼P,z=U⊤xformulae-sequencesimilar-tosuperscripttop(x,y) P,z=U x( x , y ) ∼ P , z = U⊤ x satisfying ‖z−z¯‖≤ε2pnlogn.norm¯superscript2\|z- z\|≤ ^2 pn n.∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n . According to the definition of datasubscriptdataG_ dataGroman_data, all (xi,yi)subscriptsubscript(x_i,y_i)( xitalic_i , yitalic_i ) satisfy this condition. Thus conditioning on the event GG, the model f(1)superscript1f^(1)f( 1 ) can correctly classify all training data points. And applying the law of total probability, we obtain that the test error is bounded by: ℙ(x,y)∼P(y≠sgn(f(1)(x)))≤ℙ(x,y)∼P(y≠sgn(f(1)(x))∣∥z−z¯∥≤ε2pnlogn)+ℙ(x,y)∼P(‖z−z¯‖≤ε2pnlogn)=ℙ(x,y)∼P(‖z−z¯‖≤ε2pnlogn)≤exp(−Ω(log2n)), splitP_(x,y) P(y≠ sgn(f^(1)(x)))&% _(x,y) P (y≠ sgn(f^(1)(x)) % \|z- z\|≤ ^2 pn n )\\ & +P_(x,y) P (\|z- z\|≤ ^2 % pn n )\\ &=P_(x,y) P (\|z- z\|≤ ^2 p% n n )≤ (- ( ^2n)), splitstart_ROW start_CELL blackboard_P( x , y ) ∼ P ( y ≠ sgn ( f( 1 ) ( x ) ) ) end_CELL start_CELL ≤ blackboard_P( x , y ) ∼ P ( y ≠ sgn ( f( 1 ) ( x ) ) ∣ ∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + blackboard_P( x , y ) ∼ P ( ∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_P( x , y ) ∼ P ( ∥ z - over¯ start_ARG z end_ARG ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) ≤ exp ( - Ω ( log2 n ) ) , end_CELL end_ROW where the last inequality uses Lemma A.3. ∎ Lemma A.1. Suppose that Assumption (A2) holds. With probability at least 1−O(1n2)11superscript21-O( 1n^2)1 - O ( divide start_ARG 1 end_ARG start_ARG n2 end_ARG ), the following conditions hold: ℐ[i,j,−1],+μ1=ℐ+μ1;ℐ[i,j,−1],−μ1=∅,i,j∈±1;ℐ[i,j,+1],+μ1=∅;ℐ[i,j,+1],−μ1=ℐ−μ1,i,j∈±1;ℐ[+1,−1,k],+μ2=ℐ+μ2;ℐ[+1,−1,k],−μ2=∅,k∈±1; split&I_[i,j,-1],+ _1=I_+ _1; % I_[i,j,-1],- _1= , i,j∈\± 1\;\\ &I_[i,j,+1],+ _1= ; _[i,j,+1],- _% 1=I_- _1, i,j∈\± 1\;\\ &I_[+1,-1,k],+ _2=I_+ _2; _[+1% ,-1,k],- _2= , k∈\± 1\; splitstart_ROW start_CELL end_CELL start_CELL I[ i , j , - 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I+ μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; I[ i , j , - 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ , i , j ∈ ± 1 ; end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL I[ i , j , + 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ ; I[ i , j , + 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , i , j ∈ ± 1 ; end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL I[ + 1 , - 1 , k ] , + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; I[ + 1 , - 1 , k ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∅ , k ∈ ± 1 ; end_CELL end_ROW (27) ℐ[−1,+1,k],+μ2=∅;ℐ[−1,+1,k],−μ2=ℐ−μ2,k∈±1;||ℐ[i,i,k],μ|−nμ2|≤nlogn,i,k∈±1,μ∈±μ2. split&I_[-1,+1,k],+ _2= ; _[% -1,+1,k],- _2=I_- _2, k∈\± 1\;\\ & ||I_[i,i,k],μ|- n_μ2 |≤ n n% , i,k∈\± 1\,μ∈\± _2\. splitstart_ROW start_CELL end_CELL start_CELL I[ - 1 , + 1 , k ] , + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∅ ; I[ - 1 , + 1 , k ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , k ∈ ± 1 ; end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL | | I[ i , i , k ] , μ | - divide start_ARG nitalic_μ end_ARG start_ARG 2 end_ARG | ≤ square-root start_ARG n log n end_ARG , i , k ∈ ± 1 , μ ∈ ± μ2 . end_CELL end_ROW (28) Proof. For simplicity, we denote ℙ(⋅∣(xi,yi)i=1n∈data)P(· \(x_i,y_i)\_i=1^n _ % data)blackboard_P ( ⋅ ∣ ( xitalic_i , yitalic_i ) i = 1n ∈ Groman_data ) as ℙ(⋅)ℙ⋅P(·)blackboard_P ( ⋅ ) in the proof below. For vj(0)=[vinit,vinit,vinit]superscriptsubscript0subscriptinitsubscriptinitsubscriptinitv_j^(0)=[v_init,v_init,v_init]vitalic_j( 0 ) = [ vinit , vinit , vinit ], we first show that for (xi,yi)i=1n∈datasuperscriptsubscriptsubscriptsubscript1subscriptdata\(x_i,y_i)\_i=1^n _ data ( xitalic_i , yitalic_i ) i = 1n ∈ Groman_data, ⟨vj(0),zi⟩>0,∀i∈ℐ−μ1.formulae-sequencesuperscriptsubscript0subscript0for-allsubscriptℐsubscript1 v_j^(0),z_i >0, ∀ i _- _1.⟨ vitalic_j( 0 ) , zitalic_i ⟩ > 0 , ∀ i ∈ I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT . According to the definition of datasubscriptdataG_ dataGroman_data, ‖zi−z¯i‖≤ε2p/nlognnormsubscriptsubscript¯superscript2\|z_i- z_i\|≤ ^2 p/n n∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ ≤ ε2 square-root start_ARG p / n end_ARG log n for all i∈ℐ−μ1subscriptℐsubscript1i _- _1i ∈ I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Thus ⟨vj(0),zi⟩=⟨vj(0),z¯i⟩+⟨vj(0),zi−z¯i⟩≥vinit(2−‖zi−z¯i‖)>0,superscriptsubscript0subscriptsuperscriptsubscript0subscript¯superscriptsubscript0subscriptsubscript¯subscriptinit2normsubscriptsubscript¯0 v_j^(0),z_i = v_j^(0), z_i +% v_j^(0),z_i- z_i ≥ v_init(2-\|z_i-% z_i\|)>0,⟨ vitalic_j( 0 ) , zitalic_i ⟩ = ⟨ vitalic_j( 0 ) , over¯ start_ARG z end_ARGi ⟩ + ⟨ vitalic_j( 0 ) , zitalic_i - over¯ start_ARG z end_ARGi ⟩ ≥ vinit ( 2 - ∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ ) > 0 , where the first inequality uses z¯i=[0,0,2]subscript¯002 z_i=[0,0,2]over¯ start_ARG z end_ARGi = [ 0 , 0 , 2 ] when i∈ℐ−μ1subscriptℐsubscript1i _- _1i ∈ I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and the second inequality uses Assumption (A1). Similarly, we have ⟨vj(0),zi⟩<0,∀i∈ℐ+μ1.formulae-sequencesuperscriptsubscript0subscript0for-allsubscriptℐsubscript1 v_j^(0),z_i <0, ∀ i _+ _1.⟨ vitalic_j( 0 ) , zitalic_i ⟩ < 0 , ∀ i ∈ I+ μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT . Thus conditioning on (xi,yi)i=1n∈datasuperscriptsubscriptsubscriptsubscript1subscriptdata\(x_i,y_i)\_i=1^n _ data ( xitalic_i , yitalic_i ) i = 1n ∈ Groman_data, we have ℐ[1,1,1],−μ1=ℐ−μ1;ℐ[1,1,1],+μ1=∅.formulae-sequencesubscriptℐ111subscript1subscriptℐsubscript1subscriptℐ111subscript1I_[1,1,1],- _1=I_- _1;I_[1,1,1],+% _1= .I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; I[ 1 , 1 , 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ . For i∈ℐ±μ2subscriptℐplus-or-minussubscript2i _± _2i ∈ I± μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, recall that xi=[xi,signal⊤,xi,noise⊤]⊤subscriptsuperscriptsuperscriptsubscriptsignaltopsuperscriptsubscriptnoisetoptopx_i=[x_i,signal ,x_i,noise ] xitalic_i = [ xitalic_i , signal⊤ , xitalic_i , noise⊤ ]⊤ with xi,signal=[xi,1,xi,2]⊤subscriptsignalsuperscriptsubscript1subscript2topx_i,signal=[x_i,1,x_i,2] xitalic_i , signal = [ xitalic_i , 1 , xitalic_i , 2 ]⊤ and xi,noise=[xi,3,⋯,xi,p]⊤subscriptnoisesuperscriptsubscript3⋯subscripttopx_i,noise=[x_i,3,·s,x_i,p] xitalic_i , noise = [ xitalic_i , 3 , ⋯ , xitalic_i , p ]⊤. We have ⟨vj(0),zi⟩=vinit∑l=13zi,l=vinit(∑l=13ul)⊤xi=vinit[−μ1⊤,(∑l=13δl)⊤]xi=vinit(∑l=13δl)⊤xi,noise.superscriptsubscript0subscriptsubscriptinitsuperscriptsubscript13subscriptsubscriptinitsuperscriptsuperscriptsubscript13subscripttopsubscriptsubscriptinitsuperscriptsubscript1topsuperscriptsuperscriptsubscript13subscripttopsubscriptsubscriptinitsuperscriptsuperscriptsubscript13subscripttopsubscriptnoise v_j^(0),z_i =v_init _l=1^3z_i,l=v_% init( _l=1^3u_l) x_i=v_init[- _1 ,% ( _l=1^3 _l) ]x_i=v_init( _l=1^3 _% l) x_i,noise.⟨ vitalic_j( 0 ) , zitalic_i ⟩ = vinit ∑l = 13 zitalic_i , l = vinit ( ∑l = 13 uitalic_l )⊤ xitalic_i = vinit [ - μ1⊤ , ( ∑l = 13 δitalic_l )⊤ ] xitalic_i = vinit ( ∑l = 13 δitalic_l )⊤ xitalic_i , noise . It follows that ℙ(⟨vj(0),zi⟩>0)=12.ℙsuperscriptsubscript0subscript012P( v_j^(0),z_i >0)= 12.blackboard_P ( ⟨ vitalic_j( 0 ) , zitalic_i ⟩ > 0 ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG . Applying Hoeffding’s inequality, we obtain ℙ(||ℐ[1,1,1],+μ2|−|ℐ+μ2|2|>t)≤2exp(−2t2n).ℙsubscriptℐ111subscript2subscriptℐsubscript2222superscript2P( ||I_[1,1,1],+ _2|- |I_+ _2% |2 |>t)≤ 2 (- 2t^2n).blackboard_P ( | | I[ 1 , 1 , 1 ] , + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | - divide start_ARG | I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG start_ARG 2 end_ARG | > t ) ≤ 2 exp ( - divide start_ARG 2 t2 end_ARG start_ARG n end_ARG ) . Similarly we have ℙ(||ℐ[1,1,1],−μ2|−|ℐ−μ2|2|>t)≤2exp(−2t2n).ℙsubscriptℐ111subscript2subscriptℐsubscript2222superscript2P( ||I_[1,1,1],- _2|- |I_- _2% |2 |>t)≤ 2 (- 2t^2n).blackboard_P ( | | I[ 1 , 1 , 1 ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | - divide start_ARG | I- μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG start_ARG 2 end_ARG | > t ) ≤ 2 exp ( - divide start_ARG 2 t2 end_ARG start_ARG n end_ARG ) . Let t=nlognt= n nt = square-root start_ARG n log n end_ARG. We have ||ℐ[1,1,1],μ|−nμ2|≤nlogn,μ∈±μ2formulae-sequencesubscriptℐ111subscript2plus-or-minussubscript2 ||I_[1,1,1],μ|- n_μ2 |≤ n n,% μ∈\± _2\| | I[ 1 , 1 , 1 ] , μ | - divide start_ARG nitalic_μ end_ARG start_ARG 2 end_ARG | ≤ square-root start_ARG n log n end_ARG , μ ∈ ± μ2 with probability at least 1−4/n214superscript21-4/n^21 - 4 / n2. Following similar discussion, we have that ℐ[i,j,−1],+μ1=ℐ+μ1;ℐ[i,j,−1],−μ1=∅,i,j∈±1;formulae-sequencesubscriptℐ1subscript1subscriptℐsubscript1formulae-sequencesubscriptℐ1subscript1plus-or-minus1I_[i,j,-1],+ _1=I_+ _1; _[i,j,% -1],- _1= , i,j∈\± 1\;I[ i , j , - 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I+ μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; I[ i , j , - 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ , i , j ∈ ± 1 ; ℐ[i,j,+1],+μ1=∅;ℐ[i,j,+1],−μ1=ℐ−μ1,i,j∈±1;formulae-sequencesubscriptℐ1subscript1formulae-sequencesubscriptℐ1subscript1subscriptℐsubscript1plus-or-minus1I_[i,j,+1],+ _1= ; _[i,j,+1],- _1% =I_- _1, i,j∈\± 1\;I[ i , j , + 1 ] , + μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅ ; I[ i , j , + 1 ] , - μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , i , j ∈ ± 1 ; ℐ[+1,−1,k],+μ2=ℐ+μ2;ℐ[+1,−1,k],−μ2=∅,k∈±1;formulae-sequencesubscriptℐ11subscript2subscriptℐsubscript2formulae-sequencesubscriptℐ11subscript2plus-or-minus1I_[+1,-1,k],+ _2=I_+ _2; _[+1,% -1,k],- _2= , k∈\± 1\;I[ + 1 , - 1 , k ] , + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = I+ μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; I[ + 1 , - 1 , k ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∅ , k ∈ ± 1 ; ℐ[−1,+1,k],+μ2=∅;ℐ[−1,+1,k],−μ2=ℐ−μ2,k∈±1formulae-sequencesubscriptℐ11subscript2formulae-sequencesubscriptℐ11subscript2subscriptℐsubscript2plus-or-minus1I_[-1,+1,k],+ _2= ; _[-1,+1,k],- _% 2=I_- _2, k∈\± 1\I[ - 1 , + 1 , k ] , + μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∅ ; I[ - 1 , + 1 , k ] , - μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = I- μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , k ∈ ± 1 hold with probability 1111 given (xi,yi)i=1n∈datasuperscriptsubscriptsubscriptsubscript1subscriptdata\(x_i,y_i)\_i=1^n _ data ( xitalic_i , yitalic_i ) i = 1n ∈ Groman_data. And ||ℐ[i,i,k],μ|−nμ2|≤nlogn,i,k∈±1,μ∈±μ2formulae-sequencesubscriptℐsubscript2formulae-sequenceplus-or-minus1plus-or-minussubscript2 ||I_[i,i,k],μ|- n_μ2 |≤ n n,% i,k∈\± 1\,μ∈\± _2\| | I[ i , i , k ] , μ | - divide start_ARG nitalic_μ end_ARG start_ARG 2 end_ARG | ≤ square-root start_ARG n log n end_ARG , i , k ∈ ± 1 , μ ∈ ± μ2 hold with probability at least 1−16/n2116superscript21-16/n^21 - 16 / n2. In total, the conditions above hold with probability at least 1−exp(−Ω(log2n))−O(1n2)=1−O(1n2)1Ωsuperscript21superscript211superscript21- (- ( ^2n))-O( 1n^2)=1-O( 1n^2)1 - exp ( - Ω ( log2 n ) ) - O ( divide start_ARG 1 end_ARG start_ARG n2 end_ARG ) = 1 - O ( divide start_ARG 1 end_ARG start_ARG n2 end_ARG ). ∎ Lemma A.2. Suppose that Assumption (A2) holds. Let the training data xi,yii=1nsuperscriptsubscriptsubscriptsubscript1\x_i,y_i\_i=1^n xitalic_i , yitalic_i i = 1n for model fLsubscriptf_Lfitalic_L be sampled i.i.d from P. With probability at least 1−exp(−Ω(log2n))1Ωsuperscript21- (- ( ^2n))1 - exp ( - Ω ( log2 n ) ), we have ‖zi−z¯i‖≤ε2pnlogn, for all i∈[n].formulae-sequencenormsubscriptsubscript¯superscript2 for all delimited-[]\|z_i- z_i\|≤ ^2 pn n, % for all i∈[n].∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n , for all i ∈ [ n ] . (29) Proof. Applying Lemma A.3, we obtain ℙ(‖zi−z¯i‖≤ε2pnlogn,∀i∈[n])≥1−∑i=1nℙ(‖zi−z¯i‖>ε2pnlogn)≥1−nexp(−Ω(log2n))=1−exp(−Ω(log2n)).ℙformulae-sequencedelimited-∥subscriptsubscript¯superscript2for-alldelimited-[]1superscriptsubscript1ℙdelimited-∥subscriptsubscript¯superscript21Ωsuperscript21Ωsuperscript2 splitP(\|z_i- z_i\|≤ ^2 p% n n,∀ i∈[n])&≥ 1- _i=1^nP(\|z_i- z_i% \|> ^2 pn n)\\ &≥ 1-n (- ( ^2n))=1- (- ( ^2n)). splitstart_ROW start_CELL blackboard_P ( ∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ ≤ ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n , ∀ i ∈ [ n ] ) end_CELL start_CELL ≥ 1 - ∑i = 1n blackboard_P ( ∥ zitalic_i - over¯ start_ARG z end_ARGi ∥ > ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ 1 - n exp ( - Ω ( log2 n ) ) = 1 - exp ( - Ω ( log2 n ) ) . end_CELL end_ROW ∎ Lemma A.3. Suppose that Assumption (A2) holds. For x=[x1,x2,⋯,xp]∼Psubscript1subscript2⋯subscriptsimilar-tox=[x_1,x_2,·s,x_p] Px = [ x1 , x2 , ⋯ , xitalic_p ] ∼ P with [x1,x2]⊤=μ,μ∈±μ1,±μ2formulae-sequencesuperscriptsubscript1subscript2topplus-or-minussubscript1plus-or-minussubscript2[x_1,x_2] =μ,μ∈\± _1,± _2\[ x1 , x2 ]⊤ = μ , μ ∈ ± μ1 , ± μ2 , we have ℙ(‖U⊤x−ν‖max>ε2pnlogn)≤exp(−Ω(log2n)),ℙsubscriptnormsuperscripttopsuperscript2Ωsuperscript2P(\|U x-ν\|_ > ^2 pn n)% ≤ (- ( ^2n)),blackboard_P ( ∥ U⊤ x - ν ∥max > ε2 square-root start_ARG divide start_ARG p end_ARG start_ARG n end_ARG end_ARG log n ) ≤ exp ( - Ω ( log2 n ) ) , (30) where ν=[μ2,−μ2,−μ1]⊤μ∈ℝ3superscriptsubscript2subscript2subscript1topsuperscriptℝ3ν=[ _2,- _2,- _1] μ ^3ν = [ μ2 , - μ2 , - μ1 ]⊤ μ ∈ blackboard_R3. Proof. We start our analysis with [x1,x2]⊤=μ1superscriptsubscript1subscript2topsubscript1[x_1,x_2] = _1[ x1 , x2 ]⊤ = μ1. Note that u1⊤x=∑i=1p−2δ1,ixi+2superscriptsubscript1topsuperscriptsubscript12subscript1subscript2u_1 x= _i=1^p-2 _1,ix_i+2u1⊤ x = ∑i = 1p - 2 δ1 , i xitalic_i + 2 is a summation of independent bounded random variables with zero mean. By Hoeffding’s inequality, we have ℙ(|u1⊤x|≥t)≤2exp(−t22∑i=1p−2δ1,i2ε2)≤2exp(−t22‖δ‖F2ε2).ℙsuperscriptsubscript1top2superscript22superscriptsubscript12superscriptsubscript12superscript22superscript22superscriptsubscriptnormF2superscript2P(|u_1 x|≥ t)≤ 2 (- t^22 _i=1^p% -2 _1,i^2 ^2 )≤ 2 (- t^22\|% δ\|_F^2 ^2 ).blackboard_P ( | u1⊤ x | ≥ t ) ≤ 2 exp ( - divide start_ARG t2 end_ARG start_ARG 2 ∑i = 1p - 2 δ1 , i2 ε2 end_ARG ) ≤ 2 exp ( - divide start_ARG t2 end_ARG start_ARG 2 ∥ δ ∥F2 ε2 end_ARG ) . Similarly, the concentration for u2⊤xsuperscriptsubscript2topu_2 xu2⊤ x and u3⊤x+2superscriptsubscript3top2u_3 x+2u3⊤ x + 2 are as follows: ℙ(|u2⊤x|≥t)≤2exp(−t22‖δ‖F2ε2);ℙsuperscriptsubscript2top2superscript22superscriptsubscriptnormF2superscript2P(|u_2 x|≥ t)≤ 2 (- t^22\|δ\|_% F^2 ^2 );blackboard_P ( | u2⊤ x | ≥ t ) ≤ 2 exp ( - divide start_ARG t2 end_ARG start_ARG 2 ∥ δ ∥F2 ε2 end_ARG ) ; ℙ(|u3⊤x+2|≥t)≤2exp(−t22‖δ‖F2ε2).ℙsuperscriptsubscript3top22superscript22superscriptsubscriptnormF2superscript2P(|u_3 x+2|≥ t)≤ 2 (- t^22\|δ\|_% F^2 ^2 ).blackboard_P ( | u3⊤ x + 2 | ≥ t ) ≤ 2 exp ( - divide start_ARG t2 end_ARG start_ARG 2 ∥ δ ∥F2 ε2 end_ARG ) . Combining these inequalities yields ℙ(‖U⊤x−ν1‖max>t)≤6exp(−t22‖δ‖F2ε2)≤6exp(−nt22C2ε4p),ℙsubscriptnormsuperscripttopsubscript16superscript22superscriptsubscriptnormF2superscript26superscript22superscript2superscript4P(\|U x- _1\|_ >t)≤ 6 (- t^22\|% δ\|_F^2 ^2 )≤ 6 (- nt^2% 2C^2 ^4p ),blackboard_P ( ∥ U⊤ x - ν1 ∥max > t ) ≤ 6 exp ( - divide start_ARG t2 end_ARG start_ARG 2 ∥ δ ∥F2 ε2 end_ARG ) ≤ 6 exp ( - divide start_ARG n t2 end_ARG start_ARG 2 C2 ε4 p end_ARG ) , (31) where the last inequality uses Assumption (A2). The proof concludes by letting t=ε2p/nlognsuperscript2t= ^2 p/n nt = ε2 square-root start_ARG p / n end_ARG log n. The analysis for other values of [x1,x2]⊤superscriptsubscript1subscript2top[x_1,x_2] [ x1 , x2 ]⊤ follows similarly. ∎ Lemma A.4. Suppose that Assumption (A5) holds. Then the following conditions hold with probability at least 1−O(1/n4)11superscript41-O(1/n^4)1 - O ( 1 / n4 ): (B1) maxk∈Pos,Neg||k|−m2|≤mlognsubscriptPosNegsubscript2 _k∈\ Pos, Neg\||J_k|-% m2|≤ m nmaxitalic_k ∈ Pos , Neg | | Jitalic_k | - divide start_ARG m end_ARG start_ARG 2 end_ARG | ≤ divide start_ARG m end_ARG start_ARG log n end_ARG. (B2) maxe∈Uniform(±13)|e−m8|≤mlognsubscriptUniformsuperscriptplus-or-minus13subscript8 _e (\± 1\^3)|J_e- m8|≤ % m nmaxitalic_e ∈ Uniform ( ± 1 3 ) | Jitalic_e - divide start_ARG m end_ARG start_ARG 8 end_ARG | ≤ divide start_ARG m end_ARG start_ARG log n end_ARG (B3) maxk∈Pos,Neg,e∈Uniform(±13)|k,e−m16|≤mlognsubscriptformulae-sequencePosNegUniformsuperscriptplus-or-minus13subscript16 _k∈\ Pos, Neg\,e (% \± 1\^3)|J_k,e- m16|≤ m nmaxitalic_k ∈ Pos , Neg , e ∈ Uniform ( ± 1 3 ) | Jitalic_k , e - divide start_ARG m end_ARG start_ARG 16 end_ARG | ≤ divide start_ARG m end_ARG start_ARG log n end_ARG. (B4) maxμ∈±μ1,±μ2|nμ−n4|≤nlognsubscriptplus-or-minussubscript1plus-or-minussubscript2subscript4 _μ∈\± _1,± _2\|n_μ- n4|≤ n nmaxitalic_μ ∈ ± μ start_POSTSUBSCRIPT 1 , ± μ2 end_POSTSUBSCRIPT | nitalic_μ - divide start_ARG n end_ARG start_ARG 4 end_ARG | ≤ divide start_ARG n end_ARG start_ARG log n end_ARG. Proof. Note that |Pos|∼Bin(m,1/2)similar-tosubscriptPosBin12|J_ Pos| Bin(m,1/2)| Jroman_Pos | ∼ Bin ( m , 1 / 2 ). Applying Hoeffding’s inequality, we have ℙ(||Pos|−m2|≤mlogn)≤2exp(−2mlog2n)≤2n4,ℙsubscriptPos222superscript22superscript4P ( ||J_ Pos|- m2 % |≤ m n )≤ 2 (- 2m ^2n )≤% 2n^4,blackboard_P ( | | Jroman_Pos | - divide start_ARG m end_ARG start_ARG 2 end_ARG | ≤ divide start_ARG m end_ARG start_ARG log n end_ARG ) ≤ 2 exp ( - divide start_ARG 2 m end_ARG start_ARG log2 n end_ARG ) ≤ divide start_ARG 2 end_ARG start_ARG n4 end_ARG , where the last inequality comes from Assumption (A5). And similarly ℙ(||Neg|−m2|≤2exp(−2mlog2n)≤2n4,P ( ||J_ Neg|- m2 % |≤ 2 (- 2m ^2n )≤ 2n^4,blackboard_P ( | | Jroman_Neg | - divide start_ARG m end_ARG start_ARG 2 end_ARG | ≤ 2 exp ( - divide start_ARG 2 m end_ARG start_ARG log2 n end_ARG ) ≤ divide start_ARG 2 end_ARG start_ARG n4 end_ARG , which completes the proof of (B1). Note that |nμ|∼Bin(n,1/4)similar-tosubscriptBin14|n_μ| Bin(n,1/4)| nitalic_μ | ∼ Bin ( n , 1 / 4 ). Applying Hoeffding’s inequality, we have ℙ(||nμ|−n4|≤nlogn)≤2exp(−2nlog2n)=O(1n4),∀μ∈±μ1,±μ2.formulae-sequenceℙsubscript422superscript21superscript4for-allplus-or-minussubscript1plus-or-minussubscript2P ( ||n_μ|- n4 |≤ n n)≤ 2% (- 2n ^2n )=O( 1n^4), ∀μ∈% \± _1,± _2\.blackboard_P ( | | nitalic_μ | - divide start_ARG n end_ARG start_ARG 4 end_ARG | ≤ divide start_ARG n end_ARG start_ARG log n end_ARG ) ≤ 2 exp ( - divide start_ARG 2 n end_ARG start_ARG log2 n end_ARG ) = O ( divide start_ARG 1 end_ARG start_ARG n4 end_ARG ) , ∀ μ ∈ ± μ1 , ± μ2 . (B2)-(B3) can be proved following the same procedure. We omit the proof here. ∎ A.2 Additional Experiments Figure 9: Training dynamics of the model fLsubscriptf_Lfitalic_L discussed in Section 3.3. A.3 Experimental details All experiments in the paper can be run on a single NVIDIA A100 GPU. The loss function for modular arithmetic tasks is cross-entropy loss and for (40,3)403(40,3)( 40 , 3 )-parity task is logistic loss. All models used in the paper, unless stated otherwise, set dmlp=4dembedsubscript4subscriptd_mlp=4d_embedditalic_m l p = 4 ditalic_e m b e d, where dmlpsubscriptd_mlpditalic_m l p is the MLP dimension and dembedsubscriptd_embedditalic_e m b e d is the embedding dimension. All FNN models used are in the paper are homogeneous and do not have bias terms. Code is available at https://github.com/zhiweixx/groktransfer. A.3.1 Experiments in Section 1 and 2.1 In Figure 1(b), we use a two-layer FNN with trainable embedding layer as the weak model. We choose (dembed,subscriptd_embed,ditalic_e m b e d , width) = (4,16)416(4,16)( 4 , 16 ) for the weak model. The target model is a three-layer FNN with trainable embedding layer. We choose (dembed,subscriptd_embed,ditalic_e m b e d , width) = (128,512)128512(128,512)( 128 , 512 ) for the target model. The hyperparameters (init scale,learning rate,weight decay)init scalelearning rateweight decay(init scale,learning rate,weight decay)( init scale , learning rate , weight decay ) are selected by the following grid search: init scale: [0.1,0.2,⋯,1.5]0.10.2⋯1.5 [0.1,0.2,·s,1.5][ 0.1 , 0.2 , ⋯ , 1.5 ] learning rate: [10−4,5×10−4,10−3,5×10−3,10−2,10−1]superscript1045superscript104superscript1035superscript103superscript102superscript101 [10^-4,5× 10^-4,10^-3,5× 10^-3,10^-2,10^-1][ 10- 4 , 5 × 10- 4 , 10- 3 , 5 × 10- 3 , 10- 2 , 10- 1 ] weight decay: [10−4,10−3,10−2,10−1,1,2,3,4,5].superscript104superscript103superscript102superscript10112345 [10^-4,10^-3,10^-2,10^-1,1,2,3,4,5].[ 10- 4 , 10- 3 , 10- 2 , 10- 1 , 1 , 2 , 3 , 4 , 5 ] . We select the configuration that first achieves 90%percent9090\%90 % accuracy on the validation set. The best configuration for GrokTransfer is (0.3,0.005,3)0.30.0053(0.3,0.005,3)( 0.3 , 0.005 , 3 ). For standard training, only learning rate and weight decay are tuned. They are selected by the following grid search: learning rate: [10−3,5×10−3,10−2,5×10−2,10−1]superscript1035superscript103superscript1025superscript102superscript101 [10^-3,5× 10^-3,10^-2,5× 10^-2,10^-1][ 10- 3 , 5 × 10- 3 , 10- 2 , 5 × 10- 2 , 10- 1 ] weight decay: [10−2,10−1,1,2,3,4,5],superscript102superscript10112345 [10^-2,10^-1,1,2,3,4,5],[ 10- 2 , 10- 1 , 1 , 2 , 3 , 4 , 5 ] , and the optimal configuration is (0.05,3)0.053(0.05,3)( 0.05 , 3 ). In Figure 2, we set the dimension of the GPT embedding to be 128128128128. For the Fourier embedding, we choose k=77k=7k = 7 frequencies, and let ijsubscripti_jiitalic_j to be the j-th smallest prime number. For each type of embedding, we normalize the embedding of each integer to be 1111. The FNN used in Figure 2 is a three layer dense neural network f(x)=W3ϕ(W2ϕ(W1x)),subscript3italic-ϕsubscript2italic-ϕsubscript1f(x)=W_3φ(W_2φ(W_1x)),f ( x ) = W3 ϕ ( W2 ϕ ( W1 x ) ) , where W1∈ℝwidth×embed dim,W2∈ℝwidth×width,W3∈ℝp×widthformulae-sequencesubscript1superscriptℝwidthembed dimformulae-sequencesubscript2superscriptℝwidthwidthsubscript3superscriptℝwidthW_1 ^width×embed dim,W_2 ^% width×width,W_3 ^p×widthW1 ∈ blackboard_Rwidth × embed dim , W2 ∈ blackboard_Rwidth × width , W3 ∈ blackboard_Rp × width, width=512absent512=512= 512. The hyperparameters (init scale,learning rate,weight decay)init scalelearning rateweight decay(init scale,learning rate,weight decay)( init scale , learning rate , weight decay ) are selected by the following grid search: init scale: [0.1,0.2,⋯,1.5]0.10.2⋯1.5 [0.1,0.2,·s,1.5][ 0.1 , 0.2 , ⋯ , 1.5 ] learning rate: [10−4,10−3,10−2,10−1,1]superscript104superscript103superscript102superscript1011 [10^-4,10^-3,10^-2,10^-1,1][ 10- 4 , 10- 3 , 10- 2 , 10- 1 , 1 ] weight decay: [10−4,10−3,10−2,10−1,1,5,10].superscript104superscript103superscript102superscript1011510 [10^-4,10^-3,10^-2,10^-1,1,5,10].[ 10- 4 , 10- 3 , 10- 2 , 10- 1 , 1 , 5 , 10 ] . We select the configuration that first achieves 90%percent9090\%90 % accuracy on the validation set. The best configuration (init,lr,wd)initlrwd(init,lr,wd)( init , lr , wd ) for the four embeddings are: One-hot: (0.2,0.01,5);Binary: (0.3,0.01,1);Fourier: (0.5,0.1,0.1)GPT: (1.3,0.01,1).One-hot: 0.20.015Binary: 0.30.011Fourier: 0.50.10.1GPT: 1.30.011 -hot: (0.2,0.01,5); : (0.3,0.01,1);% : (0.5,0.1,0.1) : (1.3,0.01,1).One-hot: ( 0.2 , 0.01 , 5 ) ; Binary: ( 0.3 , 0.01 , 1 ) ; Fourier: ( 0.5 , 0.1 , 0.1 ) GPT: ( 1.3 , 0.01 , 1 ) . In Figure 3, the distance between two empirical NTK is estimated following the method in Mohamadi et al. (2023). We denote Θ^tsubscript^Θ _tover start_ARG Θ end_ARGt as the pseudo-NTK of the model at epoch t, i.e. Θ^t(x1,x2)=[∇θ∑i=1pfθ(i)(x1)]⊤[∇θ∑i=1pfθ(i)(x2)]/p∈ℝ.subscript^Θsubscript1subscript2superscriptdelimited-[]subscript∇superscriptsubscript1superscriptsubscriptsubscript1topdelimited-[]subscript∇superscriptsubscript1superscriptsubscriptsubscript2ℝ _t(x_1,x_2)=[ _θ _i=1^pf_θ^(i% )(x_1)] [ _θ _i=1^pf_θ^(i)(x_2)]/p∈% R.over start_ARG Θ end_ARGt ( x1 , x2 ) = [ ∇θ ∑i = 1p fitalic_θ( i ) ( x1 ) ]⊤ [ ∇θ ∑i = 1p fitalic_θ( i ) ( x2 ) ] / p ∈ blackboard_R . We estimate the distance between the empirical NTK at step t and t−11t-1t - 1 by ‖Θt^−Θt−1^‖Fsubscriptnorm^subscriptΘ^subscriptΘ1F\| _t- _t-1\|_F∥ over start_ARG Θitalic_t end_ARG - over start_ARG Θitalic_t - 1 end_ARG ∥F. A.3.2 Experiments in Section 3 For experiments in Section 3, we let the sample size n=400400n=400n = 400, feature dimension p=8000080000p=80000p = 80000, and noise level ε=0.050.05 =0.05ε = 0.05. For Figure 4(a), the model is a two-layer neural network with width 2048204820482048. The optimizer is full-batch gradient descent with learning rate 0.10.10.10.1 and weight decay 0.10.10.10.1. In Figure 4(b), we train a small model with only three neurons. The initialization of the hidden layer follows i.i.d N(0,0.01)00.01N(0,0.01)N ( 0 , 0.01 ), and the initialization of the second layer follows i.i.d N(0,10−4)0superscript104N(0,10^-4)N ( 0 , 10- 4 ). The learning rate is 0.10.10.10.1 and weight decay is 0,1010,10 , 1. Figure 4(c) visualizes the hidden layer of that small model after training. Figure 5(a) generates 4000400040004000 i.i.d datapoints from the distribution P, and visualizes UxUxU x for each x. Figure 5(b) fixes n=10001000n=1000n = 1000 and train the weak model for p=[4×104,8×104,16×104,],ϵ=[1/40,1/80,1/160]p=[4× 10^4,8× 10^4,16× 10^4,],ε=[1/40,1/80,1/160]p = [ 4 × 104 , 8 × 104 , 16 × 104 , ] , ϵ = [ 1 / 40 , 1 / 80 , 1 / 160 ]. Figure 5(c) fixes p=8×1048superscript104p=8× 10^4p = 8 × 104 and train the weak model for n=[400,800,1600,3200],ϵ=[1/40,1/80,1/160]formulae-sequence40080016003200italic-ϵ1401801160n=[400,800,1600,3200],ε=[1/40,1/80,1/160]n = [ 400 , 800 , 1600 , 3200 ] , ϵ = [ 1 / 40 , 1 / 80 , 1 / 160 ]. Figure 9 takes vinit=0.4subscriptinit0.4v_init=0.4vinit = 0.4, learning rate 2.02.02.02.0 and zero weight decay. A.3.3 Experiments in Section 4 The attention layer used in this paper follows the same structure as that in Nanda et al. (2023). While Nanda et al. (2023) also suggested to set the precision to be float64 to mitigate the Slingshot phenomenon (Thilak et al., 2022), a fluctuation of accuracy and loss during training process, we still use float32 to control the computation cost. All modular tasks set p=113113p=113p = 113 and the fraction of training data being 25%percent2525\%25 %. For the (40,3)403(40,3)( 40 , 3 )-parity task, we set the sample size n=10001000n=1000n = 1000. Unless otherwise specified, we use the AdamW optimizer (Loshchilov & Hutter, 2019) for all experiments; we initialize the weights using the default PyTorch initialization scaled by a factor init scale>0init scale0init scale>0init scale > 0 to control the initial weight norm, as proposed by Liu et al. (2023). The hyperparameters (init scale,learning rate,weight decay)init scalelearning rateweight decay(init scale,learning rate,weight decay)( init scale , learning rate , weight decay ) are selected by the following grid search: init scale: [0.05,0.1,0.2,0.3,⋯,1.5]0.050.10.20.3⋯1.5 [0.05,0.1,0.2,0.3,·s,1.5][ 0.05 , 0.1 , 0.2 , 0.3 , ⋯ , 1.5 ] learning rate: [10−4,5×10−4,10−3,5×10−3,10−2,10−1,0.5,1.0]superscript1045superscript104superscript1035superscript103superscript102superscript1010.51.0 [10^-4,5× 10^-4,10^-3,5× 10^-3,10^-2,10^-1,% 0.5,1.0][ 10- 4 , 5 × 10- 4 , 10- 3 , 5 × 10- 3 , 10- 2 , 10- 1 , 0.5 , 1.0 ] weight decay: [10−4,10−3,10−2,10−1,1,2,3,4,5].superscript104superscript103superscript102superscript10112345 [10^-4,10^-3,10^-2,10^-1,1,2,3,4,5].[ 10- 4 , 10- 3 , 10- 2 , 10- 1 , 1 , 2 , 3 , 4 , 5 ] . In Figure 6(a),(b), the structure of weak and target model are the same as those in Figure 1(b). In Figure 6(c), the weak model is a three-layer width=16ℎ16width=16w i d t h = 16 FNN, the target model is a three-layer width=512ℎ512width=512w i d t h = 512 FNN with dembed=128subscript128d_embed=128ditalic_e m b e d = 128. In Figure 6(a), the optimal configuration for GrokTransfer is (0.3,0.001,1)0.30.0011(0.3,0.001,1)( 0.3 , 0.001 , 1 ) and the optimal one for training from scratch is (0.1,0.1,2)0.10.12(0.1,0.1,2)( 0.1 , 0.1 , 2 ). In Figure 6(b), the optimal configuration for GrokTransfer is (0.3,0.005,3)0.30.0053(0.3,0.005,3)( 0.3 , 0.005 , 3 ) and the optimal one for training from scratch is (0.1,0.1,2)0.10.12(0.1,0.1,2)( 0.1 , 0.1 , 2 ). In Figure 6(b), we have the number of training samples n=10001000n=1000n = 1000. The model trained via GrokTransfer uses learning rate 10−3superscript10310^-310- 3 and weight decay 10−3superscript10310^-310- 3; the model trained from scratch uses learning rate 10−2superscript10210^-210- 2 and weight decay 1111. In Figure 8(a), the weak model is a two-layer width-4444 FNN, and the target model is an 8888-layer transformer with dembed=512,dmlp=512,nhead=4,dhead=128formulae-sequencesubscript512formulae-sequencesubscript512formulae-sequencesubscriptℎ4subscriptℎ128d_embed=512,d_mlp=512,n_head=4,d_head=128ditalic_e m b e d = 512 , ditalic_m l p = 512 , nitalic_h e a d = 4 , ditalic_h e a d = 128. The optimal configuration for target model trained via GrokTransfer is (0.7,0.001,1)0.70.0011(0.7,0.001,1)( 0.7 , 0.001 , 1 ). The optimal configuration for target model trained from scratch is (0.4,0.0005,1)0.40.00051(0.4,0.0005,1)( 0.4 , 0.0005 , 1 ). In Figure 8(b), the weak model remains the same, and the target model becomes an 2222-layer transformer with dembed=128,dmlp=128,nhead=4,dhead=32formulae-sequencesubscript128formulae-sequencesubscript128formulae-sequencesubscriptℎ4subscriptℎ32d_embed=128,d_mlp=128,n_head=4,d_head=32ditalic_e m b e d = 128 , ditalic_m l p = 128 , nitalic_h e a d = 4 , ditalic_h e a d = 32. The optimal configuration for target model trained via GrokTransfer is (0.6,0.005,0.1)0.60.0050.1(0.6,0.005,0.1)( 0.6 , 0.005 , 0.1 ). The optimal configuration for GrokFast is (lr, wd) = (0.01,1.0)0.011.0(0.01,1.0)( 0.01 , 1.0 ). A.4 Additional Discussion A.4.1 Forward Pass FLOPs estimation for Models in Figure 8 Left For the weak model, a two-layer width-4444 MLP: Cforward≈2∗(8∗4+8+4∗113+226)=1436∼103.subscript284841132261436similar-tosuperscript103C_forward≈ 2*(8*4+8+4*113+226)=1436 10^3.Citalic_f o r w a r d ≈ 2 ∗ ( 8 ∗ 4 + 8 + 4 ∗ 113 + 226 ) = 1436 ∼ 103 . For the target model, an 8888-layer transformer, following Table 1 in Kaplan et al. (2020), we have N=2dembednlayer(2dattn+dmlp)=2∗512∗8∗(256+512)=6291456,Cforward=2(N+8∗2∗128)∼107.formulae-sequence2subscriptsubscript2subscriptsubscript251282565126291456subscript282128similar-tosuperscript107N=2d_embedn_layer(2d_attn+d_mlp)=2*512*8*(256+512)=6291456,C_forward% =2(N+8*2*128) 10^7.N = 2 ditalic_e m b e d nitalic_l a y e r ( 2 ditalic_a t t n + ditalic_m l p ) = 2 ∗ 512 ∗ 8 ∗ ( 256 + 512 ) = 6291456 , Citalic_f o r w a r d = 2 ( N + 8 ∗ 2 ∗ 128 ) ∼ 107 . Two-layer FNN takes around 2000200020002000 epochs to generalize. Target model trained by GrokTransfer takes around 1000100010001000 epochs and target model trained from scratch takes around 10000100001000010000 epochs. Thus, the total flops of GrokTransfer is around 1010+2∗106superscript10102superscript10610^10+2*10^61010 + 2 ∗ 106 and the total flops of training from scratch is around 1011superscript101110^111011. A.4.2 Additional Experiments Figure 10: Left: Training dynamics of the one neuron weak model. Middle: Visualization of the neuron in the weak model. We can see it has learned the feature [1,−1]11[1,-1][ 1 , - 1 ]. Right: training dynamics of the target model with embedding transferred from the one-neuron weak model. Figure 11: Left: Visualization of the distribution P with the embedding from the one-neuron weak model. Middle: Visualization of the distribution P with the embedding from the two-neuron weak model. Right: Training dynamics of the target model with embedding transferred from a two-neuron weak model. When the weak model only has one neuron, the target model has the form: fL(x)=∑j=1majϕ(vj⋅u⊤x).subscriptsuperscriptsubscript1subscriptitalic-ϕ⋅subscriptsuperscripttopf_L(x)= _j=1^ma_jφ(v_j· u x).fitalic_L ( x ) = ∑j = 1m aitalic_j ϕ ( vitalic_j ⋅ u⊤ x ) . Denote z=u⊤xsuperscripttopz=u xz = u⊤ x. Equivalently, we have fL(z)=∑j=1majϕ(vj⋅z).subscriptsuperscriptsubscript1subscriptitalic-ϕ⋅subscriptf_L(z)= _j=1^ma_jφ(v_j· z).fitalic_L ( z ) = ∑j = 1m aitalic_j ϕ ( vitalic_j ⋅ z ) . Note that sign(fL(z1))=sign(fL(z2))signsubscriptsubscript1signsubscriptsubscript2sign(f_L(z_1))=sign(f_L(z_2))sign ( fitalic_L ( z1 ) ) = sign ( fitalic_L ( z2 ) ) if sign(z1)=sign(z2)signsubscript1signsubscript2sign(z_1)=sign(z_2)sign ( z1 ) = sign ( z2 ). Since both positive and negative classes locate on both sides of the original point (Figure 11 Left), it is impossible for fLsubscriptf_Lfitalic_L to correctly classify these points (Figure 10 Right). Figure 12: 2D Visualization of the 4D embedding from the weak model trained on modular addition task. Figure 13: Left: Training dynamics of an 8-layer transformer on modular multiplication task. Right: Training dynamics of a 2-layer transformer and a 4-layer transformer trained on the (40,3)403(40,3)( 40 , 3 )-parity task. Both do not have delayed generalization. Figure 14: Training dynamics of a four-layer MLP on MNIST dataset. 200200200200 samples are randomly selected as the training data, and 400400400400 samples are randomly selected as the test data. The weak model is a two-layer MLP with 40%percent4040\%40 % test accuracy. Figure 15: (a)Training dynamics of the target models used in Figure 6 and Figure 8 Right, but with low-rank embedding A⋅B⋅A· BA ⋅ B, both A and B are randomly initialized. MLP represents a three-layer MLP and TF represents a two-layer transformer. (b) Merge A and B into ETsubscriptE_TEitalic_T at 100-th epoch and keep training. (c) Set the embedding dimension of the weak model to be the same as that of the target model. Figure 16: Dynamics of the target model (a two-layer Transformer) trained via GrokTransfer, and the target model trained via GrokFast on modular multiplication task.