Skip to content

Commit

Permalink
cpu: x64: jit_avx2_1x1_conv: fix invalid imm32 values in kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kwiersch authored and vpirogov committed Nov 18, 2022
1 parent b60633f commit 2ba2523
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
auto bcast_ptr = [=](int u, int j) {
assert(j < jcp.ur);
assert(u <= jcp.reduce_loop_unroll);
return ptr[aux_reg_bcast_data + get_bcast_offset(jcp, u, j)];
const size_t offset = get_bcast_offset(jcp, u, j);
return make_safe_addr(aux_reg_bcast_data, offset, reg_long_offt);
};

auto get_load_offset_bwd_w = [=](int u, int i) {
Expand Down Expand Up @@ -278,10 +279,8 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
: 0) // TODO: Xbyak should allow 0 scale
+ sizeof(float) * jcp.oc_block * j];
default:
return ptr[aux_reg_output_data
+ (i * get_output_i_offset(jcp)
+ j * get_output_j_offset(jcp))
* sizeof(float)];
const size_t off = get_output_offset(i, j);
return make_safe_addr(aux_reg_output_data, off, reg_long_offt);
}
};

Expand Down Expand Up @@ -492,7 +491,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
L(reduce_loop);
{
fma_block(false);
add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
safe_add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step, reg_long_offt);
add(aux_reg_load_data, jcp.reduce_loop_load_step);
sub(reduce_loop_iter, jcp.reduce_loop_unroll);
jg(reduce_loop, T_NEAR);
Expand Down Expand Up @@ -599,8 +598,9 @@ void jit_avx2_1x1_conv_kernel_f32::generate() {
case forward_inference:
add(reg_bias_data,
load_loop_blk * jcp.oc_block * sizeof(float));
add(reg_output_data,
get_load_loop_output_fwd_offset(jcp, load_loop_blk));
safe_add(reg_output_data,
get_load_loop_output_fwd_offset(jcp, load_loop_blk),
reg_long_offt);
if (jcp.with_binary) {
mov(aux_reg_load_data,
ptr[rsp + reg_binary_post_op_acc_off]);
Expand All @@ -610,8 +610,9 @@ void jit_avx2_1x1_conv_kernel_f32::generate() {
}
break;
case backward_data:
add(reg_output_data,
get_load_loop_output_bwd_d_offset(jcp, load_loop_blk));
safe_add(reg_output_data,
get_load_loop_output_bwd_d_offset(jcp, load_loop_blk),
reg_long_offt);
break;
case backward_weights:
for (int i = 0; i < load_loop_blk; i++)
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct jit_avx2_1x1_conv_kernel_f32 : public jit_generator {
reg64_t reg_tmp_output_stride = reg_bcast_data;
reg64_t reg_tmp = aux_reg_bcast_data;
reg64_t reg_output_stride_scale = load_loop_iter;
reg64_t reg_long_offt = reg_bcast_data;

constexpr static int reg64_size_ = sizeof(int64_t);
constexpr static int reg_diff_bias_data_stack_offt = 0;
Expand Down

0 comments on commit 2ba2523

Please sign in to comment.