diff --git a/easy_mnist.ino b/easy_mnist.ino new file mode 100644 index 0000000000000000000000000000000000000000..b474e614b74344506439768b22b28b823c118614 --- /dev/null +++ b/easy_mnist.ino @@ -0,0 +1,79 @@ +#include <Chirale_TensorFlowLite.h> + +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/schema/schema_generated.h" + +#include "arduino_model.h" +#include "mnist_images.h" + +const tflite::Model* model = nullptr; +tflite::MicroInterpreter* interpreter = nullptr; +TfLiteTensor* input = nullptr; +TfLiteTensor* output = nullptr; + +constexpr int kTensorArenaSize = 20*1024; +alignas(16) uint8_t tensor_arena[kTensorArenaSize]; + +void setup() { + Serial.begin(9600); + while(!Serial); + + model = tflite::GetModel(arduino_model); + + if (model->version() != TFLITE_SCHEMA_VERSION) { + Serial.println("Model provided and schema version are not equal!"); + while(true); + } + + static tflite::AllOpsResolver resolver; + + static tflite::MicroInterpreter static_interpreter( + model, resolver, tensor_arena, kTensorArenaSize); + interpreter = &static_interpreter; + + TfLiteStatus allocate_status = interpreter->AllocateTensors(); + if (allocate_status != kTfLiteOk) { + Serial.println("AllocateTensors() failed"); + while(true); + } + + input = interpreter->input(0); + output = interpreter->output(0); + + Serial.println("Ready for input data."); +} + +void loop() { + + if (Serial.available()){ + String inputValue = Serial.readString(); + + memcpy(input->data.uint8, image2, 28*28); + + TfLiteStatus invoke_status = interpreter->Invoke(); + if (invoke_status != kTfLiteOk) { + Serial.println("Invoke failed!"); + return; + } + + int8_t max_value = output->data.uint8[0]; + int8_t predicted_digit = 0; + for (int i = 1; i < 10; i++) { + if (output->data.uint8[i] > max_value) { + max_value = output->data.uint8[i]; + predicted_digit = i-1; + } + } + + + Serial.print("Predicted digit: "); + Serial.println(predicted_digit); + Serial.print("Actual digit: "); + Serial.println(image2_label); + + } + + + +}