From d9bfa9412a2468521f24f1f0261035e6db75642c Mon Sep 17 00:00:00 2001 From: Woosung Cho <wscho@ajou.ac.kr> Date: Sun, 22 Dec 2024 21:12:48 +0900 Subject: [PATCH] arduino code --- easy_mnist.ino | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 easy_mnist.ino diff --git a/easy_mnist.ino b/easy_mnist.ino new file mode 100644 index 0000000..b474e61 --- /dev/null +++ b/easy_mnist.ino @@ -0,0 +1,79 @@ +#include <Chirale_TensorFlowLite.h> + +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/schema/schema_generated.h" + +#include "arduino_model.h" +#include "mnist_images.h" + +const tflite::Model* model = nullptr; +tflite::MicroInterpreter* interpreter = nullptr; +TfLiteTensor* input = nullptr; +TfLiteTensor* output = nullptr; + +constexpr int kTensorArenaSize = 20*1024; +alignas(16) uint8_t tensor_arena[kTensorArenaSize]; + +void setup() { + Serial.begin(9600); + while(!Serial); + + model = tflite::GetModel(arduino_model); + + if (model->version() != TFLITE_SCHEMA_VERSION) { + Serial.println("Model provided and schema version are not equal!"); + while(true); + } + + static tflite::AllOpsResolver resolver; + + static tflite::MicroInterpreter static_interpreter( + model, resolver, tensor_arena, kTensorArenaSize); + interpreter = &static_interpreter; + + TfLiteStatus allocate_status = interpreter->AllocateTensors(); + if (allocate_status != kTfLiteOk) { + Serial.println("AllocateTensors() failed"); + while(true); + } + + input = interpreter->input(0); + output = interpreter->output(0); + + Serial.println("Ready for input data."); +} + +void loop() { + + if (Serial.available()){ + String inputValue = Serial.readString(); + + memcpy(input->data.uint8, image2, 28*28); + + TfLiteStatus invoke_status = interpreter->Invoke(); + if (invoke_status != kTfLiteOk) { + Serial.println("Invoke failed!"); + return; + } + + int8_t max_value = output->data.uint8[0]; + int8_t predicted_digit = 0; + for (int i = 1; i < 10; i++) { + if (output->data.uint8[i] > max_value) { + max_value = output->data.uint8[i]; + predicted_digit = i-1; + } + } + + + Serial.print("Predicted digit: "); + Serial.println(predicted_digit); + Serial.print("Actual digit: "); + Serial.println(image2_label); + + } + + + +} -- GitLab