Skip to content

Commit

Permalink
gpu: jit: gemm: remove unnecessary type conversions with sum post-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Feb 8, 2023
1 parent dbb7c28 commit a1e6bc5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
18 changes: 16 additions & 2 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
* Copyright 2019-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -6582,7 +6582,7 @@ void gemm_kernel_generator_t<hw>::outerProductSystolic(int h, int ha, int hb,
// Decide whether to use the legacy post-op injector inside C update.
// Needed if we can't convert C to f32 in-place, but doesn't support binary post-ops.
static inline bool useEltwiseInjector(const GEMMProblem &problem) {
return problem.hasPostOp() && (problem.Tc.size() < 4);
return problem.hasNonSum1PostOp() && (problem.Tc.size() < 4);
}

// Perform C update operation on C_acc, given original C data in C_load.
Expand Down Expand Up @@ -14546,6 +14546,12 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
subproblem.beta_real = 1;
subproblem.beta_imag = 0;

if (subproblem.postOps.len() > 0) {
auto &lastPO = subproblem.postOps
.entry_[subproblem.postOps.len() - 1];
if (lastPO.kind == primitive_kind::sum) lastPO.sum.scale = 1.0f;
}

if (!gemmUpdateC(subproblem, strategy, substate)) return false;

if (checkBeta0) {
Expand All @@ -14566,6 +14572,14 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
subproblem.beta_real = 0;
subproblem.beta_imag = 0;

if (subproblem.postOps.len() > 0) {
auto &lastPO = subproblem.postOps
.entry_[subproblem.postOps.len() - 1];
if (lastPO.kind == primitive_kind::sum)
subproblem.postOps.entry_.resize(
subproblem.postOps.len() - 1);
}

substrategy.C.atomic = false;

if (!gemmUpdateC(subproblem, substrategy, substate)) return false;
Expand Down
9 changes: 7 additions & 2 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
* Copyright 2019-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -817,6 +817,11 @@ struct GEMMProblem : public CommonProblem {
std::vector<bool> binaryBatch;

bool hasPostOp() const { return postOps.len() > 0; }
bool hasNonSum1PostOp() const {
for (const auto &e : postOps.entry_)
if (!e.is_sum()) return true;
return false;
}
bool hasBinaryPostOp() const {
for (int idx = 0; idx < postOps.len(); idx++)
if (postOps.entry_[idx].is_binary()) return true;
Expand All @@ -840,7 +845,7 @@ struct GEMMProblem : public CommonProblem {
if (!(alpha1() || alphaM1())) return true;
if (!(beta0() || beta1())) return true;
if (beta1() && !Tc_ext.isSubsetOf(Tc)) return true;
if (hasPostOp()) return true;
if (hasNonSum1PostOp()) return true;
return false;
}

Expand Down

0 comments on commit a1e6bc5

Please sign in to comment.