diff --git a/kutacc/test/test_bgemm_ex.cpp b/kutacc/test/test_bgemm_ex.cpp index 7f7bc67b8a32bbb6c9086851ff3e48eba521b332..bd17bcb975883e8e183e25214223d35be1377ad5 100644 --- a/kutacc/test/test_bgemm_ex.cpp +++ b/kutacc/test/test_bgemm_ex.cpp @@ -11,6 +11,10 @@ #define MATRIX_COL 15 #define BF16_EPS 5.0 +static uint32_t tests_count = 0; +static uint32_t passed_count = 0; +static uint32_t failed_count = 0; + static void cal_leading_dimension(char transA, char transB, BLASINT *lda, BLASINT *ldb) { BLASINT m = MATRIX_ROW; @@ -48,7 +52,7 @@ static void init_expect_c_matrix(char transA, char transB, BLASINT m, BLASINT n, } } -static void test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) +static bool test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) { __bf16 bgemm_a[MATRIX_ROW * MATRIX_COL]; __bf16 bgemm_b[MATRIX_COL * MATRIX_ROW]; @@ -83,7 +87,7 @@ static void test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) if (prepack_mask & BLAS_EXTEND_PREPACK_A_MASK) { size_t size_a = kutacc_core_bgemm_pack_get_size('A', m, n, k); if (size_a != MATRIX_ROW * MATRIX_COL) { - printf("[FAILED] 1 tests\n"); + return false; } sa = (__bf16 *)malloc(size_a * sizeof(__bf16)); @@ -92,7 +96,7 @@ static void test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) } else if (prepack_mask & BLAS_EXTEND_PREPACK_B_MASK) { size_t size_b = kutacc_core_bgemm_pack_get_size('B', m, n, k); if (size_b != MATRIX_ROW * MATRIX_COL) { - printf("[FAILED] 1 tests\n"); + return false; } sb = (__bf16 *)malloc(size_b * sizeof(__bf16)); @@ -100,13 +104,13 @@ static void test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) kutacc_core_bgemm_ex(transA, transB, m, n, k, alpha, bgemm_a, lda, sb, ldb, beta, bgemm_c, ldc, &extend_param); } else { printf("input param prepack_mask error!\n"); - return; + return false; } // compare bgemm_ex result for (int i = 0; i < MATRIX_ROW * MATRIX_ROW; i++) { if (fabs(vcvtah_f32_bf16(bgemm_c[i]) - vcvtah_f32_bf16(expect_c[i])) > BF16_EPS) { - printf("[FAILED] 1 tests\n"); + return false; } } @@ -116,11 +120,28 @@ static void test_bgemm_ex(char transA, char transB, uint32_t prepack_mask) if (sb) { free(sb); } - printf("[PASSED] 1 tests\n"); + + return true; +} + +void test_kutacc_bgemm_ex() +{ + tests_count++; + printf("start test test_kutacc.kutacc_bgemm_ex_001\n"); + bool result = test_bgemm_ex('N', 'N', BLAS_EXTEND_PREPACK_A_MASK); + if (result == true) { + passed_count++; + printf("end test test_kutacc.kutacc_bgemm_ex_001 PASSED\n"); + } else { + failed_count++; + printf("end test test_kutacc.kutacc_bgemm_ex_001 FAILED\n"); + } } int main() { - test_bgemm_ex('N', 'N', BLAS_EXTEND_PREPACK_A_MASK); + test_kutacc_bgemm_ex(); + printf("%u tests ran, %u tests PASSED, %u test FAILED.\n", tests_count, passed_count, failed_count); + return 0; }