diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp index 14f6f2e72a8..b181d2cd334 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp @@ -72,7 +72,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( } void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) { - mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux1_reg_bcast_data, ptr[rsp + reg_bcast_data_off]); mov(aux_reg_output_data, reg_output_data); mov(bcast_loop_iter, reg_bcast_loop_work); @@ -383,7 +383,6 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( apply_postops(load_loop_blk, ur, load_dim_tail); if (jcp.prop_kind == backward_weights && load_dim_tail > 0) { - push(reg_bcast_data); push(aux_reg_bcast_data); } @@ -433,7 +432,6 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( if (jcp.prop_kind == backward_weights && load_dim_tail > 0) { pop(aux_reg_bcast_data); - pop(reg_bcast_data); } }; @@ -576,6 +574,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { } mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); if (jcp.with_bias) { diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp index 3c2f0d70325..984b6641189 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp @@ -81,7 +81,8 @@ struct jit_avx2_1x1_conv_kernel_f32 : public jit_generator { constexpr static int reg_diff_bias_data_stack_offt = 0; constexpr static int reg_binary_post_op_acc_off = 1 * reg64_size_; constexpr static int reg_abi_param1_backup = 2 * reg64_size_; - constexpr static int stack_space_needed = 3 * reg64_size_; + constexpr static int reg_bcast_data_off = 3 * reg64_size_; + constexpr static int stack_space_needed = 4 * reg64_size_; ymm_t vreg_bcast = ymm_t(15); ymm_t vtmp = ymm_t(14);