Skip to content

Commit

Permalink
feat(compression): add work-in-progress compression and viewer tools
Browse files Browse the repository at this point in the history
  • Loading branch information
rkuester committed Oct 4, 2024
1 parent f6bd486 commit b773428
Show file tree
Hide file tree
Showing 12 changed files with 519 additions and 40 deletions.
26 changes: 26 additions & 0 deletions tensorflow/lite/micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,20 @@ cc_library(
],
)

cc_library(
name = "hexdump",
srcs = [
"hexdump.cc",
],
hdrs = [
"hexdump.h",
],
deps = [
":span",
":static_vector",
],
)

cc_library(
name = "recording_allocators",
srcs = [
Expand Down Expand Up @@ -556,6 +570,18 @@ cc_test(
],
)

cc_test(
name = "hexdump_test",
size = "small",
srcs = [
"hexdump_test.cc",
],
deps = [
":hexdump",
"//tensorflow/lite/micro/testing:micro_test",
],
)

cc_test(
name = "memory_helpers_test",
srcs = [
Expand Down
23 changes: 11 additions & 12 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ py_binary(
"@absl_py//absl/flags",
"@absl_py//absl/logging",
"@flatbuffers//:runtime_py",
requirement("bitarray"),
requirement("numpy"),
],
)
Expand All @@ -92,30 +93,28 @@ py_test(
)

py_binary(
name = "discretize",
name = "view",
srcs = [
"discretize.py",
"view.py",
],
deps = [
":metadata_py",
"//tensorflow/lite/python:schema_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@absl_py//absl/logging",
"@flatbuffers//:runtime_py",
requirement("numpy"),
requirement("bitarray"),
],
)

py_binary(
name = "view",
py_test(
name = "view_test",
size = "small",
srcs = [
"view.py",
"view_test.py",
],
deps = [
":metadata_py",
"//tensorflow/lite/python:schema_py",
"@absl_py//absl:app",
":test_models",
":view",
"@absl_py//absl/testing:absltest",
],
)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow/lite/micro/compression/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from functools import reduce
from typing import Sequence
import math
import os
import sys
import textwrap

from tflite_micro.tensorflow.lite.micro.compression import (
lib,
Expand Down Expand Up @@ -246,4 +249,9 @@ def main(argv):


if __name__ == "__main__":
name = os.path.basename(sys.argv[0])
usage = textwrap.dedent(f"""\
Usage: {name} <INPUT> <OUTPUT> [--tensors=<SPEC>] [--alt_axis_tensors=<SPEC>]
Compress a .tflite model.""")
sys.modules['__main__'].__doc__ = usage
absl.app.run(main)
5 changes: 4 additions & 1 deletion tensorflow/lite/micro/compression/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def build(spec: dict) -> bytearray:
A tflite flatbuffer.
"""
root = tflite.ModelT()
description = spec.get("description")
if description is not None:
root.description = description

root.operatorCodes = []
for id, operator_code in spec["operator_codes"].items():
Expand Down Expand Up @@ -57,7 +60,7 @@ def build(spec: dict) -> bytearray:
for id, tensor in subgraph["tensors"].items():
assert id == len(subgraph_t.tensors)
tensor_t = tflite.TensorT()
tensor_t.name = tensor.get("name", f"tensor{id}")
tensor_t.name = tensor.get("name", None)
tensor_t.shape = tensor["shape"]
tensor_t.type = tensor["type"]
tensor_t.buffer = tensor["buffer"]
Expand Down
87 changes: 60 additions & 27 deletions tensorflow/lite/micro/compression/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pprint
import bitarray
import bitarray.util
import pprint
import textwrap
import os
import sys

import lib
from tensorflow.lite.micro.compression import metadata_py_generated as compression_schema
Expand Down Expand Up @@ -53,22 +56,41 @@ def unpack_TensorType(type):
return lut[type]


def _decode_name(name):
"""Returns name as a str or 'None'.
The flatbuffer library returns names as bytes objects or None. This function
returns a str, decoded from the bytes object, or None.
"""
if name is None:
return None
else:
return str(name, encoding="utf-8")


def unpack_tensors(tensors):
result = []
for index, t in enumerate(tensors):
d = {
"_index": index,
"name": t.name.decode("utf-8"),
"_tensor": index,
"name": _decode_name(t.name),
"type": unpack_TensorType(t.type),
"variable": t.isVariable,
"shape": unpack_array(t.shape),
"buffer": t.buffer,
}
if t.quantization is not None:
d["quantization"] = [
unpack_array(t.quantization.scale),
unpack_array(t.quantization.zeroPoint)
]

if t.isVariable:
d["is_variable"] = True
else:
# don't display this unusual field
pass

if t.quantization is not None and t.quantization.scale is not None:
d["quantization"] = {
"scale": unpack_array(t.quantization.scale),
"zero": unpack_array(t.quantization.zeroPoint),
"dimension": t.quantization.quantizedDimension,
}
result.append(d)
return result

Expand All @@ -78,7 +100,7 @@ def unpack_subgraphs(subgraphs):
for index, s in enumerate(subgraphs):
d = {
"_index": index,
"name": s.name,
"name": _decode_name(s.name),
# "inputs": s.inputs,
# "outputs": s.outputs,
"operators": unpack_operators(s.operators),
Expand All @@ -92,7 +114,7 @@ def unpack_metadata(metadata):
if metadata is None:
return None
return [{
"name": m.name.decode("utf-8"),
"name": _decode_name(m.name),
"buffer": m.buffer
} for m in metadata]

Expand Down Expand Up @@ -157,8 +179,8 @@ def unpack_buffers(model, compression_metadata=None, unpacked_metadata=None):
buffers = model.buffers
result = []
for index, b in enumerate(buffers):
d = {"buffer": index}
d = d | {"bytes": len(b.data) if b.data is not None else 0}
d = {"_buffer": index}
d = d | {"_bytes": len(b.data) if b.data is not None else 0}
d = d | {"data": unpack_array(b.data)}
if index == compression_metadata:
if unpacked_metadata is not None:
Expand All @@ -184,12 +206,20 @@ def get_compression_metadata_buffer(model):
if model.metadata is None:
return None
for item in model.metadata:
if item.name.decode("utf-8") == "COMPRESSION_METADATA":
if _decode_name(item.name) == "COMPRESSION_METADATA":
return item.buffer
return None


def print_model(model, format=None):
def create_dictionary(flatbuffer: memoryview) -> dict:
"""Returns a human-readable dictionary from the provided model flatbuffer.
This function transforms a .tflite model flatbuffer into a Python dictionary.
When pretty-printed, this dictionary offers an easily interpretable view of
the model.
"""
model = tflite_schema.ModelT.InitFromPackedBuf(flatbuffer, 0)

comp_metadata_index = get_compression_metadata_buffer(model)
comp_metadata_unpacked = None
if comp_metadata_index is not None:
Expand All @@ -201,30 +231,33 @@ def print_model(model, format=None):

output = {
"description":
model.description.decode("utf-8"),
model.description,
"version":
model.version,
model.version,
"operator_codes":
unpack_list(model.operatorCodes),
unpack_list(model.operatorCodes),
"metadata":
unpack_metadata(model.metadata),
unpack_metadata(model.metadata),
"subgraphs":
unpack_subgraphs(model.subgraphs),
unpack_subgraphs(model.subgraphs),
"buffers":
unpack_buffers(model, comp_metadata_index, comp_metadata_unpacked),
unpack_buffers(model, comp_metadata_index, comp_metadata_unpacked),
}

pprint.pprint(output, width=90, sort_dicts=False, compact=True)
return output


def main(argv):
path = argv[1]
with open(path, 'rb') as file:
model = tflite_schema.ModelT.InitFromPackedBuf(file.read(), 0)

print_model(model)
with open(path, 'rb') as flatbuffer:
d = create_dictionary(memoryview(flatbuffer.read()))
pprint.pprint(d, width=90, sort_dicts=False, compact=True)


if __name__ == "__main__":
name = os.path.basename(sys.argv[0])
usage = textwrap.dedent(f"""\
Usage: {name} <MODEL>
Print a visualization of a .tflite model.""")
sys.modules['__main__'].__doc__ = usage
absl.app.run(main)
sys.exit(rc)
88 changes: 88 additions & 0 deletions tensorflow/lite/micro/compression/view_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import test_models
import view

_MODEL = {
"description": "Test model",
"operator_codes": {
0: {
"builtin_code": 0,
},
1: {
"builtin_code": 1,
},
},
"subgraphs": {
0: {
"operators": {
0: {
"opcode_index": 1,
"inputs": (
0,
1,
),
"outputs": (3, ),
},
1: {
"opcode_index": 0,
"inputs": (
3,
2,
),
"outputs": (4, ),
},
},
"tensors": {
0: {
"shape": (16, 1),
"type": 1,
"buffer": 1,
},
1: {
"shape": (16, 1),
"type": 1,
"buffer": 1,
},
},
},
},
"buffers": {
0: bytes(),
1: bytes(i for i in range(1, 16)),
}
}


class UnitTests(absltest.TestCase):

def testHelloWorld(self):
self.assertTrue(True)

def testSmokeTest(self):
flatbuffer = test_models.build(_MODEL)
view.create_dictionary(memoryview(flatbuffer))

def testStrippedDescription(self):
stripped = _MODEL.copy()
del stripped["description"]
flatbuffer = test_models.build(stripped)
view.create_dictionary(memoryview(flatbuffer))


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit b773428

Please sign in to comment.