1717from torchao .core .config import AOBaseConfig
1818from torchao .dtypes import (
1919 CutlassInt4PackedLayout ,
20- Int4CPULayout ,
21- Int4XPULayout ,
2220 PlainLayout ,
2321 SemiSparseLayout ,
2422 to_affine_quantized_intx ,
2826 Float8WeightOnlyConfig ,
2927 GemliteUIntXWeightOnlyConfig ,
3028 Int4DynamicActivationInt4WeightConfig ,
31- Int4WeightOnlyConfig ,
3229 Int8DynamicActivationInt4WeightConfig ,
3330 Int8DynamicActivationInt8WeightConfig ,
3431 Int8WeightOnlyConfig ,
3532 quantize_ ,
3633)
37- from torchao .quantization .quant_primitives import MappingType , ZeroPointDomain
38- from torchao .testing .utils import skip_if_no_cuda , skip_if_no_gemlite , skip_if_rocm
34+ from torchao .quantization .quant_primitives import MappingType
35+ from torchao .testing .utils import skip_if_no_gemlite , skip_if_rocm
3936from torchao .utils import (
4037 check_cpu_version ,
4138 check_xpu_version ,
@@ -62,24 +59,10 @@ def get_quantization_functions(
6259 ]
6360 if do_int4 :
6461 if check_cpu_version (device ):
65- base_functions .append (
66- Int4WeightOnlyConfig (group_size = 32 , layout = Int4CPULayout (), version = 1 )
67- )
62+ pass
6863 elif check_xpu_version (device ):
69- base_functions .append (
70- Int4WeightOnlyConfig (group_size = 32 , layout = Int4XPULayout (), version = 1 )
71- )
72- if int4_zp_int :
73- base_functions .append (
74- Int4WeightOnlyConfig (
75- group_size = 32 ,
76- layout = Int4XPULayout (),
77- zero_point_domain = ZeroPointDomain .INT ,
78- version = 1 ,
79- )
80- )
64+ pass
8165 else :
82- base_functions .append (Int4WeightOnlyConfig (group_size = 32 , version = 1 ))
8366 if device == "cuda" and not is_ROCM ():
8467 base_functions .append (
8568 Int8DynamicActivationInt4WeightConfig (
@@ -107,26 +90,6 @@ class TestAffineQuantized(TestCase):
10790 ["xpu" ] if torch .xpu .is_available () else []
10891 )
10992
110- @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
111- def test_tensor_core_layout_transpose (self ):
112- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
113- t = linear .weight
114- shape = t .shape
115- apply_int4_weight_only_quant = Int4WeightOnlyConfig (group_size = 32 , version = 1 )
116- quantize_ (linear , apply_int4_weight_only_quant )
117- ql = linear
118- aqt = ql .weight
119- aqt_shape = aqt .shape
120- self .assertEqual (aqt_shape , shape )
121-
122- # transpose shape test
123- for _ in range (10 ):
124- t = t .t ()
125- aqt = aqt .t ()
126- shape = t .shape
127- aqt_shape = aqt .shape
128- self .assertEqual (aqt_shape , shape )
129-
13093 @unittest .skipIf (len (GPU_DEVICES ) == 0 , "Need GPU available" )
13194 def test_weights_only (self ):
13295 for device in self .GPU_DEVICES :
@@ -338,20 +301,6 @@ def test_alias(self, device, dtype):
338301 quantize_ (dummy , Int8DynamicActivationInt8WeightConfig ())
339302 _ = dummy .weight [...]
340303
341- @common_utils .parametrize ("device" , [_DEVICE ])
342- @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
343- @skip_if_no_cuda ()
344- @skip_if_rocm ("ROCm enablement in progress" )
345- def test_slice_int4wo (self , device , dtype ):
346- # in_feature not divisible by 1024
347- # out_feature not divisible by 8
348- # to test slice + padding for int4 weight only quantization
349- dummy = nn .Linear (256 , 321 , dtype = dtype , device = device )
350- quantize_ (dummy , Int4WeightOnlyConfig (version = 1 ))
351- # make sure these run without error
352- _ = dummy .weight .narrow (0 , 0 , 64 )
353- _ = dummy .weight .narrow (1 , 0 , 128 )
354-
355304 @common_utils .parametrize ("device" , [_DEVICE ])
356305 @common_utils .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
357306 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
@@ -452,58 +401,6 @@ def test_matmul(self, device, dtype):
452401 # make sure it runs
453402 torch .matmul (x , w .t ())
454403
455- @common_utils .parametrize ("device" , [_DEVICE ])
456- @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
457- @skip_if_no_cuda ()
458- @skip_if_rocm ("ROCm enablement in progress" )
459- def test_slice_and_copy_int4wo (self , device , dtype ):
460- l = torch .nn .Linear (1024 , 1024 ).to (_DEVICE ).to (torch .bfloat16 )
461- l .weight = torch .nn .Parameter (
462- torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = _DEVICE )
463- )
464- quantize_ (l , Int4WeightOnlyConfig (version = 1 ))
465- param = l .weight
466- param_data = param .data
467- param_data = param_data .narrow (0 , 0 , 512 )
468- assert (
469- param .data .tensor_impl .packed_weight .data_ptr ()
470- == param_data .tensor_impl .packed_weight .data_ptr ()
471- )
472- assert (
473- param .data .tensor_impl .scale_and_zero .data_ptr ()
474- == param_data .tensor_impl .scale_and_zero .data_ptr ()
475- )
476- assert param .data .dequantize ()[0 ][0 ] == 0
477-
478- # dummy_l has random input (shouldn't be 0)
479- dummy_l = torch .nn .Linear (1024 , 1024 ).to (_DEVICE ).to (torch .bfloat16 )
480- quantize_ (dummy_l , Int4WeightOnlyConfig (version = 1 ))
481- quantized = dummy_l .weight
482- quantized = quantized .narrow (0 , 0 , 512 )
483-
484- param_data .copy_ (quantized )
485-
486- # making sure param.data is updated
487- assert param .data .dequantize ()[0 ][0 ] != 0
488-
489- @common_utils .parametrize ("device" , [_DEVICE ])
490- @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
491- @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
492- @skip_if_rocm ("ROCm enablement in progress" )
493- def test_mm_int4wo (self , device , dtype ):
494- weight = torch .randn (512 , 1024 ).to (device ).to (dtype )
495- weight = weight .t ()
496-
497- l = torch .nn .Linear (512 , 1024 ).to (device ).to (dtype )
498- l .weight = torch .nn .Parameter (weight )
499- quantize_ (l , Int4WeightOnlyConfig (version = 1 ))
500- # weight shape: 1024 x 512
501- weight = l .weight
502-
503- input = torch .randn (1 , 512 , device = device , dtype = dtype )
504- # make sure it runs
505- torch .nn .functional .linear (input , weight )
506-
507404
508405common_utils .instantiate_parametrized_tests (TestAffineQuantized )
509406common_utils .instantiate_parametrized_tests (TestAffineQuantizedBasic )
0 commit comments