Skip to content

Commit

Permalink
src: cpu: conv: jit_avx512_core_x8s8s32x: dst scale: reuse vmm register
Browse files Browse the repository at this point in the history
  • Loading branch information
igorsafo committed May 25, 2023
1 parent bb3ecc4 commit 7fa3b6f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
10 changes: 3 additions & 7 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2022 Intel Corporation
* Copyright 2016-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -973,9 +973,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate() {

if (jcp.is_depthwise) {
bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point;
// dst zero point and dst scale reuse the same register
int idx = jcp.max_regs_ur - 1
+ nstl::max(2 * is_zero_point, static_cast<int>(jcp.dst_scale));
int idx = jcp.max_regs_ur - 1 + 2 * is_zero_point;
if (!jcp.is_resrc_depthwise) zmm_src = Zmm(++idx);
if (!jcp.has_vnni) zmm_tmp = Zmm(++idx);
if (jcp.is_fast_depthwise) zmm_permute = Zmm(++idx);
Expand All @@ -984,8 +982,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate() {
// and/or saturation, we increment by one more
if (jcp.signed_input || jcp.need_saturation) ++idx;

assert(IMPLICATION(!jcp.dst_scale && !is_zero_point
&& jcp.dst_dt != data_type::bf16,
assert(IMPLICATION(!is_zero_point && jcp.dst_dt != data_type::bf16,
idx == ker_dw_reg_base_idx));
}
if (!jcp.is_depthwise && (!jcp.has_vnni)) {
Expand Down Expand Up @@ -1498,7 +1495,6 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
// TODO: re-implement so that the JIT Kernel uses the least amount of
// registers. Currently, there are issues because of compile and run time
// definitions.
if (jcp.dst_scale) jcp.max_regs_ur = 26;
if (jcp.src_zero_point || jcp.dst_zero_point) jcp.max_regs_ur = 25;

auto set_or_check_wei_format = [&]() {
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2022 Intel Corporation
* Copyright 2016-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -131,7 +131,7 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
const Vmm vmm_zp_one = Vmm(26);
const Vmm vmm_zp_tmp = vmm_zp;

const Vmm vmm_dst_scale = Vmm(26);
const Vmm vmm_dst_scale = Vmm(31);

/* bf16 emulation */
Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
Expand Down

0 comments on commit 7fa3b6f

Please sign in to comment.