Skip to content

Commit

Permalink
feat(compression): implement tensor decompression in op depthwise conv (
Browse files Browse the repository at this point in the history
#3017)

Implement tensor decompression in op depthwise conv. Extend tests
to validate operation on compressed tensors.

BUG=part of #2636
  • Loading branch information
rkuester authored Dec 16, 2024
1 parent 099774d commit b1d8a08
Show file tree
Hide file tree
Showing 7 changed files with 581 additions and 28 deletions.
41 changes: 40 additions & 1 deletion tensorflow/lite/micro/kernels/depthwise_conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,16 +52,37 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
: nullptr;

#ifdef USE_TFLM_COMPRESSION

MicroContext* micro_context = GetMicroContext(context);

const CompressionTensorData* filter_comp_td =
micro_context->GetTensorCompressionData(node,
kDepthwiseConvWeightsTensor);
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);

#endif // USE_TFLM_COMPRESSION

switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: {
tflite::reference_ops::DepthwiseConv(
DepthwiseConvParamsFloat(params, data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
Expand Down Expand Up @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
Expand All @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
Expand Down
23 changes: 21 additions & 2 deletions tensorflow/lite/micro/kernels/depthwise_conv_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv(

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
micro_context->DeallocateTempTfLiteTensor(bias);
if (has_bias) {
micro_context->DeallocateTempTfLiteTensor(bias);
}
micro_context->DeallocateTempTfLiteTensor(output);

return kTfLiteOk;
Expand Down Expand Up @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));

#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) &&
filter->type == kTfLiteInt4) {
MicroPrintf("Compression not supported with INT4 tensors");
return kTfLiteError;
}
data->weights_scratch_index =
micro_context->AllocateDecompressionScratchBuffer(
node, kDepthwiseConvWeightsTensor);
data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
node, kDepthwiseConvBiasTensor);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
Expand Down
Loading

0 comments on commit b1d8a08

Please sign in to comment.