Skip to content

Commit

Permalink
backend: dnnl: remove pd in cache when fusing reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyu-intel authored and TaoLv committed Nov 1, 2022
1 parent 5925c1c commit 0887652
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/backend/dnnl/passes/lower_down.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,10 @@ impl::status_t fuse_adjacent_reorders(std::shared_ptr<subgraph_t> &sg) {
out_val->set_producer(*fused_op);

auto scratchpad_val = insert_empty_scratchpad(fused_op);
// remove pd in pd_cache since fused_op share the same id
if (pd_cache.find(fused_op.get()) != pd_cache.end()) {
pd_cache.erase(fused_op.get());
}
const auto &pd
= create_reorder_pd(fused_op, *p_engine, mgr, pd_cache)
.first;
Expand Down
130 changes: 130 additions & 0 deletions tests/cpp/unit/backend/dnnl/test_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,136 @@ TEST(ExecuteSubgraphInt8, BmmU8u8f32) {
}
}

TEST(ExecuteSubgraphInt8, BmmU8u8f32NonContiguous) {
impl::engine_t &engine = get_engine();
impl::stream_t &strm = get_stream();

// u8 to s8 shift by using a reroder with 128 zps is not supported on gpu
SKIP_IF(engine.kind() == impl::engine_kind::gpu, "skip on gpu");

std::string qtype = "per_tensor";
// prepare fp32 data
std::vector<int64_t> src_shape = {1, 1, 50, 50};
std::vector<int64_t> weight_shape = {1, 1, 50, 32};
std::vector<int64_t> weight_stride = {4800, 4800, 96, 1};
std::vector<int64_t> dst_shape = {1, 1, 50, 32};
// simulate non-contiguous case
std::vector<int64_t> weight_shape_large = {1, 1, 50, 96};

test::vector<uint8_t> src_data(product(src_shape));
test::vector<float> weight_data(product(weight_shape_large));

// random generate src, weight data
// random seed = 7
std::default_random_engine generator(7);
std::uniform_real_distribution<float> u8_distribution(0.0f, 255.0f);
std::generate(src_data.begin(), src_data.end(),
[&]() { return static_cast<uint8_t>(u8_distribution(generator)); });
std::uniform_real_distribution<float> f32_distribution(0.0f, 1.0f);
std::generate(weight_data.begin(), weight_data.end(),
[&]() { return static_cast<float>(f32_distribution(generator)); });
float scale_src = 1 / 255.f; // map to 0~255
int64_t zp_src = 0;

size_t scales_wei_sizes = qtype == "per_tensor" ? 1 : dst_shape.back();
std::vector<float> scale_wei(scales_wei_sizes, 1 / 255.f);
std::vector<int64_t> zp_wei(scales_wei_sizes, 114);

// -------------------------case 1----------------------------------
impl::op_t dqdata_op(1, impl::op_kind::Dequantize, "dqdata_op");
dqdata_op.set_attr<std::string>(impl::op_attr::qtype, "per_tensor");
dqdata_op.set_attr<std::vector<int64_t>>(impl::op_attr::zps, {zp_src});
dqdata_op.set_attr<std::vector<float>>(impl::op_attr::scales, {scale_src});
dqdata_op.set_attr<int64_t>(impl::op_attr::axis, 0);

impl::op_t qweight_op(2, impl::op_kind::Quantize, "qweight_op");
qweight_op.set_attr<std::string>(impl::op_attr::qtype, qtype);
qweight_op.set_attr<std::vector<int64_t>>(impl::op_attr::zps, zp_wei);
qweight_op.set_attr<std::vector<float>>(impl::op_attr::scales, scale_wei);
qweight_op.set_attr<int64_t>(impl::op_attr::axis, 1);

impl::op_t dqweight_op(3, impl::op_kind::Dequantize, "dqweight_op");
dqweight_op.set_attr<std::string>(impl::op_attr::qtype, qtype);
dqweight_op.set_attr<std::vector<int64_t>>(impl::op_attr::zps, zp_wei);
dqweight_op.set_attr<std::vector<float>>(impl::op_attr::scales, scale_wei);
dqweight_op.set_attr<int64_t>(impl::op_attr::axis, 1);

impl::op_t matmul_op(4, impl::op_kind::MatMul, "matmul_op");
matmul_op.set_attr<bool>(impl::op_attr::transpose_a, false);
matmul_op.set_attr<bool>(impl::op_attr::transpose_b, false);

// prepare logical tensor
impl::logical_tensor_t src_u8
= utils::logical_tensor_init(1, src_shape, impl::data_type::u8);
impl::logical_tensor_t src_f32_dq
= utils::logical_tensor_init(2, src_shape, impl::data_type::f32);
impl::logical_tensor_t weight_f32 = utils::logical_tensor_init(
3, weight_shape, weight_stride, impl::data_type::f32);
impl::logical_tensor_t weight_u8
= utils::logical_tensor_init(4, weight_shape, impl::data_type::u8);
impl::logical_tensor_t weight_f32_dq
= utils::logical_tensor_init(5, weight_shape, impl::data_type::f32);
impl::logical_tensor_t dst_f32
= utils::logical_tensor_init(6, dst_shape, impl::data_type::f32);

dqdata_op.add_input(src_u8);
dqdata_op.add_output(src_f32_dq);

qweight_op.add_input(weight_f32);
qweight_op.add_output(weight_u8);

dqweight_op.add_input(weight_u8);
dqweight_op.add_output(weight_f32_dq);

matmul_op.add_input(src_f32_dq);
matmul_op.add_input(weight_f32_dq);
matmul_op.add_output(dst_f32);

impl::graph_t g(engine.kind());
g.add_op(&dqdata_op);
g.add_op(&qweight_op);
g.add_op(&dqweight_op);
g.add_op(&matmul_op);
g.build_graph();

impl::tensor_t src_u8_ts(src_u8, &engine, src_data.data());
impl::tensor_t weight_f32_ts(weight_f32, &engine, weight_data.data());
// -------------------------case 1----------------------------------
test::vector<float> case1_out_data(product(dst_shape));
impl::tensor_t dst_f32_ts(dst_f32, &engine, case1_out_data.data());
ASSERT_EQ(run_graph(g, {src_u8_ts, weight_f32_ts}, {dst_f32_ts}, engine,
strm),
impl::status::success);
// -------------------------case 2----------------------------------
impl::pass::pass_base_ptr apass
= get_pass("int8_matmul_post_ops_fusion_cpu");
apass->run(g);
ASSERT_EQ(g.get_num_partitions(), 1U);
auto part = g.get_partitions()[0];

// compile
impl::partition_t p;
p.init(part);

impl::compiled_partition_t cp(p);

std::vector<const impl::logical_tensor_t *> lt_ins {&weight_f32, &src_u8};
std::vector<const impl::logical_tensor_t *> lt_outs {&dst_f32};

p.compile(&cp, lt_ins, lt_outs, &engine);

test::vector<float> case2_out_data(product(dst_shape));
impl::tensor_t dst_f32_case2_ts(dst_f32, &engine, case2_out_data.data());
cp.execute(&strm, {weight_f32_ts, src_u8_ts}, {dst_f32_case2_ts});
strm.wait();

static auto isa = dnnl_get_effective_cpu_isa();
if (isa >= dnnl_cpu_isa_avx512_core_vnni) {
ASSERT_TRUE(allclose(case1_out_data, case2_out_data, /*rtol*/ 0.01f,
/*atol*/ 1.f));
}
}

TEST(ExecuteSubgraphInt8, BmmDivU8u8f32) {
impl::engine_t &engine = get_engine();
impl::stream_t &strm = get_stream();
Expand Down

0 comments on commit 0887652

Please sign in to comment.