Skip to content

Add DeepSeek Engram layer#3010

Merged
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-engram
Feb 11, 2026
Merged

Add DeepSeek Engram layer#3010
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-engram

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jan 26, 2026

Description

Background

What this PR does

Add Engram layer: engram.py

  • NgramHashMapping (non-parametric): CompressedTokenizer + hashing logic, convert "input_id" to "ngram hash_token_id"
    • CompressedTokenizer (non-parametric): convert "input_id" to "compresed_input_id"
  • Engram (multi-branch): inputs are "ngram hash_token_id" and "transformer state", MultiHeadEmbedding (lookup embedding using hash id as static memory) + context-aware gating (dot product static memory with contextual state) + ShortConv (temporal smoothing)
    • MultiHeadEmbedding: convert ngram hash_token_id to ngram embedding vector
    • ShortConv (multi-branch): depthwise (mix time steps, not mix channel), causal, short means kernel size is small

Add unit test: tests.unit.engram_vs_reference_test

  • for each component, verify the output matches that from reference code

Implementation Notes

Placement of: NgramHashMapping

  • NgramHashMapping converts vanilla token-ids to hashed ngram token-ids, which Engram consumes for embedding lookup
  • Future: I would like to NgramHashMapping and hash_input_ids generation be put in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.

Multi-branch

  • Engram and ShortConv handles multi-branch input and multi-branch output (if mhc_expansion_rate > 1), using nnx.vmap for independent norm per branch
  • Future: to be integrated into multi-branch backbone like mHC.

Tests

unit test against reference

python3 -m pytest -v --pyargs tests.unit.engram_vs_reference_test -rP -s

log: https://paste.googleplex.com/5905570101067776

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 26, 2026

Codecov Report

❌ Patch coverage is 0% with 198 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/engram.py 0.00% 198 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin changed the title [DRAFT] do no merge [DRAFT] engram Jan 29, 2026
@shuningjin shuningjin force-pushed the shuningjin-engram branch 2 times, most recently from 93458cf to 21cec5f Compare January 30, 2026 17:52
@shuningjin shuningjin changed the title [DRAFT] engram Add DeepSeek Engram layer Feb 4, 2026
@shuningjin shuningjin marked this pull request as ready for review February 4, 2026 21:48
@shuningjin
Copy link
Collaborator Author

@gemini-cli /review

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-cli /review

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.

🔍 General Feedback

  • Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
  • Clear Implementation: The code in engram.py is well-commented and organized, making it easy to follow the logic from the original paper.
  • TODOs: I've commented on the TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general! Just a few minor comments. I will have a try to integrate this change with a decoder layer tomorrow and see how it goes.

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

@copybara-service copybara-service bot merged commit b10c284 into main Feb 11, 2026
104 of 105 checks passed
@copybara-service copybara-service bot deleted the shuningjin-engram branch February 11, 2026 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants