<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://jainagaraj.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://jainagaraj.github.io/" rel="alternate" type="text/html" /><updated>2026-03-03T21:58:57+00:00</updated><id>https://jainagaraj.github.io/feed.xml</id><title type="html">Jai Vivek Nagaraj</title><subtitle>When you say normal... what does that even mean?</subtitle><author><name>Jai Nagaraj</name></author><entry><title type="html">Online learning with a memory-efficient Kalman Filter</title><link href="https://jainagaraj.github.io/blog/2025/12/16/lofi/" rel="alternate" type="text/html" title="Online learning with a memory-efficient Kalman Filter" /><published>2025-12-16T14:00:00+00:00</published><updated>2025-12-16T14:00:00+00:00</updated><id>https://jainagaraj.github.io/blog/2025/12/16/lofi</id><content type="html" xml:base="https://jainagaraj.github.io/blog/2025/12/16/lofi/"><![CDATA[<p class="post-meta-updated">Dec 16, 2025 | Updated Feb 8, 2026</p>

<h2 id="introduction-the-kalman-filter">Introduction: The Kalman Filter</h2>

<p>Recently, I had the opportunity to learn about the Kalman filter: a powerful, versatile tool that captures uncertainty in a time-varying system. Given a series of observations and an underlying dynamical model of how the state <em>should</em> change over time, the Kalman filter can update the state based on a <strong>probability distribution of the state</strong> and <strong>external observation data</strong>. It is no surprise, then, that NASA was able to utilize these properties for many aerodynamic control systems, including the <a href="https://www.lancaster.ac.uk/stor-i-student-sites/jack-trainer/how-nasa-used-the-kalman-filter-in-the-apollo-program/" target="_blank">famous Apollo missions to the moon</a>!</p>

<p>However, another use of the Kalman filter outside of control theory is Bayesian machine learning. Rather than traditional methods of optimizing a model, we can frame the idea of model learning in terms of uncertainty: we start off being <em>highly uncertain</em> that our model parameters are optimal, and ideally use a variant of the Kalman filter algorithm to shrink our uncertainty (and therefore naturally optimize the model in the state-estimation process). As it turns out, this is not only reasonable, but also quite performant; the calculations incorporate approximate second-order information<a href="#appendix">${}^1$</a> when computing the next state. We call this method the <a href="https://en.wikipedia.org/wiki/Extended_Kalman_filter" target="_blank">Extended Kalman Filter (EKF)</a>.</p>

<h3 id="the-ekf-algorithm">The EKF algorithm</h3>
<p>We can roughly describe the algorithm in the following steps:</p>

<ol>
  <li>“A priori” predictions<a href="#appendix">${}^2$</a>; essentially, making a simple prediction of what the next <em>EKF state</em> will be (in our case, the model parameters). $\theta$ represents the model parameters as a vector, $\Sigma$ represents the covariance matrix of the parameters, $\mathbf{Q}$ represents the process noise, and $\mathbf{f_\theta}$ represents the output of the model with parameters $\theta$. To generalize this problem to one in control theory, let’s suppose we are learning to optimize controls over a system that has <em>system state</em> $\mathbf{x}$ and controls $\mathbf{u}$, and our underlying baseline model predicts an observation $\mathbf{y}$.</li>
</ol>

\[\begin{aligned}
&amp; \theta_{t+1 \mid t}=\theta_t \\
&amp; \Sigma_{t+1 \mid t} = \Sigma_t + \mathbf{Q}\\
&amp; \mathbf{y}_{t+1 \mid t}=\mathbf{f}_{\theta_{t+1 \mid t}}(\mathbf{x}_t, \mathbf{u}_t)
\end{aligned}\]

<ol start="2">
  <li>Kalman gain computation, which helps scale how much to change the parameters (almost like a “learning rate”). $\mathbf{R}_t$ is the observation noise, and crucially, $\mathbf{K}_t$ is the <strong>Kalman gain</strong> matrix.</li>
</ol>

\[\begin{aligned}
&amp; \mathbf{F}_t=\frac{\partial \mathbf{f}}{\partial \theta} (\mathbf{x}_t, \mathbf{u}_t, \theta_t) \\
&amp; \mathbf{S}_t=\mathbf{F}_t \Sigma_{t+1 \mid t} \mathbf{F}^\top_t + \mathbf{R}_t \\
&amp; \mathbf{K}_t = \Sigma_{t+1 \mid t}\mathbf{F}^\top_t\mathbf{S}_t^{-1}
\end{aligned}\]

<ol start="3">
  <li>Posterior computation, which uses the Kalman gain to update the state distribution. $\mathbf{I}$ is the identity matrix while $\mathbf{s}_t$ is the <em>innovation</em>: the difference between the predicted state $\mathbf{y}_{t+1 \mid t}$ and the observed state $\mathbf{y}_{t+1}$.</li>
</ol>

\[\begin{aligned}
&amp; \mathbf{s}_t = \mathbf{y}_{t+1} - \mathbf{y}_{t+1 \mid t} \\
&amp; \theta_{t+1} = \theta_{t+1 \mid t} + \mathbf{K}_t\mathbf{s}_t \\
&amp; \Sigma_{t+1} = (\mathbf{I}-\mathbf{K}_t\mathbf{F}_t) \Sigma_{t+1 \mid t} \\
&amp; \textbf{return } \theta_{t+1}, \Sigma_{t+1} \\
\end{aligned}\]

<p>This process is repeated for each timestep $t$.</p>

<h3 id="in-practice">In practice</h3>
<p>Using the <a href="https://docs.jax.dev/en/latest/" target="_blank">JAX</a> library in Python, the above algorithm can be implemented as follows:</p>
<details>
<summary> <strong> Click to show code </strong> </summary>


<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">jnp</span>

<span class="c1"># Any functions not defined can be assumed to work as the signature implies.
</span><span class="k">def</span> <span class="nf">ekf_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean_t</span><span class="p">,</span> <span class="n">cov_t</span><span class="p">,</span> <span class="n">control_t</span><span class="p">,</span> <span class="n">obs_tp1</span><span class="p">):</span>
    <span class="c1"># A priori predictions
</span>    <span class="n">mean_tp1_apriori</span> <span class="o">=</span> <span class="n">mean_t</span>
    <span class="n">cov_tp1_apriori</span> <span class="o">=</span> <span class="n">cov_t</span> <span class="o">+</span> <span class="n">process_cov_fn</span><span class="p">(</span>
        <span class="n">mean_tp1_apriori</span>
    <span class="p">)</span>
    <span class="n">obs_tp1_apriori</span> <span class="o">=</span> <span class="n">observation_fn</span><span class="p">(</span>
        <span class="n">mean_tp1_apriori</span><span class="p">,</span> <span class="n">control_t</span>
    <span class="p">)</span>

    <span class="c1"># Kalman gain calculation via Jacobian
</span>    <span class="n">jac_obs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jacrev</span><span class="p">(</span><span class="n">observation_fn</span><span class="p">,</span> <span class="n">argnums</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span>
        <span class="n">mean_tp1_apriori</span><span class="p">,</span> <span class="n">control_t</span>
    <span class="p">)</span>
    <span class="n">R_t</span> <span class="o">=</span> <span class="n">observation_cov_fn</span><span class="p">(</span><span class="n">mean_tp1_apriori</span><span class="p">)</span>
    <span class="n">S_t</span> <span class="o">=</span> <span class="n">jac_obs</span> <span class="o">@</span> <span class="n">cov_tp1_apriori</span> <span class="o">@</span> <span class="n">jac_obs</span><span class="p">.</span><span class="n">T</span> <span class="o">+</span> <span class="n">R_t</span>
    <span class="n">kalman_gain</span> <span class="o">=</span> <span class="p">(</span>
        <span class="n">cov_tp1_apriori</span> <span class="o">@</span> <span class="n">jac_obs</span><span class="p">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">S_t</span><span class="p">)</span>
    <span class="p">)</span>

    <span class="c1"># Posterior calculation
</span>    <span class="n">innovation</span> <span class="o">=</span> <span class="n">obs_tp1</span> <span class="o">-</span> <span class="n">obs_tp1_apriori</span>
    <span class="n">mean_tp1</span> <span class="o">=</span> <span class="n">mean_tp1_apriori</span> <span class="o">+</span> <span class="n">kalman_gain</span> <span class="o">@</span> <span class="n">innovation</span>
    <span class="n">eye_cov</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">cov_t</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">cov_tp1</span> <span class="o">=</span> <span class="p">(</span><span class="n">eye_cov</span> <span class="o">-</span> <span class="n">kalman_gain</span> <span class="o">@</span> <span class="n">jac_obs</span><span class="p">)</span> <span class="o">@</span> <span class="n">cov_tp1_apriori</span>

    <span class="k">return</span> <span class="n">mean_tp1</span><span class="p">,</span> <span class="n">cov_tp1</span></code></pre></figure>


</details>

<p><br />
The EKF works remarkably well for regression problems. We can easily frame regression as a state-estimation problem to fit the algorithm above: the output of $\mathbf{f}_\theta$ should minimize the innovation under a constant control of $\mathbf{u}_t=\vec{0}$, $\forall t$. Here, I demonstrate the results of training a simple 641-parameter multilayer perceptron<a href="#appendix">${}^3$</a> with two hidden layers:</p>

<p><img src="/assets/images/online_ekf_demo.png" alt="EKF sinusoid performance" /></p>

<p>The Adam optimizer is used as a baseline. The function to be learned online is $f(x) = \text{sin}(10x)$.</p>

<p>As expected, the EKF performs much better than gradient-descent optimizers due to the second-order information the EKF has access to. The shaded bands around the EKF function output represent the predictive variance calculated via the following formula,
$\text{Var}(y \mid x) = \textbf{S} = J_\theta \Sigma J_\theta^\top + R$ where $J_\theta = \frac{\partial \textbf{f}}{\partial \theta}(x, \theta)$,
which we already calculate in the EKF algorithm.</p>

<p>Notice that the bands only represent $0.1$ standard deviations from the mean, signaling the scale of uncertainty. I displayed only $\pm 0.1\sigma$ to better capture the detail of the output functions, otherwise Matplotlib would have squished them to be a straight line!</p>

<hr />

<h2 id="optimization-lofi">Optimization: LOFI</h2>

<h3 id="motivation">Motivation</h3>
<p>There is but one issue with the Kalman filter: it’s computationally and spatially expensive. The covariance matrix grows <em>quadratically</em> with the number of model parameters, making training large models impractical unless you have an impractical amount of RAM (and, with the current AI boom raising prices, an impractical amount of money for it). Even with large amounts of memory, the computational complexity of calculating inverses and performing the necessary matrix operations scales with the size of the covariance matrix, making training inefficient even with a GPU.</p>

<p>The plots I generated above required minutes of training for the EKF over the same data as Adam, which took just a few seconds. This surely wouldn’t fly in real-time systems despite their accuracy!</p>

<h3 id="using-decomposition-to-approximate-ekf">Using decomposition to approximate EKF</h3>
<p>In 2023, <a href="https://arxiv.org/abs/2305.19535" target="_blank">this paper</a><a href="#appendix">${}^4$</a> presented a way to approximate an EKF by storing the <em>inverse</em> covariance matrix in two components: a diagonal matrix and a low-rank (horizontal) matrix. Using singular value decomposition and Woodbury identities, we can still mimic the Kalman filter process of updating state distributions based on Jacobians. However, covariance updates are now much faster since the covariance information we store grows linearly with the number of parameters! We can then adjust how much covariance information we would like to capture by changing the maximum rank $L$ of the low-rank matrix.</p>

<p>Due to the extra complications of using covariance matrices in this format, prediction and updates are split into separate algorithms in LOFI. Moreover, while the diagonal matrix and low-rank matrices are stored independently, the update step must take care to treat them as approximating the <em>same matrix</em> when combined, thus any updates to one of the components must be reflected in the other. This is where SVD comes in during the update step. Still, if you look beyond the fancy matrix computations, many of the steps are similar to the standard EKF procedure: a priori predictions of parameters and the covariance matrix are the same (at least for the diagonal component), and the parameters are updated by using innovation and a modified Kalman gain matrix.</p>

<h3 id="performance">Performance</h3>
<p>When evaluating the tradeoffs of an optimization, it’s important that the sacrifices made for speed do not lead to an optimizer that is pointlessly more complicated than a standard optimizer (i.e. does not give much performance benefit). Fortunately, this is not the case for LOFI! The following results pit LOFI<a href="#appendix">${}^5$</a> and Adam against each other as they learn a nonlinear function online over 50,000 observations. For every 500 observations, evaluation over the test data as well as previously-seen observations is conducted and plotted.</p>

<p>Evaluated on $f(x)=\text{sin}(10x)$:
<img src="/assets/images/online_sinusoid_demo.png" alt="LOFI vs. Adam for sinusoid" /></p>

<p>Evaluated on $f(x)=e^{\frac{-(x+1.5)^2}{0.18}} + 0.7e^{-2x^2} + 0.9e^{\frac{-(x-1.5)^2}{0.32}}$:
<img src="/assets/images/online_gaussian_demo.png" alt="LOFI vs. Adam for gaussian mixture" /></p>

<p>Evaluated on $f(x)=x^3-x$:
<img src="/assets/images/online_poly_demo.png" alt="LOFI vs. Adam for simple polynomial" /></p>

<p>Evaluated on $f(x) = \text{sgn}(\text{sin}(10x))$ (essentially a square wave):
<img src="/assets/images/online_square_demo.png" alt="LOFI vs. Adam for square wave" /></p>

<p>In each case, not only does LOFI approximate the function better than Adam (sometimes by a large margin), it also learns more quickly and accurately over time. More importantly, the efficiency of LOFI makes the difference in training time between the two optimizers nearly negligible.</p>

<p>We can also see the accuracy-for-speed tradeoff more clearly when pitting LOFI against EKF:
<img src="/assets/images/online_ekf_vs_lofi_demo.png" alt="EKF vs LOFI" /></p>

<p>Both optimizers were able to learn the function quite well, and surprisingly, LOFI exceeds the performance of the EKF using the same shared hyperparameters<a href="#appendix">${}^6$</a>! This is likely due to implicit regularization from the low-rank approximation, as the standard EKF is more expressive than LOFI. However, observe the uncertainty bands — the EKF has slightly more certainty and its uncertainty bands match the function shape more accurately. (I suspect this is due to the low-rank approximation that LOFI uses encoding less covariance information.) Yet what is not displayed is the training time: while the EKF trained for <strong>~3 minutes</strong>, LOFI took a mere <strong>18 seconds</strong> to train. Given the closeness in performance, LOFI is a remarkable optimization over the EKF.</p>

<hr />

<h2 id="my-experience-with-online-learning">My experience with online learning</h2>

<p>I should note that before this project, I had little experience with traditional machine learning and neural networks beyond a couple online Google Colabs that abstracted away many of the intricacies and charms of the the field, and having seen it as nothing more than a tool where the magic is already done for me, displayed little interest. But with this more technical approach to machine learning where I had to build an unconventional trainer from scratch, I realize the art of training a neural network is far from boring. Of the many things I learned, here are a few of the major points that I’ll certainly keep in mind for the future:</p>

<h3 id="normalization-matters">Normalization matters</h3>
<p>It’s incredible how normalizing an input dataset to be centered around 0 can do wonders! While it isn’t necessary for all cases, it helps stability and convergence during training since neural nets are rather sensitive to large changes in magnitude. In fact, I found that all models were much more apt at predicting function values from inputs between [-1,1] than other input values regardless of optimizer choice.</p>

<h3 id="training-is-extremely-sensitive-to-hyperparameter-selection">Training is extremely sensitive to hyperparameter selection</h3>
<p>I don’t exactly recall which paper demonstrated the role of good hyperparameter selection in convergence, but one of the featured graphs was shocking: when plotting convergence successes based on two hyperparameters only, the regions of success were completely erratic and impossible to explain! There would be small spots where the model would find a local minima amid a sea of failed convergences corresponding to bad hyperparameter pairs. Meanwhile, I had to worry about many different aspects of training for LOFI:</p>

<ul>
  <li>What should the maximum rank $L$ of the low-rank matrix be?</li>
  <li>How large should the initial covariance be?</li>
  <li>How much process noise should I add?</li>
  <li>How much observation noise should I add?</li>
</ul>

<p>I must have spent a couple hours in total trying out different combinations and increments before getting a stable model. For this, the <a href="https://wandb.ai/site" target="_blank">Weights and Biases</a> API was quite helpful in documenting past runs and making sure I didn’t waste my time with redundant runs.</p>

<h3 id="jit-and-vectorize-everything">JIT and vectorize EVERYTHING</h3>
<p>In machine learning, there are a lot of repetitive calculations done inside loops on large batches of data. Thus, when doing deep learning in Python, this leads to slow, inefficient programs that simply make my laptop fan very angry for very long periods of time. Luckily, JAX also comes with several tools to make Python code quite fast by turning functions into compiled machine code with the <code class="language-plaintext highlighter-rouge">@jax.jit</code> decorator (the proper term for this is <em>Just-In-Time Compilation</em>, which you may have heard of in the context of Java). While it is quite strict with the kind of code you can have jitted, the improvements are as expected: the difference between compiled and interpreted code in terms of speed is night and day.</p>

<p>Additionally, because data is often handled in batches, <em>vectorizing</em> these operations can also lead to enormous performance enhancements. JAX’s <code class="language-plaintext highlighter-rouge">jax.vmap</code> and <code class="language-plaintext highlighter-rouge">jax.lax.scan</code> functions can modify functions to handle large amounts of data and compile slow loops into optimized machine code, respectively. These are useful when writing training scripts, as tasks like test set evaluation can easily be batched.</p>

<h3 id="nans-are-really-really-really-really-annoying">NaNs are really really really really annoying</h3>
<p>Early on, the LOFI trainer would proceed fairly smoothly (albeit suboptimally) until some critical juncture of observations or epochs, after which the loss would become noticeably unstable and spike until three dreaded letters displayed to the screen:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>NaN
</code></pre></div></div>
<p>Unlike traditional errors, it’s incredibly difficult to pinpoint the source of NaNs entering the learning pipeline, especially with so many matrix operations and inverses. Furthermore, jitting a function effectively means goodbye to any kind of effective debugging, as the Python debugger cannot follow machine code well and there was simply too much information to print.</p>

<p>Eventually, I implemented the following to fix the mathematical issues (as far as I’m aware of):</p>

<ul>
  <li>Using Moore-Penrose pseudo-inverse operations for potential singular matrices</li>
  <li>Ensuring the returned Jacobian is a matrix to avoid dimension mismatch issues</li>
  <li>Adding an adaptive “jitter” value to maintain a non-zero floor for covariance (minimum 0.01)</li>
</ul>

<p>That last modification to the algorithm was what truly helped, as the covariance would often collapse to near-zero a few observations before a NaN appeared. Maintaining mathematical stability in the algorithm enabled me to use more aggressive learning rates and thus have LOFI performance exceed that of Adam and compare to EKF.</p>

<h2 id="conclusion-and-future-work">Conclusion and Future Work</h2>
<p>Overall, this was a valuable exploration into the world of more cutting-edge machine learning methods beyond simple gradient-descent and backprop, though I believe it would be a good idea to properly learn the fundamentals of machine learning more rigorously through my coursework and more projects like this one. Additionally, the aforementioned paper tested out LOFI learning on the MNIST dataset to test its ability to handle classification problems. Unfortunately, I was not able to set that up well, as it would take ages to train a proper-sized MLP for the job (perhaps an inefficiency with my script). I’m certainly looking forward to more work in this area, especially concerning applying such methods to non-trivial problems.</p>

<hr />
<div id="appendix"></div>

<ol>
  <li>
    <p>The EKF algorithm is a variant of the Gauss-Newton method with diminishing step size. Both methods use first-order approximations of the Hessian matrix that can <a href="https://web.mit.edu/dimitrib/www/ekf.pdf" target="_blank">solve nonlinear least-squares problems</a>.</p>
  </li>
  <li>
    <p>In the general EKF algorithm, the a priori predictions of state and covariance are governed by an underlying baseline model of how state changes w.r.t control (common in control theory), and will thus have extra computation not presented here. Since our state is instead the <em>parameters</em> of the model, the “baseline model” is simply the identity function, and thus the extra computation can be abstracted away.</p>
  </li>
  <li>
    <p>The layers had 32 and 16 weights repectively, with $\text{tanh}(x)$ as the activation function for each layer.</p>
  </li>
  <li>
    <p>Chang, Peter G., et al. “Low-rank extended Kalman filtering for online learning of neural networks from streaming data.” arXiv preprint arXiv:2305.19535 (2023).</p>
  </li>
  <li>LOFI here uses the following hyperparameters:
    <ul>
      <li>$L = 8$</li>
      <li>$\mathbf{Q}=qI$, $q=10^{-4}$</li>
      <li>$\mathbf{R} = \frac{1}{r}I$, $r = 10^{-1}$</li>
      <li>$\Sigma_0 = cI$, $c = 1.0$</li>
    </ul>
  </li>
  <li>LOFI is now modified to use $L = 12$, as the paper recommends that for $L \geq 10$, LOFI performance is roughly the same as EKF. All other hyperparameters are consistent between the two and are listed above.</li>
</ol>

<hr />]]></content><author><name>Jai Nagaraj</name></author><category term="blog" /><summary type="html"><![CDATA[Dec 16, 2025 | Updated Feb 8, 2026]]></summary></entry></feed>