Skip to content

Commit

Permalink
Fixes to CONV for optional bias tensor when compression is enabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Dec 14, 2024
1 parent 01eb927 commit e54866c
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 21 deletions.
10 changes: 5 additions & 5 deletions tensorflow/lite/micro/kernels/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand All @@ -92,7 +92,7 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand All @@ -118,7 +118,7 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
Expand Down Expand Up @@ -162,7 +162,7 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/xtensa/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
op_data.reference_op_data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td,
op_data.reference_op_data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -193,12 +193,13 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node,
const int8_t* filter_data = tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.reference_op_data.weights_scratch_index);
const int64_t* bias_data = tflite::micro::GetTensorData<int64_t>(
const int64_t* bias_data = tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td,
data.reference_op_data.bias_scratch_index);
#else // USE_TFLM_COMPRESSION
const int8_t* filter_data = tflite::micro::GetTensorData<int8_t>(filter);
const int64_t* bias_data = tflite::micro::GetTensorData<int64_t>(bias);
const int64_t* bias_data =
tflite::micro::GetOptionalTensorData<int64_t>(bias);
#endif // USE_TFLM_COMPRESSION
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);

Expand Down Expand Up @@ -307,11 +308,12 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node,

const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
#ifdef USE_TFLM_COMPRESSION
const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(
const int32_t* bias_data = tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.reference_op_data.bias_scratch_index);
#else // USE_TFLM_COMPRESSION
const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(bias);
const int32_t* bias_data =
tflite::micro::GetOptionalTensorData<int32_t>(bias);
#endif // USE_TFLM_COMPRESSION
int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);

Expand Down
8 changes: 4 additions & 4 deletions tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,8 +68,8 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
op_data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(micro_context, bias, bias_comp_td,
op_data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, op_data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
Expand All @@ -94,7 +94,7 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) {
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,8 +81,8 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(filter), filter_data,
tflite::micro::GetTensorShape(bias),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int32_t>(micro_context, bias, bias_comp_td,
op_data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, op_data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/lite/micro/kernels/xtensa/conv_vision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* bias =
micro_context->AllocateTempInputTensor(node, kConvBiasTensor);
TF_LITE_ENSURE(context, bias != nullptr);
const uint32_t input_height = SizeOfDimension(input, 1);
const uint32_t input_width = SizeOfDimension(input, 2);

Expand All @@ -47,8 +49,10 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) {

TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteTensor* filter =
micro_context->AllocateTempInputTensor(node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);

const uint32_t output_height = SizeOfDimension(output, 1);
const uint32_t output_width = SizeOfDimension(output, 2);
Expand Down Expand Up @@ -212,9 +216,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) {
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
if (bias != nullptr) {
micro_context->DeallocateTempTfLiteTensor(bias);
}
micro_context->DeallocateTempTfLiteTensor(bias);
return kTfLiteOk;
}

Expand Down

0 comments on commit e54866c

Please sign in to comment.