diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index a8116441caa..864169c19ad 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -889,11 +889,10 @@ void create_arm_gemm_dequant(std::unique_ptr // Create arm_gemm fallback auto fallback = std::make_unique>(); - // Configure requantization info - const GEMMLowpOutputStageInfo os_info = info.output_stage; - arm_gemm::DequantizeFloat gemm_dequant_info{}; - gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale); + gemm_dequant_info.scale = a->quantization_info().uniform().scale * b->quantization_info().uniform().scale; + gemm_dequant_info.a_offset = info.dequant_a_offset; + gemm_dequant_info.b_offset = info.dequant_b_offset; fallback->configure(a, b, c, d, args, info, gemm_dequant_info); arm_gemm = std::move(fallback); @@ -1020,6 +1019,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } + else if (d->data_type() == DataType::F32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, + args, {})), + "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and F32 output"); + } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( @@ -1130,6 +1136,11 @@ Status CpuGemmAssemblyDispatch::validate( a->data_type() == DataType::QASYMM8 && (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32), "Only QASYMM8/S32/F32 output supported for QASYMM8 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + a->data_type() == DataType::QASYMM8_SIGNED && + (d->data_type() != DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32 && + d->data_type() != DataType::F32), + "Only QASYMM8_SIGNED/S32/F32 output supported for QASYMM8_SIGNED input"); arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED; const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY) diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 9b4b15d0dbc..4131c07a7d3 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2025 Arm Limited. + * Copyright (c) 2018-2026 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -65,7 +65,9 @@ struct AsmGemmInfo * @note This flag will be silently ignored (assumed to be false) when the weight_format is a fixed format. Because * fixed format kernels do not accept weights (B) with any prior transformations */ - bool transpose_b{false}; + bool transpose_b{false}; + int32_t dequant_a_offset{0}; // input zero-point for DequantizeFloat path (handled in kernel) + int32_t dequant_b_offset{0}; // weight zero-point for DequantizeFloat path (handled in kernel) }; /** Assembly kernel glue */