Skip to content

Commit

Permalink
fixes for optional bias tensor DCHECK.
Browse files Browse the repository at this point in the history
Fixes for FULLY_CONNECTED optional bias tensor.
  • Loading branch information
ddavis-2015 committed Dec 13, 2024
1 parent b6d8b75 commit 2335754
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
10 changes: 5 additions & 5 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand Down Expand Up @@ -194,7 +194,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
Expand All @@ -214,7 +214,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
Expand Down Expand Up @@ -248,7 +248,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
22 changes: 17 additions & 5 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {

#ifdef USE_TFLM_COMPRESSION

// Overloads existing GetTensorData. If not compressed, this will return
// Overloads existing GetOptionalTensorData. If not compressed, this will return
// tensor->data.
template <typename T>
const T* GetTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
const T* GetOptionalTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
if (tensor == nullptr) {
return nullptr;
}
Expand All @@ -128,6 +128,18 @@ const T* GetTensorData(MicroContext* micro_context,
return reinterpret_cast<const T*>(uncompressed_data);
}

// Overloads existing GetTensorData. If not compressed, this will return
// tensor->data.
template <typename T>
const T* GetTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
TFLITE_DCHECK(tensor != nullptr);
return GetOptionalTensorData<T>(micro_context, tensor, compression_data,
scratch_buffer_handle);
}

#endif // USE_TFLM_COMPRESSION

// Returns the shape of a TfLiteEvalTensor struct.
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand Down Expand Up @@ -119,7 +119,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8(

const int32_t* bias_data =
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int32_t>(micro_context, bias, bias_comp_td,
data.bias_scratch_index);
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index);
#else // USE_TFLM_COMPRESSION
tflite::micro::GetOptionalTensorData<int32_t>(bias);
#endif // USE_TFLM_COMPRESSION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context,
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor);
if (bias_comp_td != nullptr) {
TFLITE_DCHECK(bias != nullptr);
const size_t bias_data_size =
NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32);
bias_data = reinterpret_cast<int32_t*>(
Expand All @@ -144,6 +145,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context,
}
const TfLiteEvalTensor* bias_eval =
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
TFLITE_DCHECK(bias_eval != nullptr);
bias_data = static_cast<int32_t*>(micro_context->DecompressTensorToBuffer(
*bias_eval, *bias_comp_td, bias_data));
} else {
Expand Down

0 comments on commit 2335754

Please sign in to comment.