Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions intervention/circle_probe_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
choices=["llama", "mistral"],
help="Choose 'llama' or 'mistral' model",
)
parser.add_argument("--device", type=int, default=4, help="CUDA device number")
parser.add_argument("--device", type=str, default="4", help="CUDA device number, or full device string")
parser.add_argument(
"--use_inverse_regression_probe",
action="store_true",
Expand Down Expand Up @@ -73,7 +73,7 @@
help="Probe on linear representation with center of 0.",
)
args = parser.parse_args()
device = f"cuda:{args.device}"
device = (f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") if args.device.isnumeric() else args.device
day_month_choice = args.problem_type
circle_letter = args.intervene_on
model_name = args.model
Expand All @@ -100,7 +100,7 @@
# use_inverse_regression_probe = False
# intervention_pca_k = 5

device = "cuda:4"
device = "cuda:4" if torch.cuda.is_available() else "cpu"
circle_letter = "c"
day_month_choice = "day"
model_name = "mistral"
Expand Down
2 changes: 1 addition & 1 deletion intervention/days_of_week_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from task import activation_patching


device = "cuda:4"
device = "cuda:4" if torch.cuda.is_available() else "cpu"
#
# %%

Expand Down
3 changes: 2 additions & 1 deletion intervention/months_of_year_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

setup_notebook()

import torch
import numpy as np
import transformer_lens
from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca
from task import activation_patching


device = "cuda:4"
device = "cuda:4" if torch.cuda.is_available() else "cpu"
#
# %%

Expand Down