Skip to content

Commit 54189c0

Browse files
zhang-hui-yulozhang hui
andauthored
remove i_major_dual (#18157)
Co-authored-by: zhang hui <[email protected]>
1 parent 9ce64ae commit 54189c0

File tree

1 file changed

+73
-34
lines changed

1 file changed

+73
-34
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
7878
// MIRRORED == Each data value is held exactly once per thread subgroup.
7979
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
8080
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81-
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
81+
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
8282
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83-
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
8483
};
8584
// Implemented mma combinations are:
8685
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
8786
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
8887
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8988

90-
constexpr bool is_i_major(const data_layout dl) {
89+
static constexpr bool is_i_major(const data_layout dl) {
9190
return dl == DATA_LAYOUT_I_MAJOR ||
92-
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
93-
dl == DATA_LAYOUT_I_MAJOR_DUAL;
91+
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
9492
}
9593

96-
constexpr data_layout get_input_data_layout() {
97-
#if defined(RDNA3)
98-
return DATA_LAYOUT_I_MAJOR_DUAL;
94+
static constexpr __device__ data_layout get_input_data_layout() {
95+
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96+
return DATA_LAYOUT_I_MAJOR_MIRRORED;
9997
#else
10098
return DATA_LAYOUT_I_MAJOR;
101-
#endif // defined(RDNA3)
99+
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
102100
}
103101

104102
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
@@ -462,31 +460,35 @@ namespace ggml_cuda_mma {
462460
}
463461
};
464462

465-
template <int I_, int J_>
466-
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
463+
template <int I_, int J_, typename T>
464+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
467465
static constexpr int I = I_;
468466
static constexpr int J = J_;
469467
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
470-
static constexpr int ne = I * J / (WARP_SIZE/4);
471468

472-
half2 x[ne] = {{0.0f, 0.0f}};
469+
// RDNA3
470+
static constexpr int ne = I * J / 32 * 2;
471+
472+
T x[ne] = {0};
473473

474474
static constexpr __device__ bool supported() {
475-
if (I == 8 && J == 4) return true;
475+
if (I == 16 && J == 16) return true;
476+
if (I == 16 && J == 8) return true;
477+
if (I == 16 && J == 4) return true;
476478
return false;
477479
}
478480

479481
static __device__ __forceinline__ int get_i(const int /*l*/) {
480-
if constexpr (I == 8 && J == 4) {
481-
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
482+
if constexpr (supported()) {
483+
return threadIdx.x % 16;
482484
} else {
483485
NO_DEVICE_CODE;
484486
return -1;
485487
}
486488
}
487489

488490
static __device__ __forceinline__ int get_j(const int l) {
489-
if constexpr (I == 8 && J == 4) {
491+
if constexpr (supported()) {
490492
return l;
491493
} else {
492494
NO_DEVICE_CODE;
@@ -496,10 +498,27 @@ namespace ggml_cuda_mma {
496498
};
497499

498500
template <int I_, int J_>
499-
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
501+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
500502
static constexpr int I = I_;
501503
static constexpr int J = J_;
502-
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
504+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
505+
#if defined(RDNA3)
506+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
507+
508+
half2 x[ne] = {{0.0f, 0.0f}};
509+
510+
static constexpr __device__ bool supported() {
511+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
512+
}
513+
514+
static __device__ __forceinline__ int get_i(const int l) {
515+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
516+
}
517+
518+
static __device__ __forceinline__ int get_j(const int l) {
519+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
520+
}
521+
#else // Volta
503522
static constexpr int ne = I * J / (WARP_SIZE/4);
504523

505524
half2 x[ne] = {{0.0f, 0.0f}};
@@ -509,9 +528,9 @@ namespace ggml_cuda_mma {
509528
return false;
510529
}
511530

512-
static __device__ __forceinline__ int get_i(const int l) {
531+
static __device__ __forceinline__ int get_i(const int /*l*/) {
513532
if constexpr (I == 8 && J == 4) {
514-
return ((l / 2) * 4) + (threadIdx.x % 4);
533+
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
515534
} else {
516535
NO_DEVICE_CODE;
517536
return -1;
@@ -520,43 +539,63 @@ namespace ggml_cuda_mma {
520539

521540
static __device__ __forceinline__ int get_j(const int l) {
522541
if constexpr (I == 8 && J == 4) {
523-
return ((threadIdx.x / 16) * 2) + (l % 2);
542+
return l;
524543
} else {
525544
NO_DEVICE_CODE;
526545
return -1;
527546
}
528547
}
548+
#endif // defined(RDNA3)
529549
};
530550

531-
template <int I_, int J_, typename T>
532-
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
551+
template <int I_, int J_>
552+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
533553
static constexpr int I = I_;
534554
static constexpr int J = J_;
535-
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
555+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
556+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
536557

537-
static constexpr int ne = I * J / 32 * 2;
558+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
538559

539-
T x[ne] = {0};
560+
static constexpr __device__ bool supported() {
561+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
562+
}
563+
564+
static __device__ __forceinline__ int get_i(const int l) {
565+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
566+
}
567+
568+
static __device__ __forceinline__ int get_j(const int l) {
569+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
570+
}
571+
};
572+
573+
template <int I_, int J_>
574+
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
575+
static constexpr int I = I_;
576+
static constexpr int J = J_;
577+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
578+
static constexpr int ne = I * J / (WARP_SIZE/4);
579+
580+
half2 x[ne] = {{0.0f, 0.0f}};
540581

541582
static constexpr __device__ bool supported() {
542-
if (I == 16 && J == 16) return true;
543-
if (I == 16 && J == 8) return true;
544-
if (I == 16 && J == 4) return true;
583+
if (I == 8 && J == 4) return true;
545584
return false;
546585
}
547586

548587
static __device__ __forceinline__ int get_i(const int l) {
549-
if constexpr (supported()) {
550-
return threadIdx.x % 16;
588+
if constexpr (I == 8 && J == 4) {
589+
return ((l / 2) * 4) + (threadIdx.x % 4);
551590
} else {
552591
NO_DEVICE_CODE;
553592
return -1;
554593
}
555594
}
556595

557596
static __device__ __forceinline__ int get_j(const int l) {
558-
if constexpr (supported()) {
559-
return l;
597+
if constexpr (I == 8 && J == 4) {
598+
return ((threadIdx.x / 16) * 2) + (l % 2);
560599
} else {
561600
NO_DEVICE_CODE;
562601
return -1;

0 commit comments

Comments
 (0)