Skip to content

Commit cae8f3f

Browse files
degenfabianbryce13950jlarson4
authored
updated loading in llama 2 demo to use transformer bridge (#1019)
* updated loading in llama 2 demo to use transformer bridge * Updating LLaMA quantized model --------- Co-authored-by: Bryce Meyer <bryce13950@gmail.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
1 parent f7587dd commit cae8f3f

7 files changed

Lines changed: 246 additions & 429 deletions

File tree

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ jobs:
235235
# - "Head_Detector_Demo"
236236
# - "Interactive_Neuroscope"
237237
# - "LLaMA"
238-
# - "LLaMA2_GPU_Quantized"
238+
# - "LLaMA2_GPU_Quantized" # Requires quantization libs + too slow for CI timeout
239239
- "Main_Demo"
240240
# - "No_Position_Experiment"
241241
- "Othello_GPT"

demos/LLaMA2_GPU_Quantized.ipynb

Lines changed: 153 additions & 414 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"tabulate>=0.9.0",
6464
]
6565
jupyter=["ipywidgets>=8.1.1", "jupyterlab>=3.5.0"]
66+
quantization=["bitsandbytes>=0.46.1", "optimum-quanto>=0.2.7"]
6667

6768
[tool.poetry.dependencies]
6869
accelerate=">=0.23.0" # Needed for Llama Models

transformer_lens/model_bridge/bridge.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def boot_transformers(
149149
load_weights: bool = True,
150150
trust_remote_code: bool = False,
151151
model_class: Optional[type] = None,
152+
hf_model: Optional[Any] = None,
152153
) -> "TransformerBridge":
153154
"""Boot a model from HuggingFace (alias for sources.transformers.boot).
154155
@@ -162,6 +163,9 @@ def boot_transformers(
162163
trust_remote_code: Whether to trust remote code for custom model architectures.
163164
model_class: Optional HuggingFace model class to use instead of the default
164165
auto-detected class (e.g., BertForNextSentencePrediction).
166+
hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful
167+
for models loaded with custom configurations (e.g., quantization via
168+
BitsAndBytesConfig). When provided, load_weights is ignored.
165169
166170
Returns:
167171
The bridge to the loaded model.
@@ -177,6 +181,7 @@ def boot_transformers(
177181
load_weights=load_weights,
178182
trust_remote_code=trust_remote_code,
179183
model_class=model_class,
184+
hf_model=hf_model,
180185
)
181186

182187
@property

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def boot(
270270
load_weights: bool = True,
271271
trust_remote_code: bool = False,
272272
model_class: Any | None = None,
273+
hf_model: Any | None = None,
273274
) -> TransformerBridge:
274275
"""Boot a model from HuggingFace.
275276
@@ -283,6 +284,9 @@ def boot(
283284
model_class: Optional HuggingFace model class to use instead of the default auto-detected
284285
class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding
285286
adapter is selected automatically (e.g., BertForNextSentencePrediction).
287+
hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
288+
models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
289+
When provided, load_weights is ignored.
286290
287291
Returns:
288292
The bridge to the loaded model.
@@ -368,7 +372,10 @@ def boot(
368372
# Default to eager (required for output_attentions hooks)
369373
model_kwargs["attn_implementation"] = "eager"
370374
adapter.prepare_loading(model_name, model_kwargs)
371-
if not load_weights:
375+
if hf_model is not None:
376+
# Use the pre-loaded model as-is (e.g., quantized models with custom device_map)
377+
pass
378+
elif not load_weights:
372379
from_config_kwargs = {}
373380
if trust_remote_code:
374381
from_config_kwargs["trust_remote_code"] = True

transformer_lens/supported_models.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,9 @@
274274
"bigscience/bloom-3b": ["bloom-3b"],
275275
"bigscience/bloom-560m": ["bloom-560m"],
276276
"bigscience/bloom-7b1": ["bloom-7b1"],
277-
"codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
278-
"codellama/CodeLlama-7b-Instruct-hf": [
279-
"CodeLlama-7b-instruct",
280-
"codellama/CodeLlama-7b-Instruct-hf",
281-
],
282-
"codellama/CodeLlama-7b-Python-hf": ["CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf"],
277+
"codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b"],
278+
"codellama/CodeLlama-7b-Instruct-hf": ["CodeLlama-7b-instruct"],
279+
"codellama/CodeLlama-7b-Python-hf": ["CodeLlama-7b-python"],
283280
"distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
284281
"EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
285282
"EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
@@ -404,16 +401,16 @@
404401
"EleutherAI/pythia-19m-v0",
405402
"pythia-19m-v0",
406403
],
407-
"facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"],
404+
"facebook/hubert-base-ls960": ["hubert-base-ls960"],
408405
"facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
409406
"facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
410407
"facebook/opt-13b": ["opt-13b", "opt-xxl"],
411408
"facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
412409
"facebook/opt-30b": ["opt-30b", "opt-xxxl"],
413410
"facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
414411
"facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
415-
"facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"],
416-
"facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"],
412+
"facebook/wav2vec2-base": ["wav2vec2-base", "w2v2-base"],
413+
"facebook/wav2vec2-large": ["wav2vec2-large", "w2v2-large"],
417414
"google-bert/bert-base-cased": ["bert-base-cased"],
418415
"google-bert/bert-base-uncased": ["bert-base-uncased"],
419416
"google-bert/bert-large-cased": ["bert-large-cased"],
@@ -450,11 +447,11 @@
450447
"llama-30b-hf": ["llama-30b"],
451448
"llama-65b-hf": ["llama-65b"],
452449
"llama-7b-hf": ["llama-7b"],
453-
"meta-llama/Llama-2-13b-chat-hf": ["Llama-2-13b-chat", "meta-llama/Llama-2-13b-chat-hf"],
454-
"meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
450+
"meta-llama/Llama-2-13b-chat-hf": ["Llama-2-13b-chat"],
451+
"meta-llama/Llama-2-13b-hf": ["Llama-2-13b"],
455452
"meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
456-
"meta-llama/Llama-2-7b-chat-hf": ["Llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf"],
457-
"meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
453+
"meta-llama/Llama-2-7b-chat-hf": ["Llama-2-7b-chat"],
454+
"meta-llama/Llama-2-7b-hf": ["Llama-2-7b"],
458455
"microsoft/phi-1": ["phi-1"],
459456
"microsoft/phi-1_5": ["phi-1_5"],
460457
"microsoft/phi-2": ["phi-2"],

uv.lock

Lines changed: 68 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)