Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e8ae3c9 to
f095801
Compare
93458cf to
21cec5f
Compare
bb190ed to
2dc37df
Compare
|
@gemini-cli /review |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.
🔍 General Feedback
- Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
- Clear Implementation: The code in
engram.pyis well-commented and organized, making it easy to follow the logic from the original paper. - TODOs: I've commented on the
TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.
|
🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details. |
2dc37df to
5371cae
Compare
RissyRan
left a comment
There was a problem hiding this comment.
I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.
5371cae to
2c4e71f
Compare
92b7c55 to
09e7c1e
Compare
09e7c1e to
1606c36
Compare
RissyRan
left a comment
There was a problem hiding this comment.
LGTM in general! Just a few minor comments. I will have a try to integrate this change with a decoder layer tomorrow and see how it goes.
1606c36 to
4e7559c
Compare
1ec2336 to
49499b6
Compare
Description
Background
What this PR does
Add Engram layer:
engram.pyNgramHashMapping(non-parametric):CompressedTokenizer+ hashing logic, convert "input_id" to "ngram hash_token_id"CompressedTokenizer(non-parametric): convert "input_id" to "compresed_input_id"Engram(multi-branch): inputs are "ngram hash_token_id" and "transformer state",MultiHeadEmbedding(lookup embedding using hash id as static memory) + context-aware gating (dot product static memory with contextual state) +ShortConv(temporal smoothing)MultiHeadEmbedding: convert ngram hash_token_id to ngram embedding vectorShortConv(multi-branch): depthwise (mix time steps, not mix channel), causal, short means kernel size is smallAdd unit test:
tests.unit.engram_vs_reference_testImplementation Notes
Placement of:
NgramHashMappingNgramHashMappingconverts vanilla token-ids to hashed ngram token-ids, whichEngramconsumes for embedding lookupNgramHashMappingand hash_input_ids generation be put in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.Multi-branch
EngramandShortConvhandles multi-branch input and multi-branch output (ifmhc_expansion_rate > 1), using nnx.vmap for independent norm per branchTests
unit test against reference
log: https://paste.googleplex.com/5905570101067776
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.