Skip to content

Commit 6697253

Browse files
committed
add test_int8_tensor.py
1 parent d1515fa commit 6697253

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.quantization.utils import compute_error, get_block_size
2727
from torchao.testing.model_architectures import ToyTwoLinearModel
2828
from torchao.testing.utils import TorchAOIntegrationTestCase
29-
from torchao.utils import torch_version_at_least
29+
from torchao.utils import get_current_accelerator_device, torch_version_at_least
3030

3131
INT8_TEST_CONFIGS = [
3232
Int8WeightOnlyConfig(version=2, granularity=PerTensor()),
@@ -38,9 +38,10 @@
3838
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
3939
),
4040
]
41+
_DEVICE = get_current_accelerator_device()
4142

4243

43-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
44+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
4445
@common_utils.instantiate_parametrized_tests
4546
class TestInt8Tensor(TorchAOIntegrationTestCase):
4647
def setUp(self):
@@ -60,7 +61,7 @@ def test_creation_and_attributes(self, config):
6061
self.test_shape[0],
6162
bias=False,
6263
dtype=self.dtype,
63-
device="cuda",
64+
device=_DEVICE,
6465
)
6566
quantize_(linear, config)
6667

@@ -99,8 +100,8 @@ def test_int8_linear_variants(
99100
torch.compiler.reset()
100101

101102
M, N, K = sizes
102-
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
103-
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
103+
input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE)
104+
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=_DEVICE).eval()
104105
model_q = copy.deepcopy(model)
105106

106107
quantize_(model_q, config)
@@ -128,7 +129,7 @@ def test_int8_linear_variants(
128129
)
129130

130131
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
131-
@common_utils.parametrize("device", ["cpu", "cuda"])
132+
@common_utils.parametrize("device", ["cpu", _DEVICE])
132133
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
133134
def test_slice(self, config, device, dtype):
134135
"""Test tensor slicing with per-row quantization"""
@@ -159,8 +160,8 @@ def test_slice(self, config, device, dtype):
159160
def test_index_select(self, config):
160161
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
161162
N, K = 256, 512
162-
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
163-
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
163+
x = torch.randn(N, K, device=_DEVICE, dtype=torch.bfloat16)
164+
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device=_DEVICE)
164165
linear.weight.data = x
165166

166167
quantize_(linear, config)
@@ -187,7 +188,7 @@ def test_index_select(self, config):
187188
def test_dequantization_accuracy(self, config):
188189
"""Test dequantization accuracy separately"""
189190
linear = torch.nn.Linear(
190-
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
191+
256, 512, bias=False, dtype=torch.bfloat16, device=_DEVICE
191192
)
192193
weight_fp = copy.deepcopy(linear.weight)
193194
quantize_(linear, config)
@@ -208,14 +209,14 @@ def test_available_gpu_kernels(self):
208209

209210
M, K, N = 128, 256, 512
210211
m = torch.nn.Sequential(
211-
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
212+
torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16)
212213
)
213214

214215
config = Int8DynamicActivationInt8WeightConfig(version=2)
215216
quantize_(m, config)
216217

217218
m = torch.compile(m)
218-
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
219+
x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
219220

220221
out, code = run_and_get_code(m, x)
221222

@@ -248,7 +249,7 @@ def test_pin_memory(self, config):
248249
)
249250

250251

251-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
252+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
252253
@common_utils.instantiate_parametrized_tests
253254
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
254255
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@@ -257,9 +258,9 @@ def test_static_activation_per_row_int8_weight(self, granularity, dtype):
257258
torch.compiler.reset()
258259

259260
M, N, K = 32, 32, 32
260-
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")
261+
input_tensor = torch.randn(M, K, dtype=dtype, device=_DEVICE)
261262

262-
model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
263+
model = torch.nn.Linear(K, N, bias=False).eval().to(device=_DEVICE, dtype=dtype)
263264
model_static_quant = copy.deepcopy(model)
264265
model_dynamic_quant = copy.deepcopy(model)
265266

0 commit comments

Comments
 (0)