Skip to content

Conversation

@sbhavani
Copy link
Collaborator

@sbhavani sbhavani commented Jan 7, 2026

Description

The te_llama.py example fails with HF transformers 4.57+ due to a breaking change in how decoder layer outputs are handled. In transformers 4.57+, the LlamaModel forward loop changed causing TELlamaDecoderLayer to fail because it was returning a tuple (tensor,) instead of the tensor directly.

Fixes #2567

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Handle case where hidden_states is passed as a tuple (for backward compatibility with older HF versions)
  • Return tensor directly instead of wrapping in tuple (required for HF transformers >= 4.57)
  • Fix regex SyntaxWarning by using raw string prefix (r"model.layers.\d+.")

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Testing

Tested with:

  • transformer_engine 2.5.0+f05f12c
  • transformers 4.57.3
  • nvcr.io/nvidia/pytorch:25.08-py3

…= 4.57

The te_llama.py example was failing with HuggingFace transformers 4.57+
due to API changes in how decoder layer outputs are handled.

Changes:
- Handle case where hidden_states is passed as a tuple (older HF versions)
- Return tensor directly instead of wrapped in tuple (HF 4.57+ expects this)
- Fix regex pattern to use raw string (fixes SyntaxWarning)

Error fixed:
  AttributeError: 'tuple' object has no attribute 'contiguous'

Tested with:
- transformer_engine 2.5.0
- transformers 4.57.3
- PyTorch container nvcr.io/nvidia/pytorch:25.08-py3

Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com>
@sbhavani sbhavani force-pushed the fix/te-llama-hf-transformers-457-compat branch from 2ca056c to a65fa49 Compare January 7, 2026 16:46
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Summary

This PR fixes a breaking compatibility issue with HuggingFace transformers >= 4.57 where TELlamaDecoderLayer.forward() was returning a tuple (tensor,) instead of a tensor directly, causing AttributeError when the model loop tried to call .contiguous() on the tuple.

Key Changes:

  • Modified TELlamaDecoderLayer.forward() to return tensor directly (required for transformers >= 4.57)
  • Added defensive tuple unpacking check for hidden_states input
  • Fixed regex SyntaxWarning by using raw string prefix (r"model.layers.\d+.")
  • Added requirements.txt with pinned dependency versions including transformers 4.57.0
  • Updated notebook to reference requirements.txt

Impact:
The changes restore compatibility with transformers 4.57+ while maintaining the fix for the core issue. The regex fix eliminates a Python warning.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - it fixes a specific compatibility bug with well-defined scope
  • The core fix (returning tensor directly) correctly addresses the breaking change in transformers 4.57+, and the regex fix is valid. Score of 4 instead of 5 due to the defensive tuple-handling code that may not be necessary based on actual HuggingFace behavior, though it won't cause issues.
  • The tuple handling logic in te_llama.py:75-78 could be reviewed for necessity, but it functions as defensive programming without negative impact

Important Files Changed

Filename Overview
docs/examples/te_llama/te_llama.py Fixed compatibility issue with transformers >= 4.57 by returning tensor directly and handling tuple input defensively; also fixed regex SyntaxWarning
docs/examples/te_llama/requirements.txt Added new requirements file with pinned dependency versions for transformers 4.57.0 and related packages

Sequence Diagram

sequenceDiagram
    participant User
    participant LlamaModel
    participant TELlamaDecoderLayer
    participant TransformerLayer
    
    User->>LlamaModel: forward(input_ids)
    
    rect rgb(240, 240, 240)
        Note over LlamaModel,TELlamaDecoderLayer: Transformers < 4.57
        LlamaModel->>TELlamaDecoderLayer: forward(hidden_states_tensor)
        TELlamaDecoderLayer->>TransformerLayer: forward(hidden_states)
        TransformerLayer-->>TELlamaDecoderLayer: output_tensor
        TELlamaDecoderLayer-->>LlamaModel: (output_tensor,)
        LlamaModel->>LlamaModel: hidden_states = layer_outputs[0]
    end
    
    rect rgb(220, 250, 220)
        Note over LlamaModel,TELlamaDecoderLayer: Transformers >= 4.57 (This PR)
        LlamaModel->>TELlamaDecoderLayer: forward(hidden_states_tensor)
        TELlamaDecoderLayer->>TELlamaDecoderLayer: Check isinstance(tuple)
        TELlamaDecoderLayer->>TransformerLayer: forward(hidden_states)
        TransformerLayer-->>TELlamaDecoderLayer: output_tensor
        TELlamaDecoderLayer-->>LlamaModel: output_tensor (not wrapped)
        LlamaModel->>LlamaModel: hidden_states = layer_outputs
    end
    
    LlamaModel-->>User: model_output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. docs/examples/te_llama/te_llama.py, line 77-78 (link)

    logic: No check for empty tuple before accessing [0]. If hidden_states is an empty tuple, this will raise an IndexError.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27 sudhakarsingh27 self-requested a review January 7, 2026 22:35
@sudhakarsingh27
Copy link
Collaborator

Thanks for fixing, @sbhavani! Lgtm.
(I guess the other example - te_gemma also would need a change, let me take care of that - will fix other pending issues with it)

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, fixing this for 4.57+ would break it for previous versions then?

@sbhavani
Copy link
Collaborator Author

sbhavani commented Jan 8, 2026

Actually, fixing this for 4.57+ would break it for previous versions then?

nope it handles both the prev and current version of transformers. I think we should fix the version it to support the latest transformers as both TE and transformers APIs are constantly changing.

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see that it correctly handles version dependencies.

I agree with fixing library versions. Would you be open to

  1. Create a requirements.txt file with correct versions for TE, huggingface, accelerate, peft, datasets libraries
  2. Add a small section at the start of the tutorial which mentions install the prereqs using pip install -r requirements.txt?

:)

(I did that for te_gemma for your reference)

Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes compatibility issues with HuggingFace transformers >= 4.57 by changing how TELlamaDecoderLayer returns outputs. The breaking change in transformers 4.57+ modified the LlamaModel forward loop to expect decoder layers to return tensors directly instead of tuples.

Key changes:

  • Modified TELlamaDecoderLayer.forward() to return tensor directly instead of wrapping in tuple
  • Added defensive tuple unpacking for hidden_states input (backward compatibility safety)
  • Fixed regex SyntaxWarning by using raw string prefix (r"model.layers.\d+.")
  • Added requirements.txt to pin tested dependency versions

Issues found:

  • Minor: Confusing comment about when tuple unpacking is needed
  • The tuple unpacking check appears to be defensive programming rather than necessary for backward compatibility, as the return type change itself handles version compatibility

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - it addresses a critical compatibility issue with clear, targeted fixes
  • Score reflects that the fix correctly addresses the transformers 4.57+ breaking change by modifying return types. The regex fix is correct. Minor deduction for a misleading comment about the tuple unpacking logic, which doesn't affect functionality but could confuse future maintainers. The changes are well-tested according to PR description and follow a clear pattern.
  • No files require special attention - all changes are straightforward compatibility fixes

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/te_llama/te_llama.py 4/5 Fixed forward method to return tensor directly instead of tuple for transformers >= 4.57, added input tuple handling for backward compatibility, and fixed regex SyntaxWarning with raw string prefix
docs/examples/te_llama/requirements.txt 4/5 New file pinning dependency versions for the example, including transformers==4.57.0 which contains the breaking change this PR addresses

Sequence Diagram

sequenceDiagram
    participant HF as HuggingFace LlamaModel
    participant TELayer as TELlamaDecoderLayer
    participant TE as TransformerLayer (TE)
    
    Note over HF,TE: Transformers >= 4.57
    HF->>TELayer: forward(hidden_states=tensor)
    TELayer->>TELayer: Check isinstance(hidden_states, tuple)
    Note over TELayer: False - continue
    TELayer->>TE: super().forward(hidden_states, ...)
    TE-->>TELayer: returns tensor
    TELayer-->>HF: returns tensor directly
    
    Note over HF,TE: Transformers < 4.57 (with old code)
    HF->>TELayer: forward(hidden_states=tensor)
    TELayer->>TE: super().forward(hidden_states, ...)
    TE-->>TELayer: returns tensor
    TELayer-->>HF: returns (tensor,)
    Note over HF: Extracts with layer_outputs[0]
Loading

Comment on lines +75 to +78
# Handle case where hidden_states might be a tuple (from previous layer output)
# This can happen with older versions of HuggingFace transformers
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is misleading about when tuple unpacking is needed.

With the new code that returns tensor directly, hidden_states should never be a tuple when called from HuggingFace's LlamaModel forward loop (in any version). The old code returned (tensor,) for transformers < 4.57, but HF's loop extracted it with layer_outputs[0] before passing to the next layer.

This check appears to be defensive programming rather than addressing a real backward compatibility scenario. Consider clarifying the comment to explain this is a safety check rather than expected behavior.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@sudhakarsingh27
Copy link
Collaborator

Looks like the documentation wasn't built successfully. @sbhavani, could you look into it?

@sbhavani sbhavani force-pushed the fix/te-llama-hf-transformers-457-compat branch from e3ee699 to ff23db8 Compare January 23, 2026 00:28
@sbhavani
Copy link
Collaborator Author

Looks like the documentation wasn't built successfully. @sbhavani, could you look into it?

hopefully fixed now, our Sphinx documentation build uses strict nbformat validation

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@sbhavani sbhavani force-pushed the fix/te-llama-hf-transformers-457-compat branch from ff23db8 to 293ff1b Compare January 24, 2026 17:37
Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com>
@sbhavani sbhavani force-pushed the fix/te-llama-hf-transformers-457-compat branch from 293ff1b to 3568a8d Compare January 24, 2026 17:46
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +75 to +78
# Handle case where hidden_states might be a tuple (from previous layer output)
# This can happen with older versions of HuggingFace transformers
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Comment is misleading - hidden_states was never actually a tuple in any HF version since the loop always extracted with layer_outputs[0] before passing to next layer. This is defensive programming rather than backward compatibility. Consider clarifying the comment.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

te_llama.py example fails with transformers 4.57+

2 participants