2626from torchao .quantization .utils import compute_error , get_block_size
2727from torchao .testing .model_architectures import ToyTwoLinearModel
2828from torchao .testing .utils import TorchAOIntegrationTestCase
29- from torchao .utils import torch_version_at_least
29+ from torchao .utils import get_current_accelerator_device , torch_version_at_least
3030
3131INT8_TEST_CONFIGS = [
3232 Int8WeightOnlyConfig (version = 2 , granularity = PerTensor ()),
3838 version = 2 , granularity = PerRow (), act_mapping_type = MappingType .SYMMETRIC
3939 ),
4040]
41+ _DEVICE = get_current_accelerator_device ()
4142
4243
43- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
44+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
4445@common_utils .instantiate_parametrized_tests
4546class TestInt8Tensor (TorchAOIntegrationTestCase ):
4647 def setUp (self ):
@@ -60,7 +61,7 @@ def test_creation_and_attributes(self, config):
6061 self .test_shape [0 ],
6162 bias = False ,
6263 dtype = self .dtype ,
63- device = "cuda" ,
64+ device = _DEVICE ,
6465 )
6566 quantize_ (linear , config )
6667
@@ -99,8 +100,8 @@ def test_int8_linear_variants(
99100 torch .compiler .reset ()
100101
101102 M , N , K = sizes
102- input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
103- model = ToyTwoLinearModel (K , N , K , dtype = dtype , device = "cuda" ).eval ()
103+ input_tensor = torch .randn (* M , K , dtype = dtype , device = _DEVICE )
104+ model = ToyTwoLinearModel (K , N , K , dtype = dtype , device = _DEVICE ).eval ()
104105 model_q = copy .deepcopy (model )
105106
106107 quantize_ (model_q , config )
@@ -128,7 +129,7 @@ def test_int8_linear_variants(
128129 )
129130
130131 @common_utils .parametrize ("config" , INT8_TEST_CONFIGS )
131- @common_utils .parametrize ("device" , ["cpu" , "cuda" ])
132+ @common_utils .parametrize ("device" , ["cpu" , _DEVICE ])
132133 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
133134 def test_slice (self , config , device , dtype ):
134135 """Test tensor slicing with per-row quantization"""
@@ -159,8 +160,8 @@ def test_slice(self, config, device, dtype):
159160 def test_index_select (self , config ):
160161 """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
161162 N , K = 256 , 512
162- x = torch .randn (N , K , device = "cuda" , dtype = torch .bfloat16 )
163- linear = torch .nn .Linear (K , N , bias = False , dtype = torch .bfloat16 , device = "cuda" )
163+ x = torch .randn (N , K , device = _DEVICE , dtype = torch .bfloat16 )
164+ linear = torch .nn .Linear (K , N , bias = False , dtype = torch .bfloat16 , device = _DEVICE )
164165 linear .weight .data = x
165166
166167 quantize_ (linear , config )
@@ -187,7 +188,7 @@ def test_index_select(self, config):
187188 def test_dequantization_accuracy (self , config ):
188189 """Test dequantization accuracy separately"""
189190 linear = torch .nn .Linear (
190- 256 , 512 , bias = False , dtype = torch .bfloat16 , device = "cuda"
191+ 256 , 512 , bias = False , dtype = torch .bfloat16 , device = _DEVICE
191192 )
192193 weight_fp = copy .deepcopy (linear .weight )
193194 quantize_ (linear , config )
@@ -208,14 +209,14 @@ def test_available_gpu_kernels(self):
208209
209210 M , K , N = 128 , 256 , 512
210211 m = torch .nn .Sequential (
211- torch .nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 )
212+ torch .nn .Linear (K , N , device = _DEVICE , dtype = torch .bfloat16 )
212213 )
213214
214215 config = Int8DynamicActivationInt8WeightConfig (version = 2 )
215216 quantize_ (m , config )
216217
217218 m = torch .compile (m )
218- x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
219+ x = torch .randn (M , K , device = _DEVICE , dtype = torch .bfloat16 )
219220
220221 out , code = run_and_get_code (m , x )
221222
@@ -225,7 +226,7 @@ def test_available_gpu_kernels(self):
225226 ).check_count ("triton_poi_fused" , 1 ).run (code [0 ])
226227
227228
228- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
229+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
229230@common_utils .instantiate_parametrized_tests
230231class TestInt8StaticQuant (TorchAOIntegrationTestCase ):
231232 @common_utils .parametrize ("granularity" , [PerRow (), PerTensor ()])
@@ -234,9 +235,9 @@ def test_static_activation_per_row_int8_weight(self, granularity, dtype):
234235 torch .compiler .reset ()
235236
236237 M , N , K = 32 , 32 , 32
237- input_tensor = torch .randn (M , K , dtype = dtype , device = "cuda" )
238+ input_tensor = torch .randn (M , K , dtype = dtype , device = _DEVICE )
238239
239- model = torch .nn .Linear (K , N , bias = False ).eval ().to (device = "cuda" , dtype = dtype )
240+ model = torch .nn .Linear (K , N , bias = False ).eval ().to (device = _DEVICE , dtype = dtype )
240241 model_static_quant = copy .deepcopy (model )
241242 model_dynamic_quant = copy .deepcopy (model )
242243
0 commit comments