Skip to content

Commit

Permalink
feat(compression): implement tensor decompression in op concatenation (
Browse files Browse the repository at this point in the history
…#3014)

Implement tensor decompression in op concatenation. Extend
tests to validate operation on compressed tensors.

BUG=part of #2636
  • Loading branch information
rkuester authored Dec 16, 2024
1 parent 9a32964 commit 50e7e5d
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 57 deletions.
112 changes: 62 additions & 50 deletions tensorflow/lite/micro/kernels/concatenation.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;

struct OpData {
ConcatenationParams params;

#ifdef USE_TFLM_COMPRESSION

// scratch buffers for compressed tensors
int scratch_indices[kMaxInputNum];

#endif // USE_TFLM_COMPRESSION
};

// Handles negative axis index, coerces to positive index value.
Expand All @@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
inline void GetAllInputTensorShapes(const TfLiteContext* context,
const TfLiteNode* node,
RuntimeShape all_shapes[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
RuntimeShape shape = tflite::micro::GetTensorShape(t);
Expand All @@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
template <typename T>
inline void GetAllInputTensorData(const TfLiteContext* context,
const TfLiteNode* node,
T* all_data[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
const T* all_data[kMaxInputNum]) {
#ifdef USE_TFLM_COMPRESSION
const OpData* data = static_cast<const OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
#endif // USE_TFLM_COMPRESSION

for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
#ifdef USE_TFLM_COMPRESSION
const CompressionTensorData* comp_td =
micro_context->GetTensorCompressionData(node, i);
all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
data->scratch_indices[i]);
#else // USE_TFLM_COMPRESSION
all_data[i] = tflite::micro::GetTensorData<T>(t);
#endif // USE_TFLM_COMPRESSION
}
}

Expand All @@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape inputs_shape[kMaxInputNum];
const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
const data_type* inputs_data[kMaxInputNum];
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);
GetAllInputTensorShapes(context, node, inputs_shape);
GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
GetAllInputTensorData(context, node, inputs_data);

TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);

reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<data_type>(output));
Expand Down Expand Up @@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteType output_type = output_tensor->type;

micro_context->DeallocateTempTfLiteTensor(input_tensor);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

// Check activation and input type
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
Expand All @@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt64 || input_type == kTfLiteBool);

// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, output_type, input_type);

// This implementation does not support large number of input tensors
const int num_inputs = NumInputs(node);
TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);

// Shapes with dimensions >4 are not yet supported with static allocation.
// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

// Shapes with dimensions > kMaxSmallSize are not yet supported with static
// allocation.
for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, input_type);
int num_dimensions = NumDimensions(input);

if (num_dimensions > RuntimeShape::kMaxSmallSize) {
Expand All @@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}

if (input_type == kTfLiteInt8) {
// Make sure there is no re-scaling needed for Int8 quantized kernel. This
// is a restriction we introduced to Int8 kernels.
TF_LITE_ENSURE_EQ(context, static_cast<double>(input->params.scale),
static_cast<double>(output_tensor->params.scale));
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output_tensor->params.zero_point);
} else if (input_type == kTfLiteInt16) {
// Make sure that all Int16 inputs have a null zero-point.
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
}

#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
data->scratch_indices[i] =
micro_context->AllocateDecompressionScratchBuffer(node, i);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(input);
}

// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output_tensor->params.zero_point, 0);
}

switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt32:
case kTfLiteInt64: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.inputs_count = node->inputs->size;
break;
}
case kTfLiteInt8: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.axis = CalculatePositiveAxis(params->axis, output_tensor);
data->params.inputs_count = node->inputs->size;

float* input_scales =
reinterpret_cast<float*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(float)));

int32_t* input_zero_points =
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(int32_t)));

// Allocate persistent scale and zeropoint buffers.
// Store input scale and zero point values in OpParams:
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, t != nullptr);
input_scales[i] = t->params.scale;
input_zero_points[i] = t->params.zero_point;
micro_context->DeallocateTempTfLiteTensor(t);
}

data->params.input_scale = input_scales;
data->params.input_zeropoint = input_zero_points;
data->params.output_zeropoint = output->params.zero_point;
data->params.output_scale = output->params.scale;
break;
}
default:
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
MicroPrintf("Op Concatenation does not currently support type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}

micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

return kTfLiteOk;
}
Expand Down
Loading

0 comments on commit 50e7e5d

Please sign in to comment.