diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..0819ab3 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -148,7 +149,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..49e9e2b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -159,7 +160,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(