Paper deep dive
JEDI: Jointly Embedded Inference of Neural Dynamics
Anirudh Jamkhandi, Ali Korojy, Olivier Codol, Guillaume Lajoie, Matthew G. Perich
Abstract
Abstract:Animal brains flexibly and efficiently achieve many behavioral tasks with a single neural network. A core goal in modern neuroscience is to map the mechanisms of the brain's flexibility onto the dynamics underlying neural populations. However, identifying task-specific dynamical rules from limited, noisy, and high-dimensional experimental neural recordings remains a major challenge, as experimental data often provide only partial access to brain states and dynamical mechanisms. While recurrent neural networks (RNNs) directly constrained neural data have been effective in inferring underlying dynamical mechanisms, they are typically limited to single-task domains and struggle to generalize across behavioral conditions. Here, we introduce JEDI, a hierarchical model that captures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights. This model recapitulates individual samples of neural dynamics while scaling to arbitrarily large and complex datasets, uncovering shared structure across conditions in a single, unified model. Using simulated RNN datasets, we demonstrate that JEDI accurately learns robust, generalizable, condition-specific embeddings. By reverse-engineering the weights learned by JEDI, we show that it recovers ground truth fixed point structures and unveils key features of the underlying neural dynamics in the eigenspectra. Finally, we apply JEDI to motor cortex recordings during monkey reaching to extract mechanistic insight into the neural dynamics of motor control. Our work shows that joint learning of contextual embeddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone.
Tags
Links
- Source: https://arxiv.org/abs/2603.10489v1
- Canonical: https://arxiv.org/abs/2603.10489v1
PDF not stored locally. Use the link above to view on the source site.
Intelligence
Status: failed | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 0%
Last extracted: 3/13/2026, 1:10:50 AM
OpenRouter request failed (402): {"error":{"message":"This request requires more credits, or fewer max_tokens. You requested up to 65536 tokens, but can only afford 58929. To increase, visit https://openrouter.ai/settings/keys and create a key with a higher monthly limit","code":402,"metadata":{"provider_name":null}},"user_id":"user_2shvuzpVFCCndDdGXIdfi40gIMy"}
Entities (0)
Relation Signals (0)
No relation signals yet.
Cypher Suggestions (0)
No Cypher suggestions yet.
Full Text
62,232 characters extracted from source content.
Expand or collapse full text
JEDI: Jointly Embedded Inference of Neural Dynamics Anirudh Jamkhandi 1 2 Ali Korojy 1 2 Olivier Codol 1 2 Guillaume Lajoie * 1 2 Matthew G. Perich * 1 2 Abstract Animal brains flexibly and efficiently achieve many behavioral tasks with a single neural net- work. A core goal in modern neuroscience is to map the mechanisms of the brain’s flexibil- ity onto the dynamics underlying neural popu- lations. However, identifying task-specific dy- namical rules from limited, noisy, and high- dimensional experimental neural recordings re- mains a major challenge, as experimental data often provide only partial access to brain states and dynamical mechanisms. While recurrent neu- ral networks (RNNs) directly constrained neural data have been effective in inferring underlying dynamical mechanisms, they are typically lim- ited to single-task domains and struggle to gen- eralize across behavioral conditions. Here, we introduce JEDI, a hierarchical model that cap- tures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights. This model recapitulates individual sam- ples of neural dynamics while scaling to arbi- trarily large and complex datasets, uncovering shared structure across conditions in a single, uni- fied model. Using simulated RNN datasets, we demonstrate that JEDI accurately learns robust, generalizable, condition-specific embeddings. By reverse-engineering the weights learned by JEDI, we show that it recovers ground truth fixed point structures and unveils key features of the underly- ing neural dynamics in the eigenspectra. Finally, we apply JEDI to motor cortex recordings during monkey reaching to extract mechanistic insight into the neural dynamics of motor control. Our work shows that joint learning of contextual em- beddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone. * Equal contribution 1 Mila - Quebec AI Institute 2 Université de Montréal.Correspondence to:Anirudh Jamkhandi <anirudh.jamkhandi@umontreal.ca>,MatthewG.Perich <matthew.perich@umontreal.ca>. Preprint. March 12, 2026. 1. Introduction Brains are complex and nonlinear dynamical systems com- posed of a network of highly specialized neural circuits. The neural computations that govern complex behaviors (e.g., playing guitar or writing machine learning papers) ultimately arise from the time-varying interactions of neural populations distributed across the brain (Duncker & Sahani, 2021). With modern advances in recording technologies, experimentalists can record the activity of ever-larger popu- lations of neurons, providing an unprecedented window into neural computation. A critical question then arises: How can we obtain mechanistic insights into the neural dynamics un- derlying these behaviorally-relevant neural computations? The structure of neural interactions plays a central role in systems neuroscience, where the patterns and strength of connectivity between neurons shape the resulting dynamics and computation (Liu et al., 2024; Schuessler et al., 2024; Ostojic & Fusi, 2024). Analysis of these interaction weights can reveal how connectivity constrains and enables flexi- ble computations (Braun et al., 2022; Raman & O’Leary, 2021; Canatar et al., 2021), but these weights are not readily accessible in typical experimental settings. Over the past couple of decades, recurrent neural networks (RNNs)—which incorporate many canonical features of neural circuits such as recurrence and nonlinear interac- tions—have emerged as a powerful model for studying neu- ral population activity (Sussillo & Abbott, 2009; Hess et al., 2023; Dinc et al., 2023; Valente et al., 2022; Perich et al., 2021; Durstewitz et al., 2023). RNNs trained to perform neuroscience-inspired tasks can be compared to animals performing similar tasks, allowing for indirect exploration of mechanisms in the biological brain. However, a growing class of data-constrained RNNs (dRNNs) trained to recapit- ulate neural recordings provide a complementary approach to make direct inference of brain computations. Ultimately, most existing dRNNs typically assume a fixed set of weights corresponding to one dynamical system for the observed data (e.g., each behavior). In contrast, biologi- cal networks flexibly operate in different dynamical regimes even within a single task (Huang et al., 2024; Turner et al., 2021). Thus, dRNN models should have similar dynam- ical flexibility to accurately capture neural computations necessary for complex and evolving behavioral demands. 1 arXiv:2603.10489v1 [q-bio.NC] 11 Mar 2026 JEDI: Jointly Embedded Inference of Neural Dynamics Fixed weight dRNN models lack the flexibility to account for changes in behavioral tasks or contexts. This limita- tion underscores the need for dRNN approaches that can accommodate context-dependent variations in neural com- putations and generalize robustly across behavioral tasks and conditions. To address this limitation, we introduce JEDI - Jointly Em- bedded Dynamics Inference from neural data. JEDI uses a hierarchical, hypernetwork-based framework to infer neu- ral dynamics directly from time series of neural population recordings (Fig. 1). It leverages a shared embedding space to learn the parameters of arbitrary dRNNs that recapitulate the time series of recorded neural data (e.g. one trial of a behav- ioral task) using only a contextual input to the embedding. Learning in this shared model allows the model to capture common dynamics across datasets, tasks, or contexts, while providing flexibility for sample-specific dynamical features. Further, by jointly learning RNN weights and embeddings, we can infer interpretable link between weight structure and neural computations. Our contributions can be summarized as follows: •We propose a novel hypernetwork framework to model complex neural data from varying behavioral condi- tions and contexts in a single unified model. • Using the learned task embeddings, we demonstrate accurate classification of dynamical regimes and gen- eralization to unseen samples from learned tasks. •We show robust inference of dynamical rules from RNN weights through the eigenspectra, Lyapunov ex- ponents, and fixed point structure of learned weights. 2. Methods 2.1. Low-rank RNNs RNNs are universal approximators for dynamical systems, whose internal computations can be reverse engineered through the recurrent weights. We use rate-based recur- rent network of N units, in which the variablerfollows the dynamics: τ dr dt =−r(t) + Jφ(r(t)) + ξ(t),(1) with neuron activityr(t) ∈R N , time-constantτ ∈R >0 , recurrent weightsJ ∈R N×N , element-wise nonlinearity φ(x) = tanh(x), private white noiseξ(t) ∈ R N provided to each neuron. The possible combinations of allJthat can reproduce a given sample of neural activity is potentially large. To re- Target data RNN output dRNN Context embedding Loss MLP Training JEDI Figure 1. JEDI leverages a hypernetwork-based framework to flexibly generate RNN weights based on contextual inputs, all learned directly from a loss computed against the RNN output and time series neural recordings. duce this solution degeneracy (Huang et al., 2024), we con- strainJwith a low-rank penalty, compressing the recurrent weight matrix into a few dominant modes. Hence, we are interested in the case where the recurrent weight matrix has rankR ≤ N, i.e., it can be written asJ = MN T , withM,N ∈R N×R (Mastrogiuseppe & Ostojic, 2018; Schuessler et al., 2020; Beiran et al., 2021; Dubreuil et al., 2022). In the case where the recurrent weights are uncon- strained we will call the model “full rank”. 2.2. Context-informed Hypernetwork Drawing inspiration from the top-down modulation of cir- cuits in the neocortex, we introduce conditioning through a hypernetwork (Ha et al., 2016)—a secondary network that dynamically generates the parameters of a primary network (here, a dRNN). The hypernetwork acts as a controller that receives a specific context signal (like a task instruction or a sensory goal) and uses it to dynamically reconfigure the primary network. Instead of the primary network having a single, static weights, it is conditioned to adapt its internal dynamics on the fly. Our implementation uses a feedforward layer that learns context embeddings to flexibly parameter- ize the downstream RNN responsible for reconstructing the dynamics of neural recordings. J = f h (c)(2) wheref h is a non-linear function parameterized byΘ h and c ∈R E is the context vector,Eis the size of the context vector. Hence, equation (1) can be reformulated as : τ dr dt =−r(t) + f h (c)φ(r(t)) + ξ(t),(3) From Equation (3) it is clear that context vector can signifi- cantly affect the overall dynamics for appropriate choices off h (·). Modeling choices for the contextccan influence the degree of granularity of the hypernetwork. Indeed,c could be made to encode any variable, e.g. the identity of a specific trial, subject, or behavioral task. Rather than learn recurrent weights for each trial of neural activity data sepa- rately, we generate the trial-specific weights from a single, 2 JEDI: Jointly Embedded Inference of Neural Dynamics learned hypernetwork that shares efficiently information across trials. When the recurrent weights are constrained to have low rank structure, the output of the hypernetwork are the matricesM, N. In our experiments, we initialize a new context c with zero entries for each trial. 2.3. Training Objective The goal is to estimate the recurrent weights to reproduce the recorded neural population activity through minimization of a loss function. We use back-propagation through time to minimize the squared difference between the predicted and ground truth trajectories. The mean squared error (MSE) loss on each trial j can be expressed as: L(Θ h ,c) = N X i=1 T X t=1 φ r (c) i (t) − φ ˆr (c) i (t) 2 (4) wherer (c) (t))represents the target trajectory in condition c, for neuron i and timestep t andˆr (c) i (t)the corresponding predicted trajectory of the model. Similar to Von Oswald et al. (2019), we treat the context vector as a differentiable, deterministic parameter that can be optimized alongside Θ h . At each learning step, the current context embedding c (j) is updated along with the hypernetwork weightsΘ h to minimize the reconstruction loss. After training, the resulting context is stored and added to the collectionc (j) . The nature of this collection will depend on the granularity chosen for the context embedding. In the following work, this collection represents trials from a set of tasks, dynamical contexts, or behavioral conditions, but for other applications could comprise experimental subjects or even brain regions. 3. Empirical Results In the following sections, we validate JEDI and demonstrate its ability to infer the dynamical structure of recorded neural population activity. Since JEDI jointly learns contextual embeddings and neural dynamics, we designed a series of experiments to explore these two properties. We first use a ground truth dataset modeling multi-task scenarios to show JEDI learns robust and generalizable embeddings. We use this experiment to explore how architectural decisions, such as low-rank weight structure, impacts model performance. We then leverage two further datasets to show that JEDI can accurately infer dynamical features by spectral analysis of the eigenvalues and fixed points analysis. Lastly, we demon- strate JEDI’s efficacy in modeling real neural recordings. 3.1. Teacher Setup for generating synthetic data To evaluate the performance of the model, we studied a teacher-student paradigm (Saad & Solla, 1995; Seung et al., 1992; Beiran & Litwin-Kumar, 2024) where a teacher net- work is used as a proxy for a neural system with known recurrent connectivity. The student (our model) is trained to mimic the teacher. We generated synthetic data from a chaotic teacher RNN with pre-defined recurrent connectiv- ity and input structure (Fig. 2.A). The RNN was composed of N=200 neurons, in which the activity (firing rate)h i (t) of the neuron i evolved according to the dynamical update: τ dh dt =−h(t) + gJφ(h(t)) + W ext U (t),(5) whereh(t) ∈R N is the RNN state at timet,φis nonlin- ear activation function,∆tis the simulation timestep,τis the neural time constant, andu(t) ∈R N is time varying external input.J ∈R N×N is the recurrent connectivity matrix sampled from a Gaussian distributionN (0,σ 2 ). To capture the low-dimensional structure inherent in neural ac- tivity, we impose a low-rank constraint on the connectivity matrixJ(Perich et al., 2025) by initializing it with rank R = 5. The gain parametergdetermines the strength of the recurrent connections, and thus whether (g > 1) or not (g < 1) the network produces spontaneous activity with non-trivial dynamics (Rajan et al., 2010). We setg = 1.8 to produce chaotic dynamics shaped by the external inputs. W ext ∈ R N×N is input matrix that maps input signals to the recurrent matrix. To probe JEDI’s ability to capture diverse dynamical regimes in a single model, we drove the teacher RNN with a range of external input signalsu(t)that shaped the temporal evolution of the RNN (Fig. 2.A). Inputs included oscillatory (sine, cosine, square), ramping, decaying, and moving fixed- point (step-like) patterns. These inputs were directly applied to a subset of neurons in the chaotic RNN (50% of the popu- lation), selected via an input weight matrixW ext ∈R N . We refer to the input types as tasks, and simulated the variability across trials by repeating each input from different random initial RNN states. We simulated the activity ofi th neuron for 2 seconds following a 0.1 second burn-in period. The resulting activity serves as ground-truth trials of a known dynamical system, providing a testbed for assessing the generalization and interpretability of our model. We trained JEDI models to recapitulate the teacher RNN’s activity across all tasks. For our first experiment, we aimed to demonstrate robust and generalizable learned embed- dings capturing relationships in neural dynamics across tasks. We compared JEDI against three common approaches to learn low-dimensional embeddings of neural dynamics across trials and tasks. We trained Variational Autoen- coders—both feedforward (VAE) and recurrent (RNN-VAE) variants—to reconstruct neural activity on each trial from a low-dimensional embedding space. Note that these alter- native methods are intended to contextualize the utility and performance of JEDI’s embeddings, but are not fully com- parable to JEDI since they do not provide the mechanistic insight into neural dynamics through interaction weights. 3 JEDI: Jointly Embedded Inference of Neural Dynamics C B A Sine Decay Fix-pt Cosine Ramp Square PCs R 2 V A E R N N - V A E J - E D I f u l l J E D I Teacher RNN Generalization Reconstruction Sine RampFix-pt DecayCosineSquare 1.0 0.0 FE R 2 1.0 0.0 D e c a y C o s i n e S q u a r e R a m p S i n e F i x - p t Square Sine Decay Ramp Fix-pt Cosine V A E R N N - V A E J E D I - f u l l J E D I R 2 1.0 -5 RNN-VAE JEDI-full JEDI VAE RNN-VAE JEDI-full JEDI VAE Classification Accuracy 1.0 0.0 D Ground truth model of multi-task dynamics Embedding performance metrics Learned embeddings of multi-task neural data Generalization of embeddings to new samples V A E R N N - V A E J - E D I f u l l J E D I Figure 2. Quantifying the quality of the embeddings. A) Synthetic multi-task data generation setup. B) 2D PCA visualization of context embeddings. Each point corresponds to a trial, color-coded by input signal type. C) Training reconstruction accuracyR 2 across different methods. D) Accuracy of task classification from the learned embeddings. E) Generalization accuracyR 2 from the center of the learned embedding for each task F) Confusion matrix of generalizationR 2 scores, expanding the results in Panel E. Rows represent training tasks, and columns indicate test tasks. 3.2. JEDI learns generalizable embeddings across tasks We trained JEDI to recapitulate the activity generated by the teacher RNN using two model configurations: one generat- ing full rank RNN weights, and one constrained to produce RNN weights with rankr(set to 5 to match the ground truth data). We configured the hypernetworks as 3-layer feedforward networks (MLPs) and compared performance of these JEDI models against the VAE and RNN-VAE meth- ods. Using coefficient of determination (R 2 ) as our metric, we found that all models achieve high reconstruction accu- racy of the teacher RNN activity (Fig. 2.C). Notably, JEDI and VAE gave more consistently good fits (smaller variance across trials) than JEDI-full and RNN-VAE. A good contextual embedding should learn separable and interpretable structure across tasks. To evaluate if the em- beddings learned the multi-task structure, we trained a Naive Bayes classifier on the learned representations and evalu- ated its accuracy on held-out trials. All models achieved near-perfect classification accuracy, (Fig. 2.D), indicating that the dynamical regimes induced by different input types are readily separable in latent space. However, a more strin- gent test is whether these embeddings are generalizable and robust. We evaluated across-task generalization by sampling the mean embedding for each task from training data and used this embedding to generate predictions on held-out test data across all tasks. A model that captures only task-specific identity will have high values along the diagonal (general- ization within task), whereas a model that uncovers shared structure across tasks will show off-diagonal generalization. While some tasks were dynamically distinct and should have no generalization (e.g. ramp vs sine), others have com- mon structure which could be leveraged by the embedding (e.g. the oscillatory tasks). JEDI outperformed other mod- els in across-task generalization (Fig. 2.E). For example, embeddings derived from sine generalize to cosine, and embeddings from ramp transfer to decay(Fig. 2.F). Interest- ingly, JEDI-full performed substantially worse, indicating that structural low-rank constraints in the RNN weight out- puts of JEDI enable learning generalizable representations in the embedding space. We next tested the robustness of the embeddings by adding gaussian noise of increasing variance to the learned values for each trial and decoding the resulting activity. Because JEDI’s embeddings were not constrained to Gaussian dis- tributions like the VAEs, we scaled the perturbation mag- nitude according to the specific variance of each model’s task embeddings. JEDI-full and RNN-VAE exhibit sharp performance declines (Fig. 3) in the presence of the per- turbations. While VAEs retain some performance at very high noise levels, JEDI outperforms all other models during low-to-moderate perturbations, indicating that the model learned robust, structured representations in the embedding. 3.3. JEDI uncovers ground truth spectral properties of neural dynamics We next explore JEDI’s ability to infer dynamical properties of neural data by reverse-engineering the RNN weights. We devised a variant of the previous experiment where the teacher RNN was driven by sinusoidal inputs of increasing 4 JEDI: Jointly Embedded Inference of Neural Dynamics 0.0 0.00 0.02 0.04 Perturbation magnitude Reconstruction R 2 0.06 0.08 0.5 -0.5 1.0 RNN-VAE VAE JEDI-full JEDI Robustness of embeddings to perturbations Figure 3. Impact of embedding noise on model performance, comparing reconstructionR 2 as increasing noise is applied to the embeddings. The added noise was scaled according to the standard deviations of embeddings for each model frequencies (Fig. 4.A). After training on multiple samples of each frequency, we analyzed the eigen spectrum of the learned JEDI RNN weights. The imaginary part of the eigenvalues indicate rotational velocity, while the real part of eigenvalues indicate stability in associated eigen directions: smaller than zero corresponds to stability and greater than zero, instability. We observed that eigenvalues formed distinct clusters within the complex plane based on the input frequency of the asso- ciated trials. Crucially, as the input frequency increased, the eigenspectra exhibited a corresponding expansion along the imaginary axis, directly capturing the higher oscillatory con- tent inherent in the driving signal. In contrast, the real com- ponents remained relatively invariant, suggesting that JEDI selectively adapts the rotational dynamics of the network to match the data’s spectral properties without compromising global stability. Notably, this precise spectral alignment was absent in models trained via standard full-rank initialization (see Appendix S6), which failed to recover the underly- ing dynamical structure. Furthermore, while we focus here on controlled sinusoidal inputs, JEDI demonstrates similar spectral consistency in more complex settings; for a compar- ison with the original eigen spectra derived from multi-task learning paradigms (see Appendix S4). 3.4. JEDI can recover ground truth fixed point structure present in the neural dynamics We explored the ability of JEDI to infer mechanisms of neural computation using the framework proposed by Yang et al. (2019), who trained RNNs to perform a range of neuroscience-like tasks. Following the protocol of Versteeg et al. (2025), we used the multi-task trained RNN to simulate trials for the MemoryPro task across four distinct periods: context, stimulus, memory, and response, see (Fig. 5.A). The fixed-point structure of the MemoryPro task was previously characterized by Driscoll et al. (2024), providing a clear ground truth target of dynamical mechanisms to infer by JEDI (see Appendix A.1.1 for comprehensive details). After training on the large number of trials and contexts, JEDI yielded a low reconstruction loss and highR 2 (0.94). We first examined whether the learned context embeddings preserved the functional organization of the task. We found that JEDI’s embeddings capture the task’s logical struc- ture. Specifically, we found distinct, phase-specific rings organized according to stimulus direction (Fig. 5.B left), confirming that JEDI recovers the geometric relationships between input conditions (Fig. 5.B right). Next, to reverse-engineer the trial-specific computations, we identified slow-moving points within the model’s hid- den state space—commonly referred to as fixed points (Sussillo & Barak, 2013)—derived from JEDI’s trained weights. These fixed points were projected into separate low-dimensional principal component (PC) spaces and vi- sualized collectively for analysis. We observed that the fixed-point structure extracted by JEDI closely mirrored the shape of the original task-trained RNN (Fig. 5.C&D). RNN states evolved from a single, central fixed point that cued the context towards geometrically related, target-specific fixed points, forming an approximate ring attractor. This is consis- tent with the fixed-point characteristics previously reported in Driscoll et al. (2024) for RNNs performing this task. This alignment with established structures confirms that JEDI’s 1 Hz 1-1 -3 -1 1 3 -3 Real part Imaginary part -5 -7 2 Hz 3 Hz 4 Hz 5 Hz 6 Hz 7 Hz 8 Hz 9 Hz 10 Hz A B Inputs Outputs Inferring spectral properties of neural dynamics Figure 4. A) We drove the chaotic Teacher RNN with sinusoidal inputs at different frequencies (1–10 Hz). B) The eigen spectra of the weights inferred with JEDI exhibit a characteristic expansion along the imaginary axis as input frequency increases. 5 JEDI: Jointly Embedded Inference of Neural Dynamics A B CD ContextStimulusMemory 36 directional cues Fixation/response context Directional response PCs (neural) PCs (embedding) Response Task-trained RNN MemoryPro Task Ground truth JEDI Direction Ground truth dynamics from task-trained RNNs Learned embeddings in JEDI Fixed point structure underlying task performance Figure 5. JEDI identifies fixed point structure in task-trained networks. A) We fit JEDI to a network trained to perform the MemoryPro task. This task has four contextual trial phases. B) Context embedding learned by JEDI colored by trial phase (left) and response direction (right). C) Fixed point structure for the four trial phases for the ground-truth task-trained network. D) Fixed point structure inferred by JEDI. embeddings and weights effectively capture meaningful fea- tures of computations underlying neural activity. 3.5. JEDI flexibly models population recordings from monkey motor cortex We then applied JEDI to infer dynamical properties from real neural recordings. We fit JEDI models on motor cortical recordings from macaque monkeys performing a center-out reaching task (Perich et al., 2018). Neural population ac- tivity was recorded simultaneously from the primary motor and premotor cortex with two electrode arrays. The mon- key was trained to perform an instructed delay task, which involved a movement preparation phase followed by a go cue instructing the monkey to reach to one of eight tar- gets (movement phase) (see Fig. 6.A). After training, JEDI learned a clear ring-like structure in the embedding cor- responding to the eight reach directions (Fig. 6.B), with comparable performance in fit quality, classification accu- racy, and generalization as the other, embedding-specific methods (Fig. 6.C-E). 3.6. Spectral analysis reveals reorganization of dynamics during movement in monkey reaching While interpretable embeddings are a useful component of our model that helps contextualize our learned dynamcis, the ultimate goal of JEDI is to directly infer dynamical mechanisms from neural data. To explore this, we compared movement preparation (a covert process without explicit motor output) to movement execution (an overt process that moves the arm). These two processes have been shown to involve distinct neural computations (Elsayed et al., 2016), and we hypothesized that they should consequently have distinct dynamical properties which JEDI can uncover. Spectral analysis of the trained recurrent matrices generated by JEDI (but not JEDI-full, see Appendix S5) revealed clear changes in dynamical structure across the two behavioral phases. During preparation, the eigenvalue distributions were largely clustered within the unit circle except for promi- nent excursions largely along the real axis (Fig. 6.G top), consistent with the need to generate ramping dynamics that take the brain from a quiescent state to one ready to produce behavior. During movement, however, groups of eigenval- ues emerged that were tightly clustered along the unit circle near the zone of marginal stability (Fig. 6.G bottom), a de- sirable regime for efficient neural computation (Legenstein & Maass, 2007). Reassuringly, these dynamics were identi- cal across all trials, consistent with the assumption that the underlying dynamical system driving movement does not change for different reach directions. We further explored whether execution-related dynamics could represent a shift towards criticality (edge of stability). We analyzed the Lyapunov exponents (Vogt et al., 2022) of the inferred weights, which quantify the long-horizon behavior of the system, between preparation and execution. Values above zero indicate a tendency towards chaos, values below zero indicate stability, and values near zero are con- sistent with criticality or edge-of-chaos. We found that the transition between preparation and movement corresponded to an increase in the maximum Lyapunov exponent towards zero (Fig. 6.J &K). This marginal stability reflects the mo- tor system’s trade-off between robustness and expressivity (Russo et al., 2018). 3.7. Uncovering fixed points in monkey motor cortex We lastly assessed the fixed point structure inferred by JEDI to test possible mechanisms governing the evolution of neu- ral trajectories that underlie reaching movements in the motor cortex (Gallego et al., 2017). Using the procedures described in the task-trained RNN experiment, we identified 6 JEDI: Jointly Embedded Inference of Neural Dynamics AB CDEF Target on Go cue Preparation Execution Reward Move on R 2 R e c o n s t r u c t i o n V A E R N N - V A E J E D I - f u l l J E D I PCs VAE RNN-VAE JEDI-full JEDI VAE RNN-VAE JEDI-full JEDI 1.0 . 0 E NE N NW W S SW SE E NE N NW W SW S SE V A E R N N - V A E J E D I - f u l l J E D I G e n e r a l i z a t i o n R 2 1.0 -3 Classification Accuracy V A E R N N - V A E J E D I - f u l l J E D I 1.0 0.0 Max Lyapunov Exp. 0 7 7 0 i/N Preparation PCs (Neural) Preparation Execution Prep. Exec Execution 0 0 -1 -0.7 ||eigenvalues|| Frequency GHJK L 0 180 Frequency Phase(eigenvalues) I Real part -1-3-5-7-91 Real part -1-3-5-7-91 Imag. part -1 1 Imag. part -1 1 E NE N NW W SW S SE Motor cortex recordings during reachingLearned embeddings across reaching directions Embedding performance metricsGeneralization of embeddings to new samples Spectral analysis of motor cortical activity Inferred fixed point structure in motor cortex Transition towards edge-of-chaos during execution Figure 6. JEDI applied to monkey motor cortex data during reaching. A) We apply JEDI to recordings of the motor and premotor cortex during a center-out reaching task. Schematic adapted from Gallego et al. (2020). B) 2D PCA projections of the learned embeddings for each model. Each point corresponds to a single trial, color-coded by reach direction. C) ReconstructionR 2 across all trials for each model. D) Classification accuracy for reach direction from each learned embedding. E) GeneralizationR 2 using the center of each embedding cluster for the different models. F) Confusion matrix for generalization across different contexts for each model, further visualizing the results in Panel E. G) Eigenspectra of learned JEDI weights during preparation (top) and movement (bottom) colored by the 8 reach directions. During movement, eigenvalues clustered close to the unit circle, indicating dynamics approaching the edge of stability. H) Quantification of eigenspectra changes by the magnitude of each eigenvalue for preparation (gray) and execution (green). I) Quantification of eigenspectra changes by the phase of each eigenvalue. J) Lyapunov exponents across all learned weights for preparation and execution. K) Distributions of maximum Lyapunov exponents across all learned trials for preparation and execution. L) Stable fixed points during execution inferred by JEDI colored by reach directions. Single-trial neural trajectories for each corresponding reach is also plotted. stable fixed points in the neural population activity (Fig. 6.L). These fixed points were clustered by reach direction, con- sistent with different attractor zones for different reaching conditions. Intriguingly, we found that neural trajectories identified by PCA ended at these stable fixed points; this end point corresponds to the end of the monkey’s reach. The fixed points uncovered by JEDI, together with the eigen- spectra and Lyapunov exponents, posits that neural dynam- ics for reaching are governed by marginally-stable trajec- tories towards stable attractors. This mechanistic insight highlights the strength of JEDI for analyzing neural data. 4. Related Work RNNs are ubiquitous in computational neuroscience, with the aim to explain neural computations through the lens of underlying dynamical principles (Vyas et al., 2020). Prior work highlighted the role of contexts and task variables in shaping neural computations (Yang et al., 2019; Driscoll et al., 2024; Costacurta et al., 2024; Williams et al., 2025). 7 JEDI: Jointly Embedded Inference of Neural Dynamics Typically, these studies use fixed (not learned) contextual inputs to probe cognitive and decision-making tasks. These works share an emphasis on contextual inputs with our work, but focus on training networks de novo to learn specific tasks. They do not attempt to infer weight structures directly from experimentally-recorded neural activity, our primary goal. Another common application of neural networks for neu- roscience is latent state inference, including LFADS (Pan- darinath et al., 2018) and XFADS (Dowling et al., 2024). CEBRA (Schneider et al., 2023), another latent state in- ference tool, learns latent embeddings jointly from neural and behavioral data, paralleling our use of context to align across conditions. Unlike these methods, ours aims to in- fer interpretable weight structure to reproduce the full time series of neural population recordings. Existing dRNN methods focus on fixed dynamical systems. Methods such as CURBD (Perich et al., 2021) and CORNN (Dinc et al., 2023) directly learn from time series neural data. Others model neural activity at in the latent space with low-rank RNNs (Pals et al., 2024; Valente et al., 2022). Our method shares this low-rank assumption and its structural implications, but extends them in two ways: (1) we fit data at full neural resolution to enable the study of individual neural interactions; and (2) we explicitly link weight struc- ture to learned context embeddings. Lastly, recent work on motor adaptation through low-tensor rank RNNs (Pelle- grino et al., 2023) captures trial-level variation in weights, which is conceptually similar to our context-dependent vari- ations, though our work is more flexible, e.g. for modeling variations in behavioral tasks and even subjects. Recent work has combined learned contexts with neural data modeling. Hierarchical state-space models (SSMs) (Ver- mani et al., 2024) and meta-learning frameworks (Cotler et al., 2023) integrate recordings across contexts, akin to our contextual embeddings. Hierarchical models for time series (Brenner et al., 2024) and Bayesian models of decision cri- teria (Vloeberghs et al., 2025) share our goal of extracting contextual structure (Kirchmeyer et al., 2022). Relatedly, an SSM neural decoding architecture showed that task-specific embedding layers enable shared dynamics models (Ryoo et al., 2025), though these lack the interpretability of JEDI. Lastly, context-informed dynamics models that generalize across physical systems Nzoyem et al. (2025) share com- mon goals with our work, as they explicitly study recurrent models in weight space. However, unlike their linearized ar- chitectures, our work applies to nonlinear, data-constrained RNNs. Ultimately, the above methods focus on using con- text learning to improve generalization and do not address our core aim: to mechanistically probe RNN weights. 5. Discussion Summary. In this work, we introduced JEDI, a hierarchical hypernetwork model that generates low-rank autonomous RNNs to fit neural data. Our core objective was to con- firm that contextual embeddings that parameterize recurrent weight matrices could yield interpretable and generalizable models of neural computations via dynamical systems. JEDI simultaneously reveals contextual relationships in neural activity between experimental conditions (e.g., behavioral tasks) and underlying dynamical mechanisms. We demonstrated that the low-rank hypernetworks used in JEDI simultaneously: (i) produced embeddings that reliably classified varying dynamical regimes and generalized to new data; and (i) generated recurrent interaction matrices whose eigenvalue spectra aligned with known dynamical signatures in synthetic systems and revealed new motifs in real neural recordings. Our results indicate that context-conditioned weight generation uncovers invariant and equivariant struc- ture in neural activity. Low-rank constraints act as an in- ductive bias that improves robustness and generalization. Importantly, spectral analysis of the learned interactions offers mechanistic insight to the dynamical rules govern- ing computation. Prior low-rank RNN studies established that structured connectivity shapes computation. Our work extends this perspective by linking context embeddings to recurrent weight organization through hypernetworks, a first step towards bridging latent variable approaches with weight-space analyses. JEDI thus provides a scalable route to integrate heterogeneous datasets, generate hypotheses about neural dynamics, and probe conserved motifs across tasks. Future work should pursue adaptive regularization schemes to relax rank specification, incorporate priors that capture biologically plausible connectivity when these variables are known a priori, and extend the framework to settings with incomplete or multimodal observations. Applying the model to larger-scale neural datasets across tasks and species should reveal conserved organizational principles and fur- ther validate the capacity of hypernetwork-driven architec- tures to capture the structure of neural computation. Limitations. The explicit specification of rank imposes a manual design choice; adaptive strategies such as nuclear norm regularization (Scarvelis & Solomon, 2024) may alle- viate this. The models remain susceptible to vanishing and exploding gradients, suggesting the need for stabilized ar- chitectures (e.g., CORNN (Rusch & Mishra, 2020)). Finally, the reliance on partial observations in neural recordings may introduce mechanistic mismatches, limiting generalizability to unobserved populations. 8 JEDI: Jointly Embedded Inference of Neural Dynamics References Beiran, M. and Litwin-Kumar, A. Prediction of neural activity in connectome-constrained recurrent networks. bioRxiv, 2024. Beiran, M., Dubreuil, A., Valente, A., Mastrogiuseppe, F., and Ostojic, S. Shaping Dynamics With Multiple Popula- tions in Low-Rank Recurrent Networks. Neural Compu- tation, 33(6):1572–1615, 2021. Braun, L., Dominé, C., Fitzgerald, J., and Saxe, A. Exact learning dynamics of deep linear networks with prior knowledge. Advances in Neural Information Processing Systems, 35:6615–6629, 2022. Brenner, M., Weber, E., Koppe, G., and Durstewitz, D. Learning interpretable hierarchical dynamical sys- tems models from time series data. arXiv preprint arXiv:2410.04814, 2024. Canatar, A., Bordelon, B., and Pehlevan, C. Spectral bias and task-model alignment explain generalization in kernel regression and infinitely wide neural networks. Nature communications, 12(1):2914, 2021. Costacurta, J., Bhandarkar, S., Zoltowski, D., and Linder- man, S. Structured flexibility in recurrent neural networks via neuromodulation. Advances in Neural Information Processing Systems, 37:1954–1972, 2024. Cotler, J., Tai, K. S., Hernández, F., Elias, B., and Sussillo, D. Analyzing populations of neural networks via dynam- ical model embedding. arXiv preprint arXiv:2302.14078, 2023. Dinc, F., Shai, A., Schnitzer, M., and Tanaka, H. CORNN: Convex optimization of recurrent neural networks for rapid inference of neural dynamics. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. Dowling, M., Zhao, Y., and Park, M. exponential family dy- namical systems (xfads): Large-scale nonlinear gaussian state-space modeling. Advances in Neural Information Processing Systems, 37:13458–13488, 2024. Driscoll, L. N., Shenoy, K., and Sussillo, D. Flexible multi- task computation in recurrent networks utilizes shared dy- namical motifs. Nature Neuroscience, 27(7):1349–1363, 2024. Dubreuil, A., Valente, A., Beiran, M., Mastrogiuseppe, F., and Ostojic, S. The role of population structure in com- putations through neural dynamics. Nature Neuroscience, 25(6):783–794, 2022. Duncker, L. and Sahani, M. Dynamics on the manifold: Identifying computational dynamical activity from neural population recordings. Current opinion in neurobiology, 70:163–170, 2021. Durstewitz, D., Koppe, G., and Thurm, M. I. Reconstructing computational system dynamics from neural data with recurrent neural networks. Nature Reviews Neuroscience, 24(11):693–710, 2023. Elsayed, G. F., Lara, A. H., Kaufman, M. T., Churchland, M. M., and Cunningham, J. P. Reorganization between preparatory and movement population responses in motor cortex. Nature communications, 7(1):13239, 2016. Engelken, R. and Wolf, F. Lyapunov spectra of chaotic recurrent neural networks. Physical Review Research, 5 (4):043044, 2023. Gallego, J. A., Perich, M. G., Miller, L. E., and Solla, S. A. Neural manifolds for the control of movement. Neuron, 94(5):978–984, 2017. Gallego, J. A., Perich, M. G., Chowdhury, R. H., Solla, S. A., and Miller, L. E. Long-term stability of cortical popu- lation dynamics underlying consistent behavior. Nature neuroscience, 23(2):260–270, 2020. Golub, M. D. and Sussillo, D. Fixedpointfinder: A ten- sorflow toolbox for identifying and characterizing fixed points in recurrent neural networks. Journal of open source software, 3(31):1003, 2018. Ha, D., Dai, A., and Le, Q. V. Hypernetworks. arXiv preprint arXiv:1609.09106, 2016. Hess, F., Monfared, Z., Brenner, M., and Durstewitz, D. Generalized teacher forcing for learning chaotic dynam- ics. In Proceedings of the 40th International Conference on Machine Learning, ICML’23, 2023. Huang, A., Singh, S. H., Martinelli, F., and Rajan, K. Measuring and controlling solution degeneracy across task-trained recurrent neural networks. arXiv preprint arXiv:2410.03972, 2024. Kirchmeyer, M., Yin, Y., Donà, J., Baskiotis, N., Rako- tomamonjy, A., and Gallinari, P. Generalizing to new physical systems via context-informed dynamics model. In International Conference on Machine Learning, p. 11283–11301. PMLR, 2022. Legenstein, R. and Maass, W. Edge of chaos and prediction of computational performance for neural circuit models. Neural networks, 20(3):323–334, 2007. Liu, Y. H., Baratin, A., Cornford, J., Mihalas, S., Shea- Brown, E., and Lajoie, G. How connectivity structure 9 JEDI: Jointly Embedded Inference of Neural Dynamics shapes rich and lazy learning in neural circuits. ICLR, p. arXiv–2310, 2024. Mastrogiuseppe, F. and Ostojic, S. Linking connectivity, dynamics, and computations in low-rank recurrent neural networks. Neuron, 99(3):609–623.e29, 2018. Nzoyem, R. D., Keshtmand, N., Tsayem, I., Barton, D. A., and Deakin, T. Weight-space linear recurrent neural net- works. arXiv preprint arXiv:2506.01153, 2025. Ostojic, S. and Fusi, S. Computational role of structure in neural activity and connectivity. Trends in Cognitive Sciences, 28(7):677–690, 2024. Pals, M., Sa ̆ gtekin, A. E., Pei, F., Gloeckler, M., and Macke, J. H. Inferring stochastic low-rank recurrent neural net- works from neural data. Advances in Neural Information Processing Systems, 37:18225–18264, 2024. Pandarinath, C., O’Shea, D. J., Collins, J., Jozefowicz, R., Stavisky, S. D., Kao, J. C., Trautmann, E. M., Kaufman, M. T., Ryu, S. I., Hochberg, L. R., Henderson, J. M., Shenoy, K. V., Abbott, L. F., and Sussillo, D. Inferring single-trial neural population dynamics using sequential auto-encoders. Nature Methods, 15(10):805–815, 2018. Pellegrino, A., Cayco Gajic, N. A., and Chadwick, A. Low tensor rank learning of neural dynamics. Advances in Neu- ral Information Processing Systems, 36:11674–11702, 2023. Perich, M. G., Gallego, J. A., and Miller, L. E. A neural population mechanism for rapid learning. Neuron, 100 (4):964–976, 2018. Perich, M. G., Arlt, C., Soares, S., Young, M. E., Mosher, C. P., Minxha, J., Carter, E., Rutishauser, U., Rudebeck, P. H., Harvey, C. D., and Rajan, K. Inferring brain-wide interactions using data-constrained recurrent neural net- work models. bioRxiv:2020.12.18.423348, 2021. Perich, M. G., Narain, D., and Gallego, J. A. A neural manifold view of the brain. Nature Neuroscience, 28(8): 1582–1597, 2025. Rajan, K., Abbott, L., and Sompolinsky, H. Stimulus- dependent suppression of chaos in recurrent neural net- works. Physical Review E—Statistical, Nonlinear, and Soft Matter Physics, 82(1):011903, 2010. Raman, D. V. and O’Leary, T. Frozen algorithms: how the brain’s wiring facilitates learning. Current Opinion in Neurobiology, 67:207–214, 2021. Rusch, T. K. and Mishra, S. Coupled oscillatory recurrent neural network (cornn): An accurate and (gradient) stable architecture for learning long time dependencies. arXiv preprint arXiv:2010.00951, 2020. Russo, A. A., Bittner, S. R., Perkins, S. M., Seely, J. S., Lon- don, B. M., Lara, A. H., Miri, A., Marshall, N. J., Kohn, A., Jessell, T. M., et al. Motor cortex embeds muscle-like commands in an untangled population response. Neuron, 97(4):953–966, 2018. Ryoo, A. H.-W., Krishna, N. H., Mao, X., Azabou, M., Dyer, E. L., Perich, M. G., and Lajoie, G. Generalizable, real-time neural decoding with hybrid state-space mod- els, 2025. URLhttps://arxiv.org/abs/2506. 05320. Saad, D. and Solla, S. A. Exact solution for on-line learning in multilayer neural networks. Physical Review Letters, 74(21):4337, 1995. Scarvelis, C. and Solomon, J. M. Nuclear norm regulariza- tion for deep learning. Advances in Neural Information Processing Systems, 37:116223–116253, 2024. Schneider, S., Lee, J. H., and Mathis, M. W. Learnable latent embeddings for joint behavioural and neural analysis. Nature, 617(7960):360–368, 2023. Schuessler, F., Dubreuil, A., Mastrogiuseppe, F., Ostojic, S., and Barak, O. Dynamics of random recurrent networks with correlated low-rank structure. Physical Review Re- search, 2(1):013111, 2020. Schuessler, F., Mastrogiuseppe, F., Ostojic, S., and Barak, O. Aligned and oblique dynamics in recurrent neural networks. Elife, 13:RP93060, 2024. Seung, H. S., Sompolinsky, H., and Tishby, N. Statistical mechanics of learning from examples. Physical review A, 45(8):6056, 1992. Sussillo, D. and Abbott, L. F. Generating coherent patterns of activity from chaotic neural networks. Neuron, 63: 544–557, 2009. Sussillo, D. and Barak, O. Opening the black box: low- dimensional dynamics in high-dimensional recurrent neu- ral networks. Neural computation, 25(3):626–649, 2013. Turner, E., Dabholkar, K. V., and Barak, O. Charting and navigating the space of solutions for recurrent neural networks. Advances in Neural Information Processing Systems, 34:25320–25333, 2021. Valente, A., Pillow, J. W., and Ostojic, S. Extracting com- putational mechanisms from neural data using low-rank rnns. In Advances in Neural Information Processing Sys- tems, volume 35, 2022. Vermani, A., Nassar, J., Jeon, H., Dowling, M., and Park, I. M. Meta-dynamical state space models for integrative neural data analysis. arXiv preprint arXiv:2410.05454, 2024. 10 JEDI: Jointly Embedded Inference of Neural Dynamics Versteeg, C., McCart, J. D., Ostrow, M., Zoltowski, D. M., Washington, C. B., Driscoll, L., Codol, O., Michaels, J. A., Linderman, S. W., Sussillo, D., and Pandarinath, C. Computation-through-dynamics benchmark: Simulated datasets and quality metrics for dynamical models of neural activity. bioRxiv, 2025. Vloeberghs, R., Urai, A. E., Desender, K., and Linderman, S. W. A bayesian hierarchical model of trial-to-trial fluc- tuations in decision criterion. PLOS Computational Biol- ogy, 21(7):e1013291, 2025. Vogt, R., Puelma Touzel, M., Shlizerman, E., and Lajoie, G. On lyapunov exponents for rnns: Understanding informa- tion propagation using dynamical systems tools. Frontiers in Applied Mathematics and Statistics, 8:818799, 2022. Von Oswald, J., Henning, C., Grewe, B. F., and Sacramento, J. Continual learning with hypernetworks. arXiv preprint arXiv:1906.00695, 2019. Vyas, S., Golub, M. D., Sussillo, D., and Shenoy, K. V. Computation through neural population dynamics. An- nual review of neuroscience, 43(1):249–275, 2020. Williams, E., Payeur, A., Ryoo, A. H.-W., Jiralerspong, T., Perich, M. G., Mazzucato, L., and Lajoie, G. Ex- pressivity of neural networks with random weights and learned biases. In The Thirteenth International Confer- ence on Learning Representations, 2025. URLhttps: //openreview.net/forum?id=5xwx1Myosu. Yang, G. R., Joglekar, M. R., Song, H. F., Newsome, W. T., and Wang, X.-J. Task representations in neural networks trained to perform many cognitive tasks. Nature neuro- science, 22(2):297–306, 2019. 11 JEDI: Jointly Embedded Inference of Neural Dynamics A. Appendix A.1. Additional details on Datasets A.1.1. SYNTHETIC MEMORY PRO TASK We followed the training procedure outlined in the Computation through Dynamics benchmark (Versteeg et al., 2025) and selected the MemoryPro task from the suite (Yang et al., 2019), as its fixed-point structure was previously characterized by (Driscoll et al., 2024). In this task, the RNN learned to respond in the same direction as a stimulus after a memory period by minimizing the squared loss between its 3-dimensional output and the target using backpropagation through time. The input space consisted of 18 dimensions, including a 1-d fixation signal, 2-d stimulus vectors encoding the circular variable θasA sin(θ)andA cos(θ), and 15-d rule inputs that remained active throughout the trial to indicate the task type. We used a tanh activation function for the RNN and treated each task period—context, stimulus, memory, and response—as an autonomous dynamical system with a distinct set of fixed points induced by the piecewise constant inputs.We structured the experimental trials by initiating a 75-timestep context period where only fixation and rule inputs were active, followed by the stimulus period and a subsequent memory period that shared identical inputs with the context phase. When the fixation input dropped to zero during the response period, the network generated its directional output, which we monitored across 36 distinct stimulus angles. To ensure a robust dataset, we generated 20 variations for each angle through random sampling, resulting in a final collection of 2,880 trials. We then fit our model to these activations by initializing 2,880 context embeddings and training them together to capture the underlying task dynamics and specific patterns of activity induced by the trial phases. A.1.2. MOTOR CORTEX RECORDINGS Two monkeys were trained to perform a two-dimensional center-out reaching task using a planar that controlled a cursor on a screen. In each trial, the monkey moved to a central start position, waited through a variable delay, and then reached toward one of eight randomly selected targets arranged uniformly in a circle. A go cue signaled movement initiation, and successful trials required reaching the target within 1 second and holding for 0.5 seconds to receive a reward. We tested the applicability of the proposed approach on this data containingc = 160trials ofT = 150timesteps of recordings from N = 117 neurons. A.1.3. MULTIPLE FREQUENCY NOISY SINE DATA Similar to multi-task synthetic teacher task framework, we generated chaotic sinusoidal datasets by driving chaotic teacher RNN with 10 increasing frequencies across 20 unique initializations of initial states. This procedure produced a comprehensive dataset of 200 trials characterized by complex, non-linear dynamics. We subsequently utilized these trajectories to train and evaluate both JEDI and JEDI-full, assessing their capacity to capture and reconstruct high-dimensional chaotic representations. A.2. Model Architecture and Training details The latent dimension for all the methods is set to 16. A.2.1. VAE We used a two-layer Multi-Layer Perceptron(MLP) for both the encoder and decoder. The encoder transforms the flattened input of shape (T × features) into a latent vector of dimension 16 though hidden layer of 128 hidden units. The decoder mirrors this architecture, decoding from the latent space through a hidden layer of 128 units back to the input dimension. The reconstruction loss is computed using the reparameterized latent vector, and the total loss includes a standard KL divergence term. During training, we set the batch size to 30 and trained the model for 1500 epochs, selecting the best model using validation loss. We used the Adam optimizer with an initial learning rate of10 −3 . Tanh activation was used in the decoder output. A.2.2. RNN VAE For the RNN-based VAE model (RNNVAE), we used a sequence-to-sequence architecture built upon an LSTM encoder and decoder. The encoder comprises a single-layer LSTM with 32 hidden units, followed by a fully connected layer projecting 12 JEDI: Jointly Embedded Inference of Neural Dynamics into a latent space of size 16, where the mean and log-variance are estimated through parallel linear layers. The decoder reconstructs the sequence by transforming the latent vector through a linear projection and passing it through a LSTM layer. The final output is mapped back to the input dimensionality through a fully connected layer. Tanh activation was used in the decoder output. This model operated on a sequence length of 200 time steps. Total loss was mean squared error between the decoder output and the ground truth trajectory and a standard KL divergence term. During training, we set the batch size to 150 and trained the model for 5000 epochs, selecting used the validation loss. We used the Adam optimizer with learning rate of 10 −2 . Tanh activation was used in the decoder output. A.3. Fixed point finding and visualization To evaluate the model’s ability to capture ground-truth dynamical structures, we utilized a task-trained RNN as a benchmark. We trained this model on the MemoryPro task from the Computation through Dynamics repository (Versteeg et al., 2025) and simulated trials across four distinct phases: context, stimulus, memory, and response. We qualitatively assessed the true and inferred dynamics by analyzing their fixed-point structures. Fixed points represent regions in the state space where dynamics are slow enough to permit linear approximation, revealing the system’s local stability and behavior. To identify these points, we located coordinates in the hidden state space (z ∈R d z ) that minimized the system’s kinetic energy, defined asq ≈∥∆h∥ 2 . We performed this analysis using a modified version of the fixed-point finding toolkit developed by (Golub & Sussillo, 2018), adapting it for autonomous RNNs. For visualization, we applied Principal Component Analysis (PCA) to the hidden state trajectories and projected the identified fixed points into this low-dimensional subspace. We used speed tolerance of 1e− 15 for MemoryPro task and 1e− 9 for motor cortex recordings. A.4. Details of Lyapunov Exponents calculation To calculate the Lyapunov exponents, we followed algorithm proposed by (Engelken & Wolf, 2023) the algorithm tracks how tiny perturbations to the network’s state grow or shrink over time by evolving a set of test vectors alongside the main simulation. The algorithm uses jacobian matrix to update these vectors, measuring how the network activations amplify or dampen small differences. By averaging the logarithms of the stretching factors (the values from the decomposition) over the entire simulation, we can determine the final exponents. While a positive exponent confirms that the network is chaotic, stable dynamics is characterized by negative exponent. A.5. Supplementary Experiments and Details A.5.1. MODEL RECONSTRUCTIONS Qualitatively assessing model reconstructions A B TimeTime Time Time Neurons VAE RNN-VAE JEDI-full JEDI Neurons TimeTimeTimeTime VAE RNN-VAE JEDI-full JEDI Figure S1. Reconstructions of neural trajectories from various trained models. A) Reconstructions of Synthetic Data. B) Reconstructions of the Monkey cortex data. 13 JEDI: Jointly Embedded Inference of Neural Dynamics A.5.2. HYPERPARAMETER SWEEP Identifying optimal hyperparameters for training Jedi B C Hnet Hidden Layer Size Embedding Size Embedding Size 2 8 16 32 64 128 256 512 A Model Rank Hnet Hidden Layer Size 2 5 7 10 100 200 32 64 128 256 512 1024 1.0 0.0 R 2 2 5 7 10 100 200 Model Rank 2 8 16 32 64 128 256 512 32 64 128 256 512 1024 1.0 0.0 R 2 1.0 0.0 R 2 Figure S2. Jedi performance on the synthetic data generated by the teacher RNN. The heatmaps show accuracy of reconstruction as measured by the coefficient of determinationR 2 score. The red circular dot highlights the hyperparameter that was chosen for experiments. A)R 2 heatmap for various model rank vs Hypernetwork hidden layer size. B)R 2 heatmap for various Model rank vs Embedding size of the hypernetworks and c)R 2 heatmap for various embedding size vs Hidden size of the hypernetwork. A.5.3. EFFECT OF DATA RANK ON PERFORMANCE OF JEDI 125full Data Rank 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 R 2 Score R 2 Score by Data Rank (with 1 s.e.m.) Figure S3. Performance for Jedi on the multi-task synthetic data generated by the teacher RNN with different rank of connectivity matrix J. The bar plot shows accuracy of reconstruction as measured by coefficient of determinationR 2 score (with 1 s.e.m). Jedi was able to fit well on varied data ranks with the generated weights set to rank=5. A.6. Additional results on spectral analysis Spectral analysis on multi-task synthetic data JEDI-full JEDI Real part Imag. part Imag. part Real part Sine Decay Fix-pt Cosine Ramp Square 4 0 -4 -10.0 -7.5 -5.0 -2.5 2.5 0.0-10.0 -7.5 -5.0 -2.5 2.5 0.0 Figure S4. Eigenvalue spectra of learned recurrent weights on multi-task synthetic data. JEDI learned consistent structure across tasks, while JEDI-full exhibited a dense, isotropic cloud without task separation. 14 JEDI: Jointly Embedded Inference of Neural Dynamics Spectral Analysis of Jedi-full weights on Motor Cortical activity E NE N NW W SW S SE Real part Imag. part Imag. part Real part PreparationExecution Figure S5. Eigenvalue spectra of Jedi-full weights on monkey motor cortical activity. JEDI-full exhibited a dense, isotropic cloud without task/direction separation. 1 Hz 2 Hz 3 Hz 4 Hz 5 Hz 6 Hz 7 Hz 8 Hz 9 Hz 10 Hz 1.5 1.0 0.5 0.0 -0.5 -1.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 Imag. part Real part Spectral Analysis of JEDI-full weights trained on Multi-Frequency data Figure S6. Eigenvalue spectra of learned recurrent weights on multiple frequency dataset. The legend indicates the specific sine frequencies used to train the models, JEDI-full consistently exhibits a dense, isotropic spectral cloud. This distribution lacks any discernible structure or cluster formation corresponding to the underlying task frequencies. 15 JEDI: Jointly Embedded Inference of Neural Dynamics A.7. Additional results on monkey reaching data (Preparation phase) E NE N NW W SW S SE E NE N NW W S SW SE 0%100% Classification accuracy PCs A B E NE N NW W SW S SE Classification Accuracy V A E R N N - V A E J E D I - f u l l J E D I 1.0 0.0 E NE N NW W SW S SE E NE N NW W S SW SE R 2 1.0 0.0 G e n e r a l i z a t i o n R 2 1.0 -3 R V A E N N - V A E J E D I - f u l l J E D I Learned embeddings across reaching directions Embedding performance metrics C V A E R N N - V A E J E D I - f u l l J E D I 1.0 0.0 R 2 R e c o n s t r u c t i o n E F D Generalization of embeddings to new samples Figure S7. Assessing embedding quality on monkey reaching data during preparation. a) 2D PCA visualization of embeddings from different models. Each point represents a sample trial, color-coded by reach direction. b) Confusion matrices of direction classification accuracies. Rows correspond to the true direction, and columns to the predicted direction. c) Confusion matrices showing generalization performance across reach directions. Each cell indicates theR 2 score when decoding one direction (column) using mean embeddings of the other(row). A.8. Additional results on generalization E NE N NW W SW S SE E NE N NW W S SW SE R 2 1.0 0.0 R V A E N N - V A E J E D I - f u l l J E D I G e n e r a l i z a t i o n R 2 1.0 -3 A B Generalization of embeddings to new samples from other task Figure S8. Assessing embedding quality on monkey reaching data during preparation. a) 2D PCA visualization of embeddings from different models. Each point represents a sample trial, color-coded by reach direction. b) Confusion matrices of direction classification accuracies. Rows correspond to the true direction, and columns to the predicted direction. c) Confusion matrices showing generalization performance across reach directions. Each cell indicates theR 2 score when decoding one direction (column) using mean embeddings of the other(row). 16 JEDI: Jointly Embedded Inference of Neural Dynamics A.9. Additional fixed points results Inferred fixed point structure in motor cortex during preparation PCs (neural) Figure S9. Stable fixed points during preparation inferred by JEDI colored by reach directions. Single-trial neural trajectories for each corresponding reach is also plotted. Ground truth fixed points on MemoryPro Task derived from Task Trained RNN Context Stimulus Memory Response PCs (neural) Inferred fixed points on MemoryPro Task from JEDI-full Context Stimulus Memory Response PCs (neural) Inferred fixed points on MemoryPro Task from JEDI Context Stimulus Memory Response PCs (neural) Figure S10. 3d PCA projection of fixed points from Jedi across 4 periods on MemoryPro task. The fixed points structure closely resembled the structure of task-trained RNN. 17 JEDI: Jointly Embedded Inference of Neural Dynamics Sine Decay Fix-pt Cosine Ramp Square Testing the effect of training only JEDI-full recurrent weights but freezing context embeddings Figure S11. When the context embedding remains frozen during training, the eigenvalue spectra of the learned recurrent weights fail to capture task-specific signatures. Consequently, joint optimization of both contexts and weights is essential to develop meaningful, trial-specific representations. A.10. Compute Resources We list below the compute resources used per experiment: 1.Multi-Task Teacher experiment: Results were computed on external cluster equipped with Nvidia L40S GPUs. Training wall clock times was 6hrs. Inference converges in a few seconds. 2. Multi-Frequency Since experiment: Results were computed on external cluster equipped with Nvidia L40S GPUs. Training wall clock times was 6hrs. Inference converges in a few seconds. 3.Task trained RNN experiment: Results were computed on external cluster equipped with Nvidia L40S GPUs. Training wall clock times was 1.5-2 days. Inference converges in a few seconds. 4.Monkey reaching experiment : Results were computed on external cluster equipped with Nvidia L40S GPUs. Training wall clock times was 6hrs. Inference converges in a few seconds. 18