经常玩电子的朋友都知道Arduino是一个内存只有几十K的单片机,刚开始的时候我也觉得用Arduino来玩TensorFlow简直是天方夜谭,但是真是不得不佩服谷歌的强大,竟然专门为移动设备和LOT设备提供了专项适配,下面就跟着波波来学习下如何利用Arduino玩转Tensorflow吧。
文档地址:https://tensorflow.google.cn/lite/?hl=zh_cn
分享文档地址的作用是想学的朋友可以通过文档深入学习,毕竟本篇笔记波波只记录如何入门的过程,并不深入研究。
除了文档之外,TensorFlow的类库,谷歌已经为我们准备好了,大家可以打开编程工具找到【工具】->【管理库】如下图所示:
搜索【tensorflow】点击安装即可。想手动安装的朋友也可以下载github上面的类库。
因为万能的特色防火墙,大家在访问这些资源的时候可能会速度较慢,习惯就好了,耐心等待......
一切准备就绪后,就开始我们今天的HelloWorld吧,这个示例主要用来预测闪烁LED,示例源码如下:
- #include <TensorFlowLite.h>
- #include "main_functions.h"
- #include "constants.h"
- #include "output_handler.h"
- #include "sine_model_data.h"
- #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
- #include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
- #include "tensorflow/lite/experimental/micro/micro_interpreter.h"
- #include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow/lite/version.h"
- // Globals, used for compatibility with Arduino-style sketches.
- namespace {
- tflite::ErrorReporter* error_reporter = nullptr;
- const tflite::Model* model = nullptr;
- tflite::MicroInterpreter* interpreter = nullptr;
- TfLiteTensor* input = nullptr;
- TfLiteTensor* output = nullptr;
- int inference_count = 0;
- // Create an area of memory to use for input, output, and intermediate arrays.
- // Finding the minimum value for your model may require some trial and error.
- constexpr int kTensorArenaSize = 2 * 1024;
- uint8_t tensor_arena[kTensorArenaSize];
- } // namespace
- // The name of this function is important for Arduino compatibility.
- void setup() {
- // Set up logging. Google style is to avoid globals or statics because of
- // lifetime uncertainty, but since this has a trivial destructor it's okay.
- // NOLINTNEXTLINE(runtime-global-variables)
- static tflite::MicroErrorReporter micro_error_reporter;
- error_reporter = µ_error_reporter;
- // Map the model into a usable data structure. This doesn't involve any
- // copying or parsing, it's a very lightweight operation.
- model = tflite::GetModel(g_sine_model_data);
- if (model->version() != TFLITE_SCHEMA_VERSION) {
- error_reporter->Report(
- "Model provided is schema version %d not equal "
- "to supported version %d.",
- model->version(), TFLITE_SCHEMA_VERSION);
- return;
- }
- // This pulls in all the operation implementations we need.
- // NOLINTNEXTLINE(runtime-global-variables)
- static tflite::ops::micro::AllOpsResolver resolver;
- // Build an interpreter to run the model with.
- static tflite::MicroInterpreter static_interpreter(
- model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
- interpreter = &static_interpreter;
- // Allocate memory from the tensor_arena for the model's tensors.
- TfLiteStatus allocate_status = interpreter->AllocateTensors();
- if (allocate_status != kTfLiteOk) {
- error_reporter->Report("AllocateTensors() failed");
- return;
- }
- // Obtain pointers to the model's input and output tensors.
- input = interpreter->input(0);
- output = interpreter->output(0);
- // Keep track of how many inferences we have performed.
- inference_count = 0;
- }
- // The name of this function is important for Arduino compatibility.
- void loop() {
- // Calculate an x value to feed into the model. We compare the current
- // inference_count to the number of inferences per cycle to determine
- // our position within the range of possible x values the model was
- // trained on, and use this to calculate a value.
- float position = static_cast<float>(inference_count) /
- static_cast<float>(kInferencesPerCycle);
- float x_val = position * kXrange;
- // Place our calculated x value in the model's input tensor
- input->data.f[0] = x_val;
- // Run inference, and report any error
- TfLiteStatus invoke_status = interpreter->Invoke();
- if (invoke_status != kTfLiteOk) {
- error_reporter->Report("Invoke failed on x_val: %f\n",
- static_cast<double>(x_val));
- return;
- }
- // Read the predicted y value from the model's output tensor
- float y_val = output->data.f[0];
- // Output the results. A custom HandleOutput function can be implemented
- // for each supported hardware target.
- HandleOutput(error_reporter, x_val, y_val);
- // Increment the inference_counter, and reset it if we have reached
- // the total number per cycle
- inference_count += 1;
- if (inference_count >= kInferencesPerCycle) inference_count = 0;
- }
当然了由于一些限制,只能玩一些简单的语音、图像识别之类的,和其他基于数据模型的边缘计算。但是这已经很不容易了,可以为初级玩家打开一扇通往人工智能的大门。
更多精彩的创意,大家可以自由发挥,波波的分享先这么多,回头有时间了再分享我的DIY作品哈。