diff --git a/common_level3.h b/common_level3.h index 39abe3016c..64d206a6d1 100644 --- a/common_level3.h +++ b/common_level3.h @@ -110,6 +110,31 @@ void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K, float beta, float * C, BLASLONG strideC); +void ssyr2k_direct_alpha_betaUN(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); +void ssyr2k_direct_alpha_betaUT(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); +void ssyr2k_direct_alpha_betaLN(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); +void ssyr2k_direct_alpha_betaLT(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); + int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, diff --git a/common_param.h b/common_param.h index 92bde3b3d7..b123092ef4 100644 --- a/common_param.h +++ b/common_param.h @@ -268,6 +268,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyr2k_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyr2k_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyr2k_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyr2k_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG); #endif diff --git a/common_s.h b/common_s.h index df61125f6e..f51d8d4520 100644 --- a/common_s.h +++ b/common_s.h @@ -60,6 +60,10 @@ #define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT #define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN #define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT +#define SSYR2K_DIRECT_ALPHA_BETA_UN ssyr2k_direct_alpha_betaUN +#define SSYR2K_DIRECT_ALPHA_BETA_UT ssyr2k_direct_alpha_betaUT +#define SSYR2K_DIRECT_ALPHA_BETA_LN ssyr2k_direct_alpha_betaLN +#define SSYR2K_DIRECT_ALPHA_BETA_LT ssyr2k_direct_alpha_betaLT #define SGEMM_ONCOPY sgemm_oncopy #define SGEMM_OTCOPY sgemm_otcopy @@ -240,6 +244,10 @@ #define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT #define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN #define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT +#define SSYR2K_DIRECT_ALPHA_BETA_UN gotoblas -> ssyr2k_direct_alpha_betaUN +#define SSYR2K_DIRECT_ALPHA_BETA_UT gotoblas -> ssyr2k_direct_alpha_betaUT +#define SSYR2K_DIRECT_ALPHA_BETA_LN gotoblas -> ssyr2k_direct_alpha_betaLN +#define SSYR2K_DIRECT_ALPHA_BETA_LT gotoblas -> ssyr2k_direct_alpha_betaLT #endif #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy diff --git a/interface/syr2k.c b/interface/syr2k.c index 47df7f89f0..635dd3ed0d 100644 --- a/interface/syr2k.c +++ b/interface/syr2k.c @@ -345,6 +345,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Tr return; } +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) +#if defined(ARCH_ARM64) && (defined(USE_SSYR2K_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) + if (support_sme1()) +#endif + if (args.n == 0) return; + if (order == CblasRowMajor && n == ldc) { + if (Trans == CblasNoTrans && k == lda && k == ldb) { + if (Uplo == CblasUpper) { + SSYR2K_DIRECT_ALPHA_BETA_UN(n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; + }else if (Uplo == CblasLower) { + SSYR2K_DIRECT_ALPHA_BETA_LN(n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; + } + } + else if (Trans == CblasTrans && n == lda && n ==ldb) { + if (Uplo == CblasUpper) { + SSYR2K_DIRECT_ALPHA_BETA_UT(n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; + }else if (Uplo == CblasLower) { + SSYR2K_DIRECT_ALPHA_BETA_LT(n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; + } + } + } +#endif +#endif + #endif if (args.n == 0) return; diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 3a638376c0..a8849400ee 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -249,6 +249,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) if (ARM64) set(USE_DIRECT_SSYRK true) endif() + set(USE_DIRECT_SSYR2K false) + if (ARM64) + set(USE_DIRECT_SSYR2K true) + endif() set(USE_DIRECT_SGEMM false) if (X86_64 OR ARM64) set(USE_DIRECT_SGEMM true) @@ -311,6 +315,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) endif () endif() + if (USE_DIRECT_SSYR2K) + if (ARM64) + set (SSYR2KDIRECTKERNEL_ALPHA_BETA ssyr2k_direct_alpha_beta_arm64_sme1.c) + GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUT" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLT" false "" "" false SINGLE) + endif () + endif() + foreach (float_type SINGLE DOUBLE) string(SUBSTRING ${float_type} 0 1 float_char) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6df9d78b1a..3329472b9b 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -55,6 +55,7 @@ USE_DIRECT_SGEMM = 1 USE_DIRECT_SSYMM = 1 USE_DIRECT_STRMM = 1 USE_DIRECT_SSYRK = 1 +USE_DIRECT_SSYR2K = 1 endif ifeq ($(ARCH), riscv64) @@ -173,6 +174,17 @@ endif endif endif +ifdef USE_DIRECT_SSYR2K +ifndef SSYR2KDIRECTKERNEL_ALPHA_BETA +ifeq ($(ARCH), arm64) +ifeq ($(TARGET_CORE), ARMV9SME) +HAVE_SME = 1 +endif +SSYR2KDIRECTKERNEL_ALPHA_BETA = ssyr2k_direct_alpha_beta_arm64_sme1.c +endif +endif +endif + ifeq ($(BUILD_BFLOAT16), 1) ifndef BGEMMKERNEL BGEMM_BETA = ../generic/gemm_beta.c @@ -280,6 +292,16 @@ SKERNELOBJS += \ endif endif +ifdef USE_DIRECT_SSYR2K +ifeq ($(ARCH), arm64) +SKERNELOBJS += \ + ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) \ + ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \ + ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) \ + ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) +endif +endif + ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" "" DKERNELOBJS += \ dgemm_beta$(TSUFFIX).$(SUFFIX) \ @@ -1193,6 +1215,20 @@ $(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIREC endif endif +ifdef USE_DIRECT_SSYR2K +ifeq ($(ARCH), arm64) +$(KDIR)ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@ +$(KDIR)ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@ +$(KDIR)ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@ +$(KDIR)ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@ + +endif +endif + ifdef USE_TRMM $(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL) ifeq ($(OS), AIX) diff --git a/kernel/arm64/ssyr2k_direct_alpha_beta_arm64_sme1.c b/kernel/arm64/ssyr2k_direct_alpha_beta_arm64_sme1.c new file mode 100755 index 0000000000..7cb715717c --- /dev/null +++ b/kernel/arm64/ssyr2k_direct_alpha_beta_arm64_sme1.c @@ -0,0 +1,286 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include +#endif + +/* Function prototypes */ +extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ + const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); + +/* Function Definitions */ +static uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 +// Outer product kernel. +// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. +__attribute__((always_inline)) inline void +kernel_2x2(const float *A, float *B_T, const float *B, float *A_T, float *C, size_t shared_dim, + size_t ldc, size_t block_rows, size_t block_cols, float alpha, + float beta, uint64_t row_idx, uint64_t col_idx) + __arm_out("za") __arm_streaming { + + const uint64_t svl = svcntw(); + size_t ldb = ldc; + // Predicate set-up + svbool_t pg = svptrue_b32(); + svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); + svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); + + svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); + svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); + +#define pg_c_0 pg_b_0 +#define pg_c_1 pg_b_1 + + svzero_za(); + svfloat32_t beta_vec = svdup_f32(beta); + + // Load C to ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); + } + for (size_t i = svl; i < block_rows; i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); + } + + svfloat32_t alpha_vec = svdup_f32(alpha); + // Iterate through shared dimension (K) + for (size_t k = 0; k < shared_dim; k++) { +#if !defined(TRANSA) + // Computes alpha*A*B**T + // Load column of A + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of B**T + svfloat32_t row_b_0 = svld1(pg_b_0, &B_T[k * svl]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B_T[(k + shared_dim) * svl]); +#else + // Computes alpha*A**T*B + // Load column of A**T + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * ldb]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + + svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * ldb + svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of B + svfloat32_t row_b_0 = svld1(pg_b_0, &B_T[k * ldb]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B_T[k * ldb + svl]); +#endif + // Perform outer product + svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); + svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); + svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); + svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); + +#if !defined(TRANSA) + // Computes alpha*B*A**T + // Load column of B + col_a_0 = svld1(pg_a_0, &B[k * svl]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + col_a_1 = svld1(pg_a_1, &B[(k + shared_dim) * svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of A**T + row_b_0 = svld1(pg_b_0, &A_T[k * svl]); + row_b_1 = svld1(pg_b_1, &A_T[(k + shared_dim) * svl]); +#else + // Computes alpha*B**T*A + // Load column of B**T + col_a_0 = svld1(pg_a_0, &B[k * ldb]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + + col_a_1 = svld1(pg_a_1, &B[k * ldb + svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of A + row_b_0 = svld1(pg_b_0, &A_T[k * ldb]); + row_b_1 = svld1(pg_b_1, &A_T[k * ldb + svl]); +#endif + // Perform outer product + svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); + svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); + svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); + svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); + } + +#if defined(UPPER) +#define pg_c_0_full pg_c_0 +#define pg_c_1_full pg_c_1 + + bool need_update_pg_b = true; + size_t last_invalid_index = col_idx - row_idx; + // For Upper, If col_idx - row_idx >= 2*svl, we don't need to update the predicate due to all elements above the digonal + if (col_idx - row_idx >= 2*svl) { + need_update_pg_b = false; + } + // Store to C from ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++, last_invalid_index++) { + if (need_update_pg_b) { + pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index)); + pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index)); + } + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++,last_invalid_index++) { + if (need_update_pg_b) { + pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index)); + pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index)); + } + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } +#else + // Store to C from ZA + size_t valid_index = row_idx - col_idx + 1; + for (size_t i = 0; i < MIN(svl, block_rows); i++, valid_index++) { + pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols)); + pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols)); + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++, valid_index++) { + pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols)); + pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols)); + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } +#endif +} + +__arm_new("za") __arm_locally_streaming +static void ssyr2k_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\ + const float *ba, const float *bb, const float* beta, float *restrict bc) { + const uint64_t num_rows = n; + const uint64_t num_cols = n; + + const float *restrict a_ptr = ba; + const float *restrict b_ptr = bb; + float *restrict c_ptr = bc; + + const uint64_t svl = svcntw(); + const uint64_t ldc = n; + + // Block over rows of C (panels of A) + uint64_t row_idx = 0; + + // 2x2 loop + uint64_t row_batch = 2*svl; + + // Block over row dimension of C + for (; row_idx < num_rows; row_idx += row_batch) { + row_batch = MIN(row_batch, num_rows - row_idx); + uint64_t col_batch = 2*svl; +#if defined(UPPER) + // for UPLO is upper, Start from column col_idx = rows_index to ensure we only process the upper triangle (col_idx >= rows_index) + for (uint64_t col_idx = row_idx; col_idx < num_cols; col_idx += col_batch) { + col_batch = MIN(col_batch, num_cols - col_idx); +#else + // for UPLO is lower, we only process the lower triangle part (col_idx <= row_idxx) + for (uint64_t col_idx = 0; col_idx < num_cols && col_idx <= row_idx; col_idx += col_batch) { +#endif + col_batch = MIN(col_batch, num_cols - col_idx); +#if !defined(TRANSA) + kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx * k], &b_ptr[row_idx * k], &a_ptr[col_idx * k], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx); +#else + kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx], &b_ptr[row_idx], &a_ptr[col_idx], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx); +#endif + + } + } + return; +} + +#else +static void ssyr2k_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\ + const float *ba, const float *bb, const float* beta, float *restrict bc){} +#endif + +void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A, \ + BLASLONG strideA, float * __restrict B, BLASLONG strideB, \ + float beta, float * __restrict R, BLASLONG strideR) +{ +#if !defined(TRANSA) + uint64_t n_mod, vl_elms; + + vl_elms = sve_cntw(); + + n_mod = ceil((double)N/(double)vl_elms) * vl_elms; + + float *A_mod = (float *) malloc(n_mod*K*sizeof(float)); + float *B_mod = (float *) malloc(n_mod*K*sizeof(float)); + + /* Prevent compiler optimization by reading from memory instead + * of reading directly from vector (z) registers. + * */ + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + /* Pre-process the left matrix to make it suitable for + matrix sum of outer-product calculation + */ + sgemm_direct_sme1_preprocess(N, K, A, A_mod); + sgemm_direct_sme1_preprocess(N, K, B, B_mod); + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + ssyr2k_direct_sme1_2VLx2VL(N, K, &alpha, A_mod, B_mod, &beta, R); + free(A_mod); + free(B_mod); +#else + ssyr2k_direct_sme1_2VLx2VL(N, K, &alpha, A, B, &beta, R); +#endif +} +#else + +void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A, \ + BLASLONG strideA, float * __restrict B, BLASLONG strideB, \ + float beta, float * __restrict C, BLASLONG strideC){} + +#endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index e60a5d65b3..e0699a1f9f 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -227,6 +227,10 @@ gotoblas_t TABLE_NAME = { ssyrk_direct_alpha_betaUTTS, ssyrk_direct_alpha_betaLNTS, ssyrk_direct_alpha_betaLTTS, + ssyr2k_direct_alpha_betaUNTS, + ssyr2k_direct_alpha_betaUTTS, + ssyr2k_direct_alpha_betaLNTS, + ssyr2k_direct_alpha_betaLTTS, #endif sgemm_kernelTS, sgemm_betaTS, diff --git a/param.h b/param.h index 8e598d8a01..574cd45053 100644 --- a/param.h +++ b/param.h @@ -3869,6 +3869,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #define USE_SSYMM_KERNEL_DIRECT 1 #define USE_STRMM_KERNEL_DIRECT 1 #define USE_SSYRK_KERNEL_DIRECT 1 +#define USE_SSYR2K_KERNEL_DIRECT 1 #endif /* ARMv9 SME */ #if defined(ARMV5)