@@ -647,20 +647,20 @@ def test_module_fqn_to_config_module_name(self):
647647 assert isinstance (model .linear2 .weight , AffineQuantizedTensor )
648648 assert isinstance (model .linear2 .weight ._layout , PlainLayout )
649649
650- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
650+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
651651 def test_module_fqn_to_config_regex_basic (self ):
652652 config1 = Int4WeightOnlyConfig (
653653 group_size = 32 , int4_packing_format = "tile_packed_to_4d"
654654 )
655655 config = ModuleFqnToConfig ({"re:linear." : config1 })
656- model = ToyLinearModel ().cuda ( ).to (dtype = torch .bfloat16 )
657- example_inputs = model .example_inputs (device = "cuda" , dtype = torch .bfloat16 )
656+ model = ToyLinearModel ().to ( _DEVICE ).to (dtype = torch .bfloat16 )
657+ example_inputs = model .example_inputs (device = _DEVICE , dtype = torch .bfloat16 )
658658 quantize_ (model , config , filter_fn = None )
659659 model (* example_inputs )
660660 assert isinstance (model .linear1 .weight , Int4TilePackedTo4dTensor )
661661 assert isinstance (model .linear2 .weight , Int4TilePackedTo4dTensor )
662662
663- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
663+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
664664 def test_module_fqn_to_config_regex_precedence (self ):
665665 """Testing that full path config takes precedence over
666666 regex config in ModuleFqnToConfig
@@ -670,14 +670,14 @@ def test_module_fqn_to_config_regex_precedence(self):
670670 )
671671 config2 = IntxWeightOnlyConfig ()
672672 config = ModuleFqnToConfig ({"linear1" : config1 , "re:linear." : config2 })
673- model = ToyLinearModel ().cuda ( ).to (dtype = torch .bfloat16 )
674- example_inputs = model .example_inputs (device = "cuda" , dtype = torch .bfloat16 )
673+ model = ToyLinearModel ().to ( _DEVICE ).to (dtype = torch .bfloat16 )
674+ example_inputs = model .example_inputs (device = _DEVICE , dtype = torch .bfloat16 )
675675 quantize_ (model , config , filter_fn = None )
676676 model (* example_inputs )
677677 assert isinstance (model .linear1 .weight , Int4TilePackedTo4dTensor )
678678 assert isinstance (model .linear2 .weight , IntxUnpackedToInt8Tensor )
679679
680- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
680+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
681681 def test_module_fqn_to_config_regex_precedence2 (self ):
682682 """Testing that full path config takes precedence over
683683 regex config in ModuleFqnToConfig, swapping
@@ -689,14 +689,14 @@ def test_module_fqn_to_config_regex_precedence2(self):
689689 )
690690 config2 = IntxWeightOnlyConfig ()
691691 config = ModuleFqnToConfig ({"re:linear." : config2 , "linear1" : config1 })
692- model = ToyLinearModel ().cuda ( ).to (dtype = torch .bfloat16 )
693- example_inputs = model .example_inputs (device = "cuda" , dtype = torch .bfloat16 )
692+ model = ToyLinearModel ().to ( _DEVICE ).to (dtype = torch .bfloat16 )
693+ example_inputs = model .example_inputs (device = _DEVICE , dtype = torch .bfloat16 )
694694 quantize_ (model , config , filter_fn = None )
695695 model (* example_inputs )
696696 assert isinstance (model .linear1 .weight , Int4TilePackedTo4dTensor )
697697 assert isinstance (model .linear2 .weight , IntxUnpackedToInt8Tensor )
698698
699- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
699+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
700700 def test_module_fqn_to_config_regex_fullmatch (self ):
701701 """Testing that we will only match the fqns that fully
702702 matches the regex
@@ -735,7 +735,7 @@ def example_inputs(self):
735735 "linear3_full_match.bias" : None ,
736736 }
737737 )
738- model = M (dtype = torch .bfloat16 , device = "cuda" )
738+ model = M (dtype = torch .bfloat16 , device = _DEVICE )
739739 example_inputs = model .example_inputs ()
740740 quantize_ (model , config , filter_fn = None )
741741 model (* example_inputs )
0 commit comments