TensorFlow Lite for Microcontrollers 的使用和实现

本文深入解析了TensorFlow Lite Micro中的AI引擎实现,涵盖了模型加载、内存分配、OpResolver、MicroInterpreter的工作原理,以及关键组件如ErrorReporter、Model、OpResolver和MicroInterpreter的内部细节。

使用方法

参考tensorflow/lite/micro/examples/xxx 目录下的使用方法, 以hello_world为例,

文件hello_world_test.cc

1. 创建MicroErrorReporter object

tflite::MicroErrorReporter micro_error_reporter;

2.  有tflite model文件得到 tflite::Modle 结构体

const tflite::Model* model = ::tflite::GetModel(g_model);

3. 创建OpsResolver对象

这里是创建包含所有算子的对象

tflite::AllOpsResolver resolver;

如果是根据实际使用的算子创建OpsResolver, 使用

    static tflite::MicroMutableOpResolver<5> micro_op_resolver; 
    micro_op_resolver.AddConv2D();
    micro_op_resolver.AddDepthwiseConv2D();
    micro_op_resolver.AddFullyConnected();
    micro_op_resolver.AddMaxPool2D();
    micro_op_resolver.AddSoftmax();
 

4. 提供一段连续的内存,用于model的内存(placement new)

uint8_t tensor_arena[tensor_arena_size];

5. 创建MicroInterpreter对象

  MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
                   uint8_t* tensor_arena, size_t tensor_arena_size,
                   ErrorReporter* error_reporter,
                   tflite::Profiler* profiler = nullptr);
构造函数的参数类型是基类,而创建对象的参数是派生类的引用或指针,实现了多态性

tflite::MicroInterpreter interpreter(
        model, resolver, tensor_arena, tensor_arena_size, &micro_error_reporter);

6. 为所有的Tensor/ScratchBuffer分配内存,为所有支撑的变量分配内存等

interpreter.AllocateTensors();

//因内存的大小是个超参数常量,可以通过使用情况调整下,如参考arena_used_bytes()

interpreter.arena_used_bytes();

7. 得到输入对应的Tensor

TfLiteTensor* input = interpreter.input(0);

检测input tensor的一些属性

TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the

// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);

8. Provide an input value赋值input tensor


input->data.f[0] = 0.;

9. 进行推断

interpreter.Invoke();

10. 得到推断结果

// Obtain a pointer to the output tensor and make sure it has the
// properties we expect. It should be the same as the input tensor.
TfLiteTensor* output = interpreter.output(0);

//检测推断结果的属性
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
 
// Obtain the output value from the tensor
float value = output->data.f[0];

按上面的步骤,分析下代码实现, 尽量明白一个 AIEngine reference的实现。

ErrorReporter

ErrorReport class 是用来输出log, 各个平台如PC/ 串口打印log的具体实现不同,使用虚函数继承是合理的,接口类Class ErrorReporter:lite/core/api/error_report.h

虚函数是virtual int Report(const char*, va_list args), 普通成员函数调用虚函数也得到了跨平台的接口。

对外的接口是个宏: 其中... 表示所有输入

#define TF_LITE_REPORT_ERROR(reporter, ...)                             \
  do {                          &nbs

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值