Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,31 @@ void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
float beta,
float * C, BLASLONG strideC);

void ssyr2k_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);
void ssyr2k_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);
void ssyr2k_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);
void ssyr2k_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down
4 changes: 4 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyr2k_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyr2k_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyr2k_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyr2k_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
#endif


Expand Down
8 changes: 8 additions & 0 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT
#define SSYR2K_DIRECT_ALPHA_BETA_UN ssyr2k_direct_alpha_betaUN
#define SSYR2K_DIRECT_ALPHA_BETA_UT ssyr2k_direct_alpha_betaUT
#define SSYR2K_DIRECT_ALPHA_BETA_LN ssyr2k_direct_alpha_betaLN
#define SSYR2K_DIRECT_ALPHA_BETA_LT ssyr2k_direct_alpha_betaLT

#define SGEMM_ONCOPY sgemm_oncopy
#define SGEMM_OTCOPY sgemm_otcopy
Expand Down Expand Up @@ -240,6 +244,10 @@
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
#define SSYR2K_DIRECT_ALPHA_BETA_UN gotoblas -> ssyr2k_direct_alpha_betaUN
#define SSYR2K_DIRECT_ALPHA_BETA_UT gotoblas -> ssyr2k_direct_alpha_betaUT
#define SSYR2K_DIRECT_ALPHA_BETA_LN gotoblas -> ssyr2k_direct_alpha_betaLN
#define SSYR2K_DIRECT_ALPHA_BETA_LT gotoblas -> ssyr2k_direct_alpha_betaLT
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
Expand Down
29 changes: 29 additions & 0 deletions interface/syr2k.c
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Tr
return;
}

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_ARM64) && (defined(USE_SSYR2K_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (args.n == 0) return;
if (order == CblasRowMajor && n == ldc) {
if (Trans == CblasNoTrans && k == lda && k == ldb) {
if (Uplo == CblasUpper) {
SSYR2K_DIRECT_ALPHA_BETA_UN(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}else if (Uplo == CblasLower) {
SSYR2K_DIRECT_ALPHA_BETA_LN(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}
}
else if (Trans == CblasTrans && n == lda && n ==ldb) {
if (Uplo == CblasUpper) {
SSYR2K_DIRECT_ALPHA_BETA_UT(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}else if (Uplo == CblasLower) {
SSYR2K_DIRECT_ALPHA_BETA_LT(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}
}
}
#endif
#endif

#endif

if (args.n == 0) return;
Expand Down
14 changes: 14 additions & 0 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ARM64)
set(USE_DIRECT_SSYRK true)
endif()
set(USE_DIRECT_SSYR2K false)
if (ARM64)
set(USE_DIRECT_SSYR2K true)
endif()
set(USE_DIRECT_SGEMM false)
if (X86_64 OR ARM64)
set(USE_DIRECT_SGEMM true)
Expand Down Expand Up @@ -311,6 +315,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
endif ()
endif()

if (USE_DIRECT_SSYR2K)
if (ARM64)
set (SSYR2KDIRECTKERNEL_ALPHA_BETA ssyr2k_direct_alpha_beta_arm64_sme1.c)
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUT" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLT" false "" "" false SINGLE)
endif ()
endif()

foreach (float_type SINGLE DOUBLE)
string(SUBSTRING ${float_type} 0 1 float_char)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
Expand Down
36 changes: 36 additions & 0 deletions kernel/Makefile.L3
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ USE_DIRECT_SGEMM = 1
USE_DIRECT_SSYMM = 1
USE_DIRECT_STRMM = 1
USE_DIRECT_SSYRK = 1
USE_DIRECT_SSYR2K = 1
endif

ifeq ($(ARCH), riscv64)
Expand Down Expand Up @@ -173,6 +174,17 @@ endif
endif
endif

ifdef USE_DIRECT_SSYR2K
ifndef SSYR2KDIRECTKERNEL_ALPHA_BETA
ifeq ($(ARCH), arm64)
ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
endif
SSYR2KDIRECTKERNEL_ALPHA_BETA = ssyr2k_direct_alpha_beta_arm64_sme1.c
endif
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)
ifndef BGEMMKERNEL
BGEMM_BETA = ../generic/gemm_beta.c
Expand Down Expand Up @@ -280,6 +292,16 @@ SKERNELOBJS += \
endif
endif

ifdef USE_DIRECT_SSYR2K
ifeq ($(ARCH), arm64)
SKERNELOBJS += \
ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) \
ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) \
ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
endif
endif

ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
DKERNELOBJS += \
dgemm_beta$(TSUFFIX).$(SUFFIX) \
Expand Down Expand Up @@ -1193,6 +1215,20 @@ $(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIREC
endif
endif

ifdef USE_DIRECT_SSYR2K
ifeq ($(ARCH), arm64)
$(KDIR)ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@
$(KDIR)ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@
$(KDIR)ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@
$(KDIR)ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@

endif
endif

ifdef USE_TRMM
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)
ifeq ($(OS), AIX)
Expand Down
Loading
Loading