Skip to content

Commit

Permalink
add xtensa bit width 4 decompression code
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 4, 2024
1 parent d1a281e commit eb85180
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tensorflow/lite/micro/micro_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"

#ifdef HIFI5
#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
#endif // HIFI5

namespace tflite {
namespace {

Expand Down Expand Up @@ -57,6 +61,10 @@ struct DecompressionState {
template <typename T>
T* DecompressToBuffer(void* buffer);

#ifdef HIFI5
void DecompressToBufferWidth4_Xtensa(int8_t* buffer);
#endif // HIFI5

void DecompressToBufferWidth4_16(int8_t* buffer);
void DecompressToBufferWidth3_32(int8_t* buffer);
void DecompressToBufferWidth2_16(int8_t* buffer);
Expand Down Expand Up @@ -476,6 +484,52 @@ void DecompressionState::DecompressToBufferWidthAny(T* buffer) {
}
}

#ifdef HIFI5
void DecompressionState::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);

char shuffle_pattern_1[8] = {0x08, 0x19, 0x2A, 0x3B, 0x4C, 0x5D, 0x6E, 0x7F};
ae_int8x8 d_shuffle_t = *(ae_int8x8*)&shuffle_pattern_1[0];

char shuffle_pattern_2[8] = {0xFB, 0x73, 0xEA, 0x62, 0xD9, 0x51, 0xC8, 0x40};
ae_int8x8 d_d_shuffle_t2 = *(ae_int8x8*)&shuffle_pattern_2[0];

ae_int8x8 d_out1, d_out2;
ae_int8x8 d_value_0, d_value_1;
ae_int8x8 d_index;

ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_;
ae_int8* p_out_tmp = (ae_int8*)buffer;

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* value_table =
static_cast<const uint8_t*>(comp_data_.data.lut_data->value_table);

#ifdef notdef
MicroPrintf("indices %p buffer %p value_table %p stride %u",
compressed_indices_, buffer, value_table, stride);
#endif

for (size_t i = 0; i < num_channels_; i++) {
ae_int8x8 d_value_0_t = *(ae_int8x8*)&value_table[0];
ae_int8x8 d_value_1_t = *(ae_int8x8*)&value_table[8];

AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, d_shuffle_t);

for (size_t j = 0; j < elements_per_channel_; j += 16) {
AE_L8X8_IP(d_index, pIn_tmp, 8);
AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index);
AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_d_shuffle_t2);
AE_S8X8X2_IP(d_out1, d_out2, (ae_int8x16*)p_out_tmp, 16);
}

value_table += stride;
}
}
#endif // HIFI5

template <typename T>
T* DecompressionState::DecompressToBuffer(void* buffer) {
TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
Expand All @@ -484,7 +538,16 @@ T* DecompressionState::DecompressToBuffer(void* buffer) {
if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 4 &&
!comp_data_.data.lut_data->use_alternate_axis) {
#ifdef HIFI5
if (!(elements_per_channel_ & 0x0F) &&
comp_data_.data.lut_data->value_table_channel_stride == 16) {
DecompressToBufferWidth4_Xtensa(static_cast<int8_t*>(buffer));
} else {
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
}
#else // HIFI5
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
#endif // HIFI5
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 2 &&
!comp_data_.data.lut_data->use_alternate_axis) {
Expand Down

0 comments on commit eb85180

Please sign in to comment.