Skip to content

Commit a36f23e

Browse files
committed
Fix type of HookedTransformerConfig.device
This is either a torch.device or a string like "cpu", but it was typed as just `Optional[str]`. This fixes it to be `Optional[Union[str, torch.device]]` and all of the downstream places that need to be updated. Found while working on #1219
1 parent 589acd4 commit a36f23e

3 files changed

Lines changed: 5 additions & 5 deletions

File tree

transformer_lens/HookedTransformerConfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class HookedTransformerConfig:
231231
attn_types: Optional[List] = None
232232
init_mode: str = "gpt2"
233233
normalization_type: Optional[str] = "LN"
234-
device: Optional[str] = None
234+
device: Optional[Union[str, torch.device]] = None
235235
n_devices: int = 1
236236
attention_dir: str = "causal"
237237
attn_only: bool = False

transformer_lens/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from dataclasses import dataclass
8-
from typing import Optional
8+
from typing import Optional, Union
99

1010
import torch
1111
import torch.optim as optim
@@ -50,7 +50,7 @@ class HookedTransformerTrainConfig:
5050
max_grad_norm: Optional[float] = None
5151
weight_decay: Optional[float] = None
5252
optimizer_name: str = "Adam"
53-
device: Optional[str] = None
53+
device: Optional[Union[str, torch.device]] = None
5454
warmup_steps: int = 0
5555
save_every: Optional[int] = None
5656
save_dir: Optional[str] = None

transformer_lens/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False
10151015
print(f"column mismatch: {col_mismatch}")
10161016

10171017

1018-
def get_device():
1018+
def get_device() -> torch.device:
10191019
if torch.cuda.is_available():
10201020
return torch.device("cuda")
10211021
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
@@ -1051,7 +1051,7 @@ def _torch_version_tuple() -> tuple[int, ...]:
10511051
return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
10521052

10531053

1054-
def warn_if_mps(device):
1054+
def warn_if_mps(device: Union[str, torch.device]) -> None:
10551055
"""Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
10561056
10571057
Automatically suppressed when the installed PyTorch version meets or exceeds

0 commit comments

Comments
 (0)