Skip to content

Commit

Permalink
add xtensa any bit width decompression code
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 4, 2024
1 parent eb85180 commit 487c17a
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion tensorflow/lite/micro/micro_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ struct DecompressionState {

#ifdef HIFI5
void DecompressToBufferWidth4_Xtensa(int8_t* buffer);

template <size_t N>
void DecompressToBufferWidthAny_Xtensa(int8_t* buffer);
#endif // HIFI5

void DecompressToBufferWidth4_16(int8_t* buffer);
Expand Down Expand Up @@ -485,6 +488,7 @@ void DecompressionState::DecompressToBufferWidthAny(T* buffer) {
}

#ifdef HIFI5

void DecompressionState::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
Expand Down Expand Up @@ -528,6 +532,38 @@ void DecompressionState::DecompressToBufferWidth4_Xtensa(int8_t* buffer) {
value_table += stride;
}
}

template <size_t N>
void DecompressionState::DecompressToBufferWidthAny_Xtensa(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* value_table =
static_cast<const uint8_t*>(comp_data_.data.lut_data->value_table);

short* p_stream = (short*)compressed_indices_;
uint32_t index;
ae_int8* p_out_tmp = (ae_int8*)buffer;

WUR_AE_BITPTR(0);
WUR_AE_BITHEAD(0);

AE_DBI_IP((const unsigned short*)p_stream, 16);
AE_DBI_IP((const unsigned short*)p_stream, 16);

for (size_t i = 0; i < num_channels_; i++) {
for (size_t j = 0; j < elements_per_channel_; j++) {
AE_LBI_DBI_IP((unsigned short*)p_stream, index, N);
ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index);
AE_S8_0_IP(d_tmp, p_out_tmp, 1);
}

value_table += stride;
}
}

#endif // HIFI5

template <typename T>
Expand All @@ -543,21 +579,54 @@ T* DecompressionState::DecompressToBuffer(void* buffer) {
comp_data_.data.lut_data->value_table_channel_stride == 16) {
DecompressToBufferWidth4_Xtensa(static_cast<int8_t*>(buffer));
} else {
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
DecompressToBufferWidthAny_Xtensa<4>(static_cast<int8_t*>(buffer));
}
#else // HIFI5
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
#endif // HIFI5
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 2 &&
!comp_data_.data.lut_data->use_alternate_axis) {
#ifdef HIFI5
DecompressToBufferWidthAny_Xtensa<2>(static_cast<int8_t*>(buffer));
#else // HIFI5
DecompressToBufferWidth2_16(static_cast<int8_t*>(buffer));
#endif // HIFI5
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 3 &&
!comp_data_.data.lut_data->use_alternate_axis) {
#ifdef HIFI5
DecompressToBufferWidthAny_Xtensa<3>(static_cast<int8_t*>(buffer));
#else // HIFI5
DecompressToBufferWidth3_32(static_cast<int8_t*>(buffer));
#endif // HIFI5
} else {
#ifdef HIFI5
if (std::is_same<T, int8_t>::value &&
!comp_data_.data.lut_data->use_alternate_axis) {
switch (compressed_bit_width_) {
case 1:
DecompressToBufferWidthAny_Xtensa<1>(static_cast<int8_t*>(buffer));
break;
case 4:
DecompressToBufferWidthAny_Xtensa<4>(static_cast<int8_t*>(buffer));
break;
case 5:
DecompressToBufferWidthAny_Xtensa<5>(static_cast<int8_t*>(buffer));
break;
case 6:
DecompressToBufferWidthAny_Xtensa<6>(static_cast<int8_t*>(buffer));
break;
case 7:
DecompressToBufferWidthAny_Xtensa<7>(static_cast<int8_t*>(buffer));
break;
}
} else {
DecompressToBufferWidthAny<T>(static_cast<T*>(buffer));
}
#else // HIFI5
DecompressToBufferWidthAny<T>(static_cast<T*>(buffer));
#endif // HIFI5
}

return static_cast<T*>(buffer);
Expand Down

0 comments on commit 487c17a

Please sign in to comment.