So, before we start, I would like to share I'm extremely new to this field of NeuroAI. and I have experience in Machine learning. But have new to nothing experience in Temporal Modeling or working on Computational Neuroscience.
Motivation for this endeavor
I believe the next biggest breakthrough will come from people who are not from CS background, and my bet it will be someone from Neuroscience, henceforth I started exploring neural dynamics and stuff. To keep myself entertained and when the time comes, I should be capable of laying a brick or two
Direction Of the Research:
So, in most of the cases usually a self-supervised approach is used and trained with a lot of biosignal datasets (EEG). This approach does wonders, and highly functional. SSL mostly learns to reconstruct patterns in data. But I want to pivot rather than making it learn how to reconstruct pattern I want to make it learn how to Understand.
Self-Supervised Learning
when I first learnt that masked language modeling is a form of SSL, it felt like magic. If you are new to Self-Supervised Learning it will be amazing if you took a look at BERT paper
The point of Pivot:
While exploring and trying to come up with a different/weird way of exploring by making the model understand the semantics, I switched and used JePa architecture (Joint-embedding Predictive architecture), The closest thing we have to predictive coding
What it does? what makes it different!!!
This kind of models are made to explore different objective, instead of reconstructing EEG, It learns to predict latent representations of masked signal regions (will explain more about latent representations in collapsible sections)
I used I-JePa (Jepa models for Images -- will share resources on collapsible section) in which the model actually preds representations from the mission region in the embedding space rather than it generating resources directly. Btw this is a 1D version of I-Jepa since we are using Single channel EEG.
Important stuffs
Latent Representations; These are your internal features learnt by networks, this has info about input data in a compressed format. you can also put it as a meta-data for input-information being learnt by neurons. To learn about this more
I-JePa; great piece of work democratized by Yann LeCun
The arch learns by creating an internal model of the outside world, which again compares abstract representation of Images, I know this is super confusing and was to me first but after watching Yannic Kilcher, reading and re-reading this, got a mental note of how it works
From this point of view, we will jump onto implementation and how the models trained and all kinds of simple interpretability I performed on the models
Aspects of Implementation:
Context Encoder: A 1D vision Transformer (4 layers, 4 attention heads, embedding dim 256)
Target Encoder: very similar to context-encoder, but weights are never updates, instead they track the context-encoder via Exponential Moving average (EMA). This in turn creates a slowly evolving Teacher that provides gooood predictions!!
Predictor: A shallow transformer that takes CE output, and concatenates with masked tokens, and predicts what target encoder's representation of these masked patches should look like.
The Architecture (Please use dark theme to see the arrows clear)
Courtesy of Claude
At this point, we have seen the architecture -- and motivation to move forward with this idea, I believe in building foundational models with close representation of the brain is the right call, Example; replacing memory with hippocampus simulation, replacing Backprop with predictive coding and more, such as going deeper into neuroscience and Applied interpretability.
JePa - pre-training loss.
by taking a look at this Graph, we can see it started to converge at 35 ish epoch, but right now working on RSSM, where rate of convergence feels better (this is still WIP and I might be completely wrong too)
I'm a firm believer in understanding the geometric representation of World-Modeling, I think I started out with this approach because it sounded coool, and eventually fell in love in interpretability of multi-modalities
Understanding Geometry of Representations (Truth);
I ran couple of methods and have about 5 plots will be plotting all of them in this blog, I think this is the crux of the blog.
Effective rank;
this is what I understand, this points out how many features are actually being used or sometimes we can also say how effectively information is spread across
This is the good math behind it
, these are our values; in here all dims are contributing equally so its effective rank is 4.
, In this all info are put in place in one solo direction, only 1 dim is contributing so its effective rank is 1.
In the below graph, we have a rank of 64 out of 256. This means we have a concentrated entropy AKA low entropy. If you still haven't figured it out concentration is bad.
Graph
Effective Rank (SVD Entropy)
Participation Ratio;
Participation Ratio or PR, effectively it stands to make us understand about acne concentration or we can also put it this way. How Many directions are meaningfully carrying the variance.
before jumping more into this, Let's figure out what are eigenvalues. After computing the covariance matrix of your embeddings, you get eigenvalue :, Each of these tells is how much variance exist along this principal direction.
If only one direction is dominating, then we can say Variance concentration
In our participation ratio, It says 5.6 out of 256 principal component Index, with this we can conclude our variance is extremely concentrated; ~5.6 this number showcases extreme anisotropy (non-uniformity).
Graph
EigenSpectrum graph
Trajectory Curvature;
I think of this as a means of how sharply a sequence of embeddings change direction over time. Think it out as following a set of path or a direction.
In our case the mean trajectory is (It is always better if the degree is low), with this we can effectively say our representations are not smooth at all AKA The low-dimensional structure is not smooth.
Another amazing find is that out of ~6 dims of where all the data is being present, the trajectory is oscillating between them.
The representation has structure but not smoothness which is completely fine by all means necessary it does not mean a bad thing, if it had gone through a collapse then yea it's an issue.
Graph
Distance Distribution;
Basically, it tells us how similar or different two embeddings are, with this we can study if a model is relying too much on the information from positional embeddings, (There might be other implications too, but till now I have only used it for this purpose)
My results showcased me two things (might be interesting):
Same patch position across different windows, the cosine distance being ~ 0.08 (almost similar)
Different patch positions within the same windows, the cosine distance being ~ 0.8 to 1.2 (This is very different)
I'm assuming this is because the model tends to place embeddings from the same position very close together, even when they come from different segments
Again, all of this could indicate the model is relying heavily on positional information or positional embeddings, I won't go much more into this because I think this as a blog of its own.
Graph
Centered Kernel Alignment;
A measure/metric used to understand the similarity of representations in neural networks. In my context I will be talking about internal representations that model has produced
In the experiment I conducted the matrix was uniform with a mean of 0.824, with this we can define that model has high degree of similarity between representations
In a Jepa based model this much is expected as the model architecture itself encourages to promote representational consistency and prevent the latent space from drifting too much over time.
Graph
Before we hit the summary section how about we take a min to touch grass :)
Key summary;
The representation lives on a ~12-dimensional anisotropic manifold inside the 256-dim space.
Patch trajectories are jagged (oscillatory curvature > ), this in turn is consistent with EEGs oscillatory nature.
But I believe the strongest findings is the distance distribution; The model maps to the same position across windows to nearly identical embeddings - suggesting the current representation is dominated via positional embeddings.
After struggling to understand all this concepts, I have decided to write a blog on this, this might lack a lot in context or clarity, but this is my first one - so please support and feel free to criticize my work at raghulchandramouli@gmail.com (happy to hop on a conversations) and I accept this research direction would have not been possible without help of AI agents.
Motivation for this endeavor
I believe the next biggest breakthrough will come from people who are not from CS background, and my bet it will be someone from Neuroscience, henceforth I started exploring neural dynamics and stuff. To keep myself entertained and when the time comes, I should be capable of laying a brick or two
Direction Of the Research:
So, in most of the cases usually a self-supervised approach is used and trained with a lot of biosignal datasets (EEG). This approach does wonders, and highly functional. SSL mostly learns to reconstruct patterns in data. But I want to pivot rather than making it learn how to reconstruct pattern I want to make it learn how to Understand.
Self-Supervised Learning
when I first learnt that masked language modeling is a form of SSL, it felt like magic. If you are new to Self-Supervised Learning it will be amazing if you took a look at BERT paper
The point of Pivot:
While exploring and trying to come up with a different/weird way of exploring by making the model understand the semantics, I switched and used JePa architecture (Joint-embedding Predictive architecture), The closest thing we have to predictive coding
Important stuffs
Latent Representations;
These are your internal features learnt by networks, this has info about input data in a compressed format. you can also put it as a meta-data for input-information being learnt by neurons. To learn about this more
I-JePa; great piece of work democratized by Yann LeCun
The arch learns by creating an internal model of the outside world, which again compares abstract representation of Images, I know this is super confusing and was to me first but after watching Yannic Kilcher, reading and re-reading this, got a mental note of how it works
Aspects of Implementation:
The Architecture (Please use dark theme to see the arrows clear)
Courtesy of Claude
JePa - pre-training loss.
by taking a look at this Graph, we can see it started to converge at 35 ish epoch, but right now working on RSSM, where rate of convergence feels better (this is still WIP and I might be completely wrong too)
Understanding Geometry of Representations (Truth);
I ran couple of methods and have about 5 plots will be plotting all of them in this blog, I think this is the crux of the blog.
Effective rank;
Graph
Effective Rank (SVD Entropy)
Participation Ratio;
Graph
EigenSpectrum graph
Trajectory Curvature;
Graph
Distance Distribution;
Graph
Centered Kernel Alignment;
Graph
Before we hit the summary section how about we take a min to touch grass :)
Key summary;
After struggling to understand all this concepts, I have decided to write a blog on this, this might lack a lot in context or clarity, but this is my first one - so please support and feel free to criticize my work at raghulchandramouli@gmail.com (happy to hop on a conversations) and I accept this research direction would have not been possible without help of AI agents.
Citing/Citation