diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1dde6c726..e2b6dfd82 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -223,7 +223,7 @@ jobs: fail-fast: false matrix: notebook: - # - "Activation_Patching_in_TL_Demo" + - "Activation_Patching_in_TL_Demo" # - "Attribution_Patching_Demo" - "ARENA_Content" - "BERT" diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index abc033ad7..98e554445 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -40,17 +40,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - } - ], + "outputs": [], "source": [ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEBUG_MODE = False\n", @@ -58,7 +50,7 @@ " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n", + " %pip install transformer_lens\n", " # Install my janky personal plotting utils\n", " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", "except:\n", @@ -67,7 +59,7 @@ " from IPython import get_ipython\n", "\n", " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " # Code to automatically update the TransformerBridge code as its edited without restarting the kernel\n", " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", " ipython.run_line_magic(\"autoreload\", \"2\")" ] @@ -127,11 +119,7 @@ "source": [ "import transformer_lens\n", "import transformer_lens.utils as utils\n", - "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", - " HookPoint,\n", - ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens.model_bridge import TransformerBridge" ] }, { @@ -175,7 +163,14 @@ "metadata": {}, "outputs": [], "source": [ - "from neel_plotly import line, imshow, scatter" + "try:\n", + " from neel_plotly import line, imshow, scatter\n", + "except ImportError:\n", + " # neel_plotly is an optional visualization dependency.\n", + " # Define no-op stubs so patching computations still run without it.\n", + " def line(*args, **kwargs): pass\n", + " def imshow(*args, **kwargs): pass\n", + " def scatter(*args, **kwargs): pass" ] }, { @@ -201,22 +196,30 @@ "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using pad_token, but it is not set yet.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3040c20b4d87433bbdc4897d1e59b8e7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/148 [00:00\n", "\n", "
\n", - "
\n", + " }) }; \n", "\n", "" ] @@ -433,7 +436,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "49d0b56fa468408ca15afa42d0d1c91b", + "model_id": "bd2849621acf4ba2a5f6b5cfc6aa68d6", "version_major": 2, "version_minor": 0 }, @@ -451,9 +454,9 @@ "\n", "\n", "
\n", - "
\n", + " }) }; \n", "\n", "" ] @@ -509,7 +512,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "04132888196746f3b16918c63dd6c023", + "model_id": "3b902ce3c31a45e0854b43eadff198d5", "version_major": 2, "version_minor": 0 }, @@ -527,9 +530,9 @@ "\n", "\n", "
\n", - "
\n", + " }) }; \n", "\n", "" ] @@ -590,7 +593,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9dec177dd2f446248b5850e149fab8fc", + "model_id": "440041cdc83b4bc3b1f991cefa1df6f4", "version_major": 2, "version_minor": 0 }, @@ -604,7 +607,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ec9aae8965c84a819fcf6158dbc45fa6", + "model_id": "05b1394ee4a54d8e94050197279d2a11", "version_major": 2, "version_minor": 0 }, @@ -618,7 +621,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "caaf63fc68224f4babc8f492d55785e2", + "model_id": "4f5b66701221464580ccc06f25d257e0", "version_major": 2, "version_minor": 0 }, @@ -636,9 +639,9 @@ "\n", "\n", "
\n", - "
\n", + " }) }; \n", "\n", "" ] @@ -689,7 +692,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "997da4ffc80c4a87a827ab16aa1b76d7", + "model_id": "fd46410602d648ecb4d8e5c603f599c7", "version_major": 2, "version_minor": 0 }, @@ -703,7 +706,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c71892621747426b8f69dc36dcee104e", + "model_id": "898ce7b338ec42b692974fb4af3d3d3a", "version_major": 2, "version_minor": 0 }, @@ -717,7 +720,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1df77d7824f44ba58a08f37129cc7722", + "model_id": "a202d7fa0e894969b6093823c3a07900", "version_major": 2, "version_minor": 0 }, @@ -731,7 +734,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fd3be25d4cd04a3bbbbec2c24a7b9d6b", + "model_id": "07f6b5419b414d739882be1135a0ac26", "version_major": 2, "version_minor": 0 }, @@ -745,7 +748,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fc48b4abdc40412387bbea701ce028ed", + "model_id": "487f335aa1fb4f0a93c04944fcae2918", "version_major": 2, "version_minor": 0 }, @@ -763,9 +766,9 @@ "\n", "\n", "
\n", - "
\n", + " }) }; \n", "\n", "" ] @@ -805,120 +808,12 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "92709c396fdd48a7b169b87cc40d4ac2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/2160 [00:00\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "# 10,800 forward passes of GPT-2; too slow for CI without GPU\n", + "# NBVAL_SKIP\n", "if DO_SLOW_RUNS:\n", " every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", " every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n", @@ -956,7 +851,8 @@ } ], "source": [ - "attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n", + "attn_only = TransformerBridge.boot_transformers(\"attn-only-2l\")\n", + "attn_only.enable_compatibility_mode()\n", "batch = 4\n", "seq_len = 20\n", "rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n", diff --git a/demos/conftest.py b/demos/conftest.py index 2cd3d5a3e..755a20446 100644 --- a/demos/conftest.py +++ b/demos/conftest.py @@ -5,4 +5,5 @@ def pytest_collectstart(collector): "text/html", "application/javascript", "application/vnd.plotly.v1+json", # Plotly + "application/vnd.jupyter.widget-view+json", # Jupyter widgets (random model_id) )