@@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
7878 // MIRRORED == Each data value is held exactly once per thread subgroup.
7979 DATA_LAYOUT_I_MAJOR = 0 , // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
8080 DATA_LAYOUT_J_MAJOR = 10 , // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81- DATA_LAYOUT_I_MAJOR_MIRRORED = 20 ,
81+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20 , // Volta, matrix A&B for RDNA3.
8282 DATA_LAYOUT_J_MAJOR_MIRRORED = 30 ,
83- DATA_LAYOUT_I_MAJOR_DUAL = 40 , // Matrix A&B for RDNA3.
8483 };
8584 // Implemented mma combinations are:
8685 // - (I_MAJOR, I_MAJOR) -> I_MAJOR
8786 // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
8887 // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8988
90- constexpr bool is_i_major (const data_layout dl) {
89+ static constexpr bool is_i_major (const data_layout dl) {
9190 return dl == DATA_LAYOUT_I_MAJOR ||
92- dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
93- dl == DATA_LAYOUT_I_MAJOR_DUAL;
91+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
9492 }
9593
96- constexpr data_layout get_input_data_layout () {
97- #if defined(RDNA3)
98- return DATA_LAYOUT_I_MAJOR_DUAL ;
94+ static constexpr __device__ data_layout get_input_data_layout () {
95+ #if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96+ return DATA_LAYOUT_I_MAJOR_MIRRORED ;
9997#else
10098 return DATA_LAYOUT_I_MAJOR;
101- #endif // defined(RDNA3)
99+ #endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
102100 }
103101
104102 template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
@@ -462,31 +460,35 @@ namespace ggml_cuda_mma {
462460 }
463461 };
464462
465- template <int I_, int J_>
466- struct tile <I_, J_, half2 , DATA_LAYOUT_I_MAJOR_MIRRORED> {
463+ template <int I_, int J_, typename T >
464+ struct tile <I_, J_, T , DATA_LAYOUT_I_MAJOR_MIRRORED> {
467465 static constexpr int I = I_;
468466 static constexpr int J = J_;
469467 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
470- static constexpr int ne = I * J / (WARP_SIZE/4 );
471468
472- half2 x[ne] = {{0 .0f , 0 .0f }};
469+ // RDNA3
470+ static constexpr int ne = I * J / 32 * 2 ;
471+
472+ T x[ne] = {0 };
473473
474474 static constexpr __device__ bool supported () {
475- if (I == 8 && J == 4 ) return true ;
475+ if (I == 16 && J == 16 ) return true ;
476+ if (I == 16 && J == 8 ) return true ;
477+ if (I == 16 && J == 4 ) return true ;
476478 return false ;
477479 }
478480
479481 static __device__ __forceinline__ int get_i (const int /* l*/ ) {
480- if constexpr (I == 8 && J == 4 ) {
481- return (( threadIdx .x / 16 ) * 4 ) + ( threadIdx . x % 4 ) ;
482+ if constexpr (supported () ) {
483+ return threadIdx .x % 16 ;
482484 } else {
483485 NO_DEVICE_CODE;
484486 return -1 ;
485487 }
486488 }
487489
488490 static __device__ __forceinline__ int get_j (const int l) {
489- if constexpr (I == 8 && J == 4 ) {
491+ if constexpr (supported () ) {
490492 return l;
491493 } else {
492494 NO_DEVICE_CODE;
@@ -496,10 +498,27 @@ namespace ggml_cuda_mma {
496498 };
497499
498500 template <int I_, int J_>
499- struct tile <I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED > {
501+ struct tile <I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED > {
500502 static constexpr int I = I_;
501503 static constexpr int J = J_;
502- static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
504+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
505+ #if defined(RDNA3)
506+ static constexpr int ne = tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
507+
508+ half2 x[ne] = {{0 .0f , 0 .0f }};
509+
510+ static constexpr __device__ bool supported () {
511+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::supported ();
512+ }
513+
514+ static __device__ __forceinline__ int get_i (const int l) {
515+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i (l);
516+ }
517+
518+ static __device__ __forceinline__ int get_j (const int l) {
519+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j (l);
520+ }
521+ #else // Volta
503522 static constexpr int ne = I * J / (WARP_SIZE/4 );
504523
505524 half2 x[ne] = {{0 .0f , 0 .0f }};
@@ -509,9 +528,9 @@ namespace ggml_cuda_mma {
509528 return false ;
510529 }
511530
512- static __device__ __forceinline__ int get_i (const int l ) {
531+ static __device__ __forceinline__ int get_i (const int /* l */ ) {
513532 if constexpr (I == 8 && J == 4 ) {
514- return ((l / 2 ) * 4 ) + (threadIdx .x % 4 );
533+ return ((threadIdx . x / 16 ) * 4 ) + (threadIdx .x % 4 );
515534 } else {
516535 NO_DEVICE_CODE;
517536 return -1 ;
@@ -520,43 +539,63 @@ namespace ggml_cuda_mma {
520539
521540 static __device__ __forceinline__ int get_j (const int l) {
522541 if constexpr (I == 8 && J == 4 ) {
523- return (( threadIdx . x / 16 ) * 2 ) + (l % 2 ) ;
542+ return l ;
524543 } else {
525544 NO_DEVICE_CODE;
526545 return -1 ;
527546 }
528547 }
548+ #endif // defined(RDNA3)
529549 };
530550
531- template <int I_, int J_, typename T >
532- struct tile <I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL > {
551+ template <int I_, int J_>
552+ struct tile <I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED > {
533553 static constexpr int I = I_;
534554 static constexpr int J = J_;
535- static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
555+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
556+ static constexpr int ne = tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
536557
537- static constexpr int ne = I * J / 32 * 2 ;
558+ nv_bfloat162 x[ne] = {{ 0 . 0f , 0 . 0f }} ;
538559
539- T x[ne] = {0 };
560+ static constexpr __device__ bool supported () {
561+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::supported ();
562+ }
563+
564+ static __device__ __forceinline__ int get_i (const int l) {
565+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i (l);
566+ }
567+
568+ static __device__ __forceinline__ int get_j (const int l) {
569+ return tile<I_, J_, float , DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j (l);
570+ }
571+ };
572+
573+ template <int I_, int J_>
574+ struct tile <I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
575+ static constexpr int I = I_;
576+ static constexpr int J = J_;
577+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
578+ static constexpr int ne = I * J / (WARP_SIZE/4 );
579+
580+ half2 x[ne] = {{0 .0f , 0 .0f }};
540581
541582 static constexpr __device__ bool supported () {
542- if (I == 16 && J == 16 ) return true ;
543- if (I == 16 && J == 8 ) return true ;
544- if (I == 16 && J == 4 ) return true ;
583+ if (I == 8 && J == 4 ) return true ;
545584 return false ;
546585 }
547586
548587 static __device__ __forceinline__ int get_i (const int l) {
549- if constexpr (supported () ) {
550- return threadIdx .x % 16 ;
588+ if constexpr (I == 8 && J == 4 ) {
589+ return ((l / 2 ) * 4 ) + ( threadIdx .x % 4 ) ;
551590 } else {
552591 NO_DEVICE_CODE;
553592 return -1 ;
554593 }
555594 }
556595
557596 static __device__ __forceinline__ int get_j (const int l) {
558- if constexpr (supported () ) {
559- return l ;
597+ if constexpr (I == 8 && J == 4 ) {
598+ return (( threadIdx . x / 16 ) * 2 ) + (l % 2 ) ;
560599 } else {
561600 NO_DEVICE_CODE;
562601 return -1 ;
0 commit comments