Skip to content

Commit 989bc8e

Browse files
committed
add test_int8_tensor.py
1 parent f3342a0 commit 989bc8e

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.quantization.utils import compute_error, get_block_size
2727
from torchao.testing.model_architectures import ToyTwoLinearModel
2828
from torchao.testing.utils import TorchAOIntegrationTestCase
29-
from torchao.utils import torch_version_at_least
29+
from torchao.utils import get_current_accelerator_device, torch_version_at_least
3030

3131
INT8_TEST_CONFIGS = [
3232
Int8WeightOnlyConfig(version=2, granularity=PerTensor()),
@@ -38,9 +38,10 @@
3838
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
3939
),
4040
]
41+
_DEVICE = get_current_accelerator_device()
4142

4243

43-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
44+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
4445
@common_utils.instantiate_parametrized_tests
4546
class TestInt8Tensor(TorchAOIntegrationTestCase):
4647
def setUp(self):
@@ -60,7 +61,7 @@ def test_creation_and_attributes(self, config):
6061
self.test_shape[0],
6162
bias=False,
6263
dtype=self.dtype,
63-
device="cuda",
64+
device=_DEVICE,
6465
)
6566
quantize_(linear, config)
6667

@@ -99,8 +100,8 @@ def test_int8_linear_variants(
99100
torch.compiler.reset()
100101

101102
M, N, K = sizes
102-
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
103-
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
103+
input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE)
104+
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=_DEVICE).eval()
104105
model_q = copy.deepcopy(model)
105106

106107
quantize_(model_q, config)
@@ -128,7 +129,7 @@ def test_int8_linear_variants(
128129
)
129130

130131
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
131-
@common_utils.parametrize("device", ["cpu", "cuda"])
132+
@common_utils.parametrize("device", ["cpu", _DEVICE])
132133
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
133134
def test_slice(self, config, device, dtype):
134135
"""Test tensor slicing with per-row quantization"""
@@ -159,8 +160,8 @@ def test_slice(self, config, device, dtype):
159160
def test_index_select(self, config):
160161
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
161162
N, K = 256, 512
162-
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
163-
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
163+
x = torch.randn(N, K, device=_DEVICE, dtype=torch.bfloat16)
164+
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device=_DEVICE)
164165
linear.weight.data = x
165166

166167
quantize_(linear, config)
@@ -187,7 +188,7 @@ def test_index_select(self, config):
187188
def test_dequantization_accuracy(self, config):
188189
"""Test dequantization accuracy separately"""
189190
linear = torch.nn.Linear(
190-
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
191+
256, 512, bias=False, dtype=torch.bfloat16, device=_DEVICE
191192
)
192193
weight_fp = copy.deepcopy(linear.weight)
193194
quantize_(linear, config)
@@ -208,14 +209,14 @@ def test_available_gpu_kernels(self):
208209

209210
M, K, N = 128, 256, 512
210211
m = torch.nn.Sequential(
211-
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
212+
torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16)
212213
)
213214

214215
config = Int8DynamicActivationInt8WeightConfig(version=2)
215216
quantize_(m, config)
216217

217218
m = torch.compile(m)
218-
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
219+
x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
219220

220221
out, code = run_and_get_code(m, x)
221222

@@ -225,7 +226,7 @@ def test_available_gpu_kernels(self):
225226
).check_count("triton_poi_fused", 1).run(code[0])
226227

227228

228-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
229+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
229230
@common_utils.instantiate_parametrized_tests
230231
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
231232
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@@ -234,9 +235,9 @@ def test_static_activation_per_row_int8_weight(self, granularity, dtype):
234235
torch.compiler.reset()
235236

236237
M, N, K = 32, 32, 32
237-
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")
238+
input_tensor = torch.randn(M, K, dtype=dtype, device=_DEVICE)
238239

239-
model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
240+
model = torch.nn.Linear(K, N, bias=False).eval().to(device=_DEVICE, dtype=dtype)
240241
model_static_quant = copy.deepcopy(model)
241242
model_dynamic_quant = copy.deepcopy(model)
242243

0 commit comments

Comments
 (0)