diff --git a/tests/unit/pretrained_weight_conversions/test_apertus.py b/tests/unit/pretrained_weight_conversions/test_apertus.py index a9fc7f05f..c343780d9 100644 --- a/tests/unit/pretrained_weight_conversions/test_apertus.py +++ b/tests/unit/pretrained_weight_conversions/test_apertus.py @@ -183,7 +183,7 @@ def test_zero_biases_have_correct_device(self): "blocks.0.mlp.b_out", "unembed.b_U", ]: - assert sd[key].device.type == cfg.device.type, f"{key} on wrong device" + assert sd[key].device.type == cfg.device, f"{key} on wrong device" def test_unembed_shapes(self): cfg = make_cfg() diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index f208f3e51..a8036fafd 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -92,7 +92,7 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) result = get_device() - assert result == torch.device("cpu") + assert result == "cpu" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -102,7 +102,7 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ def test_get_device_returns_mps_when_env_var_set(mock_built, mock_avail, mock_cuda): """get_device() should return MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" result = get_device() - assert result == torch.device("mps") + assert result == "mps" @patch.dict("os.environ", {}, clear=False) diff --git a/transformer_lens/train.py b/transformer_lens/train.py index ec9537600..1ad298814 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.optim as optim @@ -50,7 +50,7 @@ class HookedTransformerTrainConfig: max_grad_norm: Optional[float] = None weight_decay: Optional[float] = None optimizer_name: str = "Adam" - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None warmup_steps: int = 0 save_every: Optional[int] = None save_dir: Optional[str] = None diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 881811833..d9a5826f8 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -1015,9 +1015,9 @@ def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False print(f"column mismatch: {col_mismatch}") -def get_device(): +def get_device() -> str: if torch.cuda.is_available(): - return torch.device("cuda") + return "cuda" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): major_version = int(torch.__version__.split(".")[0]) if major_version >= 2: @@ -1026,9 +1026,9 @@ def get_device(): _MPS_MIN_SAFE_TORCH_VERSION is not None and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION ): - return torch.device("mps") + return "mps" if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": - return torch.device("mps") + return "mps" logging.info( "MPS device available but not auto-selected due to known correctness issues " "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " @@ -1036,7 +1036,7 @@ def get_device(): torch.__version__, ) - return torch.device("cpu") + return "cpu" _mps_warned = False @@ -1051,7 +1051,7 @@ def _torch_version_tuple() -> tuple[int, ...]: return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) -def warn_if_mps(device): +def warn_if_mps(device: Union[str, torch.device]) -> None: """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. Automatically suppressed when the installed PyTorch version meets or exceeds