diff --git a/samples/modules/tensorflow/magic_wand/train/README.md b/samples/modules/tensorflow/magic_wand/train/README.md new file mode 100644 index 00000000000..55f1617c9a9 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/README.md @@ -0,0 +1,191 @@ +# Gesture Recognition Magic Wand Training Scripts + +## Introduction + +The scripts in this directory can be used to train a TensorFlow model that +classifies gestures based on accelerometer data. The code uses Python 3.7 and +TensorFlow 2.0. The resulting model is less than 20KB in size. + +The following document contains instructions on using the scripts to train a +model, and capturing your own training data. + +This project was inspired by the [Gesture Recognition Magic Wand](https://github.com/jewang/gesture-demo) +project by Jennifer Wang. + +## Training + +### Dataset + +Three magic gestures were chosen, and data were collected from 7 +different people. Some random long movement sequences were collected and divided +into shorter pieces, which made up "negative" data along with some other +automatically generated random data. + +The dataset can be downloaded from the following URL: + +[download.tensorflow.org/models/tflite/magic_wand/data.tar.gz](http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz) + +### Training in Colab + +The following [Google Colaboratory](https://colab.research.google.com) +notebook demonstrates how to train the model. It's the easiest way to get +started: + + + + +
+ Run in Google Colab + + View source on GitHub +
+ +If you'd prefer to run the scripts locally, use the following instructions. + +### Running the scripts + +Use the following command to install the required dependencies: + +```shell +pip install -r requirements.txt +``` + +There are two ways to train the model: + +- Random data split, which mixes different people's data together and randomly + splits them into training, validation, and test sets +- Person data split, which splits the data by person + +#### Random data split + +Using a random split results in higher training accuracy than a person split, +but inferior performance on new data. + +```shell +$ python data_prepare.py + +$ python data_split.py + +$ python train.py --model CNN --person false +``` + +#### Person data split + +Using a person data split results in lower training accuracy but better +performance on new data. + +```shell +$ python data_prepare.py + +$ python data_split_person.py + +$ python train.py --model CNN --person true +``` + +#### Model type + +In the `--model` argument, you can can provide `CNN` or `LSTM`. The CNN +model has a smaller size and lower latency. + +## Collecting new data + +To obtain new training data using the +[SparkFun Edge development board](https://sparkfun.com/products/15170), you can +modify one of the examples in the [SparkFun Edge BSP](https://github.com/sparkfun/SparkFun_Edge_BSP) +and deploy it using the Ambiq SDK. + +### Install the Ambiq SDK and SparkFun Edge BSP + +Follow SparkFun's +[Using SparkFun Edge Board with Ambiq Apollo3 SDK](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all) +guide to set up the Ambiq SDK and SparkFun Edge BSP. + +#### Modify the example code + +First, `cd` into +`AmbiqSuite-Rel2.2.0/boards/SparkFun_Edge_BSP/examples/example1_edge_test`. + +##### Modify `src/tf_adc/tf_adc.c` + +Add `true` in line 62 as the second parameter of function +`am_hal_adc_samples_read`. + +##### Modify `src/main.c` + +Add the line below in `int main(void)`, just before the line `while(1)`: + +```cc +am_util_stdio_printf("-,-,-\r\n"); +``` + +Change the following lines in `while(1){...}` + +```cc +am_util_stdio_printf("Acc [mg] %04.2f x, %04.2f y, %04.2f z, Temp [deg C] %04.2f, MIC0 [counts / 2^14] %d\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2], temperature_degC, (audioSample) ); +``` + +to: + +```cc +am_util_stdio_printf("%04.2f,%04.2f,%04.2f\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2]); +``` + +#### Flash the binary + +Follow the instructions in +[SparkFun's guide](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all#example-applications) +to flash the binary to the device. + +#### Collect accelerometer data + +First, in a new terminal window, run the following command to begin logging +output to `output.txt`: + +```shell +$ script output.txt +``` + +Next, in the same window, use `screen` to connect to the device: + +```shell +$ screen ${DEVICENAME} 115200 +``` + +Output information collected from accelerometer sensor will be shown on the +screen and saved in `output.txt`, in the format of "x,y,z" per line. + +Press the `RST` button to start capturing a new gesture, then press Button 14 +when it ends. New data will begin with a line "-,-,-". + +To exit `screen`, hit +Ctrl\\+A+, immediately followed by the +K+ key, +then hit the +Y+ key. Then run + +```shell +$ exit +``` + +to stop logging data. Data will be saved in `output.txt`. For compatibility +with the training scripts, change the file name to include person's name and +the gesture name, in the following format: + +``` +output_{gesture_name}_{person_name}.txt +``` + +#### Edit and run the scripts + +Edit the following files to include your new gesture names (replacing +"wing", "ring", and "slope") + +- `data_load.py` +- `data_prepare.py` +- `data_split.py` + +Edit the following files to include your new person names (replacing "hyw", +"shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "lsj", "pengxl", "liucx", +and "zhangxy"): + +- `data_prepare.py` +- `data_split_person.py` + +Finally, run the commands described earlier to train a new model. diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py new file mode 100644 index 00000000000..45d08c3fec6 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py @@ -0,0 +1,74 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Data augmentation that will be used in data_load.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np + + +def time_wrapping(molecule, denominator, data): + """Generate (molecule/denominator)x speed data.""" + tmp_data = [[0 + for i in range(len(data[0]))] + for j in range((int(len(data) / molecule) - 1) * denominator)] + for i in range(int(len(data) / molecule) - 1): + for j in range(len(data[i])): + for k in range(denominator): + tmp_data[denominator * i + + k][j] = (data[molecule * i + k][j] * (denominator - k) + + data[molecule * i + k + 1][j] * k) / denominator + return tmp_data + + +def augment_data(original_data, original_label): + """Perform data augmentation.""" + new_data = [] + new_label = [] + for idx, (data, label) in enumerate(zip(original_data, original_label)): # pylint: disable=unused-variable + # Original data + new_data.append(data) + new_label.append(label) + # Sequence shift + for num in range(5): # pylint: disable=unused-variable + new_data.append((np.array(data, dtype=np.float32) + + (random.random() - 0.5) * 200).tolist()) + new_label.append(label) + # Random noise + tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))] + for num in range(5): + for i in range(len(tmp_data)): # pylint: disable=consider-using-enumerate + for j in range(len(tmp_data[i])): + tmp_data[i][j] = data[i][j] + 5 * random.random() + new_data.append(tmp_data) + new_label.append(label) + # Time warping + fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)] + for molecule, denominator in fractions: + new_data.append(time_wrapping(molecule, denominator, data)) + new_label.append(label) + # Movement amplification + for molecule, denominator in fractions: + new_data.append( + (np.array(data, dtype=np.float32) * molecule / denominator).tolist()) + new_label.append(label) + return new_data, new_label diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py new file mode 100644 index 00000000000..8da07e2e580 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py @@ -0,0 +1,58 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Test for data_augmentation.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np + +from data_augmentation import augment_data +from data_augmentation import time_wrapping + + +class TestAugmentation(unittest.TestCase): + + def test_time_wrapping(self): + original_data = np.random.rand(10, 3).tolist() + wrapped_data = time_wrapping(4, 5, original_data) + self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5) + self.assertEqual(len(wrapped_data[0]), len(original_data[0])) + + def test_augment_data(self): + original_data = [ + np.random.rand(128, 3).tolist(), + np.random.rand(66, 2).tolist(), + np.random.rand(9, 1).tolist() + ] + original_label = ["data", "augmentation", "test"] + augmented_data, augmented_label = augment_data(original_data, + original_label) + self.assertEqual(25 * len(original_data), len(augmented_data)) + self.assertIsInstance(augmented_data, list) + self.assertEqual(25 * len(original_label), len(augmented_label)) + self.assertIsInstance(augmented_label, list) + for i in range(len(original_label)): # pylint: disable=consider-using-enumerate + self.assertEqual(augmented_label[25 * i], original_label[i]) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_load.py b/samples/modules/tensorflow/magic_wand/train/data_load.py new file mode 100644 index 00000000000..35ff4825a8b --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_load.py @@ -0,0 +1,106 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Load data from the specified paths and format them for training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +import numpy as np +import tensorflow as tf + +from data_augmentation import augment_data + +LABEL_NAME = "gesture" +DATA_NAME = "accel_ms2_xyz" + + +class DataLoader(object): + """Loads data and prepares for training.""" + + def __init__(self, train_data_path, valid_data_path, test_data_path, + seq_length): + self.dim = 3 + self.seq_length = seq_length + self.label2id = {"wing": 0, "ring": 1, "slope": 2, "negative": 3} + self.train_data, self.train_label, self.train_len = self.get_data_file( + train_data_path, "train") + self.valid_data, self.valid_label, self.valid_len = self.get_data_file( + valid_data_path, "valid") + self.test_data, self.test_label, self.test_len = self.get_data_file( + test_data_path, "test") + + def get_data_file(self, data_path, data_type): # pylint: disable=no-self-use + """Get train, valid and test data from files.""" + data = [] + label = [] + with open(data_path, "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + data.append(dic[DATA_NAME]) + label.append(dic[LABEL_NAME]) + if data_type == "train": + data, label = augment_data(data, label) + length = len(label) + print(data_type + "_data_length:" + str(length)) + return data, label, length + + def pad(self, data, seq_length, dim): # pylint: disable=no-self-use + """Get neighbour padding.""" + noise_level = 20 + padded_data = [] + # Before- Neighbour padding + tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[0] + tmp_data[(seq_length - + min(len(data), seq_length)):] = data[:min(len(data), seq_length)] + padded_data.append(tmp_data) + # After- Neighbour padding + tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[-1] + tmp_data[:min(len(data), seq_length)] = data[:min(len(data), seq_length)] + padded_data.append(tmp_data) + return padded_data + + def format_support_func(self, padded_num, length, data, label): + """Support function for format.(Helps format train, valid and test.)""" + # Add 2 padding, initialize data and label + length *= padded_num + features = np.zeros((length, self.seq_length, self.dim)) + labels = np.zeros(length) + # Get padding for train, valid and test + for idx, (data, label) in enumerate(zip(data, label)): # pylint: disable=redefined-argument-from-local + padded_data = self.pad(data, self.seq_length, self.dim) + for num in range(padded_num): + features[padded_num * idx + num] = padded_data[num] + labels[padded_num * idx + num] = self.label2id[label] + # Turn into tf.data.Dataset + dataset = tf.data.Dataset.from_tensor_slices( + (features, labels.astype("int32"))) + return length, dataset + + def format(self): + """Format data(including padding, etc.) and get the dataset for the model.""" + padded_num = 2 + self.train_len, self.train_data = self.format_support_func( + padded_num, self.train_len, self.train_data, self.train_label) + self.valid_len, self.valid_data = self.format_support_func( + padded_num, self.valid_len, self.valid_data, self.valid_label) + self.test_len, self.test_data = self.format_support_func( + padded_num, self.test_len, self.test_data, self.test_label) diff --git a/samples/modules/tensorflow/magic_wand/train/data_load_test.py b/samples/modules/tensorflow/magic_wand/train/data_load_test.py new file mode 100644 index 00000000000..82864974dad --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_load_test.py @@ -0,0 +1,95 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Test for data_load.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from data_load import DataLoader + +import tensorflow as tf + + +class TestLoad(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.loader = DataLoader( + "./data/train", "./data/valid", "./data/test", seq_length=512) + + def test_get_data(self): + self.assertIsInstance(self.loader.train_data, list) + self.assertIsInstance(self.loader.train_label, list) + self.assertIsInstance(self.loader.valid_data, list) + self.assertIsInstance(self.loader.valid_label, list) + self.assertIsInstance(self.loader.test_data, list) + self.assertIsInstance(self.loader.test_label, list) + self.assertEqual(self.loader.train_len, len(self.loader.train_data)) + self.assertEqual(self.loader.train_len, len(self.loader.train_label)) + self.assertEqual(self.loader.valid_len, len(self.loader.valid_data)) + self.assertEqual(self.loader.valid_len, len(self.loader.valid_label)) + self.assertEqual(self.loader.test_len, len(self.loader.test_data)) + self.assertEqual(self.loader.test_len, len(self.loader.test_label)) + + def test_pad(self): + original_data1 = [[2, 3], [1, 1]] + expected_data1_0 = [[2, 3], [2, 3], [2, 3], [2, 3], [1, 1]] + expected_data1_1 = [[2, 3], [1, 1], [1, 1], [1, 1], [1, 1]] + original_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333], + [9, 99], [-100, 0]] + expected_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333]] + padding_data1 = self.loader.pad(original_data1, seq_length=5, dim=2) + padding_data2 = self.loader.pad(original_data2, seq_length=5, dim=2) + for i in range(len(padding_data1[0])): + for j in range(len(padding_data1[0].tolist()[0])): + self.assertLess( + abs(padding_data1[0].tolist()[i][j] - expected_data1_0[i][j]), + 10.001) + for i in range(len(padding_data1[1])): + for j in range(len(padding_data1[1].tolist()[0])): + self.assertLess( + abs(padding_data1[1].tolist()[i][j] - expected_data1_1[i][j]), + 10.001) + self.assertEqual(padding_data2[0].tolist(), expected_data2) + self.assertEqual(padding_data2[1].tolist(), expected_data2) + + def test_format(self): + self.loader.format() + expected_train_label = int(self.loader.label2id[self.loader.train_label[0]]) + expected_valid_label = int(self.loader.label2id[self.loader.valid_label[0]]) + expected_test_label = int(self.loader.label2id[self.loader.test_label[0]]) + for feature, label in self.loader.train_data: # pylint: disable=unused-variable + format_train_label = label.numpy() + break + for feature, label in self.loader.valid_data: + format_valid_label = label.numpy() + break + for feature, label in self.loader.test_data: + format_test_label = label.numpy() + break + self.assertEqual(expected_train_label, format_train_label) + self.assertEqual(expected_valid_label, format_valid_label) + self.assertEqual(expected_test_label, format_test_label) + self.assertIsInstance(self.loader.train_data, tf.data.Dataset) + self.assertIsInstance(self.loader.valid_data, tf.data.Dataset) + self.assertIsInstance(self.loader.test_data, tf.data.Dataset) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare.py b/samples/modules/tensorflow/magic_wand/train/data_prepare.py new file mode 100644 index 00000000000..727167f4d05 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_prepare.py @@ -0,0 +1,164 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Prepare data for further process. + +Read data from "/slope", "/ring", "/wing", "/negative" and save them +in "/data/complete_data" in python dict format. + +It will generate a new file with the following structure: +├── data +│   └── complete_data +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import json +import os +import random + +LABEL_NAME = "gesture" +DATA_NAME = "accel_ms2_xyz" +folders = ["wing", "ring", "slope"] +names = [ + "hyw", "shiyun", "tangsy", "dengyl", "zhangxy", "pengxl", "liucx", + "jiangyh", "xunkai" +] + + +def prepare_original_data(folder, name, data, file_to_read): + """Read collected data from files.""" + if folder != "negative": + with open(file_to_read, "r") as f: + lines = csv.reader(f) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + for idx, line in enumerate(lines): # pylint: disable=unused-variable + if len(line) == 3: + if line[2] == "-" and data_new[DATA_NAME]: + data.append(data_new) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + elif line[2] != "-": + data_new[DATA_NAME].append([float(i) for i in line[0:3]]) + data.append(data_new) + else: + with open(file_to_read, "r") as f: + lines = csv.reader(f) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + for idx, line in enumerate(lines): + if len(line) == 3 and line[2] != "-": + if len(data_new[DATA_NAME]) == 120: + data.append(data_new) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + else: + data_new[DATA_NAME].append([float(i) for i in line[0:3]]) + data.append(data_new) + + +def generate_negative_data(data): + """Generate negative data labeled as 'negative6~8'.""" + # Big movement -> around straight line + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + start_x = (random.random() - 0.5) * 2000 + start_y = (random.random() - 0.5) * 2000 + start_z = (random.random() - 0.5) * 2000 + x_increase = (random.random() - 0.5) * 10 + y_increase = (random.random() - 0.5) * 10 + z_increase = (random.random() - 0.5) * 10 + for j in range(128): + dic[DATA_NAME].append([ + start_x + j * x_increase + (random.random() - 0.5) * 6, + start_y + j * y_increase + (random.random() - 0.5) * 6, + start_z + j * z_increase + (random.random() - 0.5) * 6 + ]) + data.append(dic) + # Random + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + for j in range(128): + dic[DATA_NAME].append([(random.random() - 0.5) * 1000, + (random.random() - 0.5) * 1000, + (random.random() - 0.5) * 1000]) + data.append(dic) + # Stay still + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + start_x = (random.random() - 0.5) * 2000 + start_y = (random.random() - 0.5) * 2000 + start_z = (random.random() - 0.5) * 2000 + for j in range(128): + dic[DATA_NAME].append([ + start_x + (random.random() - 0.5) * 40, + start_y + (random.random() - 0.5) * 40, + start_z + (random.random() - 0.5) * 40 + ]) + data.append(dic) + + +# Write data to file +def write_data(data_to_write, path): + with open(path, "w") as f: + for idx, item in enumerate(data_to_write): # pylint: disable=unused-variable + dic = json.dumps(item, ensure_ascii=False) + f.write(dic) + f.write("\n") + + +if __name__ == "__main__": + data = [] + for idx1, folder in enumerate(folders): + for idx2, name in enumerate(names): + prepare_original_data(folder, name, data, + "./%s/output_%s_%s.txt" % (folder, folder, name)) + for idx in range(5): + prepare_original_data("negative", "negative%d" % (idx + 1), data, + "./negative/output_negative_%d.txt" % (idx + 1)) + generate_negative_data(data) + print("data_length: " + str(len(data))) + if not os.path.exists("./data"): + os.makedirs("./data") + write_data(data, "./data/complete_data") diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py new file mode 100644 index 00000000000..4c0e3f0657b --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py @@ -0,0 +1,75 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_prepare.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import json +import os +import unittest +from data_prepare import generate_negative_data +from data_prepare import prepare_original_data +from data_prepare import write_data + + +class TestPrepare(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.file = "./%s/output_%s_%s.txt" % (folders[0], folders[0], names[0]) # pylint: disable=undefined-variable + self.data = [] + prepare_original_data(folders[0], names[0], self.data, self.file) # pylint: disable=undefined-variable + + def test_prepare_data(self): + num = 0 + with open(self.file, "r") as f: + lines = csv.reader(f) + for idx, line in enumerate(lines): # pylint: disable=unused-variable + if len(line) == 3 and line[2] == "-": + num += 1 + self.assertEqual(len(self.data), num) + self.assertIsInstance(self.data, list) + self.assertIsInstance(self.data[0], dict) + self.assertEqual(list(self.data[-1]), ["gesture", "accel_ms2_xyz", "name"]) + self.assertEqual(self.data[0]["name"], names[0]) # pylint: disable=undefined-variable + + def test_generate_negative(self): + original_len = len(self.data) + generate_negative_data(self.data) + self.assertEqual(original_len + 300, len(self.data)) + generated_num = 0 + for idx, data in enumerate(self.data): # pylint: disable=unused-variable + if data["name"] == "negative6" or data["name"] == "negative7" or data[ + "name"] == "negative8": + generated_num += 1 + self.assertEqual(generated_num, 300) + + def test_write_data(self): + data_path_test = "./data/data0" + write_data(self.data, data_path_test) + with open(data_path_test, "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), len(self.data)) + self.assertEqual(json.loads(lines[0]), self.data[0]) + self.assertEqual(json.loads(lines[-1]), self.data[-1]) + os.remove(data_path_test) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_split.py b/samples/modules/tensorflow/magic_wand/train/data_split.py new file mode 100644 index 00000000000..5448ed90020 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split.py @@ -0,0 +1,90 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mix and split data. + +Mix different people's data together and randomly split them into train, +validation and test. These data would be saved separately under "/data". +It will generate new files with the following structure: + +├── data +│   ├── complete_data +│   ├── test +│   ├── train +│   └── valid +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import random +from data_prepare import write_data + + +# Read data +def read_data(path): + data = [] + with open(path, "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + data.append(dic) + print("data_length:" + str(len(data))) + return data + + +def split_data(data, train_ratio, valid_ratio): + """Splits data into train, validation and test according to ratio.""" + train_data = [] + valid_data = [] + test_data = [] + num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} + for idx, item in enumerate(data): # pylint: disable=unused-variable + for i in num_dic: + if item["gesture"] == i: + num_dic[i] += 1 + print(num_dic) + train_num_dic = {} + valid_num_dic = {} + for i in num_dic: + train_num_dic[i] = int(train_ratio * num_dic[i]) + valid_num_dic[i] = int(valid_ratio * num_dic[i]) + random.seed(30) + random.shuffle(data) + for idx, item in enumerate(data): + for i in num_dic: + if item["gesture"] == i: + if train_num_dic[i] > 0: + train_data.append(item) + train_num_dic[i] -= 1 + elif valid_num_dic[i] > 0: + valid_data.append(item) + valid_num_dic[i] -= 1 + else: + test_data.append(item) + print("train_length:" + str(len(train_data))) + print("test_length:" + str(len(test_data))) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + data = read_data("./data/complete_data") + train_data, valid_data, test_data = split_data(data, 0.6, 0.2) + write_data(train_data, "./data/train") + write_data(valid_data, "./data/valid") + write_data(test_data, "./data/test") diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person.py b/samples/modules/tensorflow/magic_wand/train/data_split_person.py new file mode 100644 index 00000000000..8634b2f385a --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_person.py @@ -0,0 +1,75 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Split data into train, validation and test dataset according to person. + +That is, use some people's data as train, some other people's data as +validation, and the rest ones' data as test. These data would be saved +separately under "/person_split". + +It will generate new files with the following structure: +├──person_split +│   ├── test +│   ├── train +│   └──valid +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +from data_split import read_data +from data_split import write_data + + +def person_split(whole_data, train_names, valid_names, test_names): + """Split data by person.""" + random.seed(30) + random.shuffle(whole_data) + train_data = [] + valid_data = [] + test_data = [] + for idx, data in enumerate(whole_data): # pylint: disable=unused-variable + if data["name"] in train_names: + train_data.append(data) + elif data["name"] in valid_names: + valid_data.append(data) + elif data["name"] in test_names: + test_data.append(data) + print("train_length:" + str(len(train_data))) + print("valid_length:" + str(len(valid_data))) + print("test_length:" + str(len(test_data))) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + data = read_data("./data/complete_data") + train_names = [ + "hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3", + "negative4", "negative5", "negative6" + ] + valid_names = ["lsj", "pengxl", "negative2", "negative7"] + test_names = ["liucx", "zhangxy", "negative1", "negative8"] + train_data, valid_data, test_data = person_split(data, train_names, + valid_names, test_names) + if not os.path.exists("./person_split"): + os.makedirs("./person_split") + write_data(train_data, "./person_split/train") + write_data(valid_data, "./person_split/valid") + write_data(test_data, "./person_split/test") diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py new file mode 100644 index 00000000000..25ed8e5fe64 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py @@ -0,0 +1,54 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_split_person.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from data_split_person import person_split +from data_split_person import read_data + + +class TestSplitPerson(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.data = read_data("./data/complete_data") + + def test_person_split(self): + train_names = ["dengyl"] + valid_names = ["liucx"] + test_names = ["tangsy"] + dengyl_num = 63 + liucx_num = 63 + tangsy_num = 30 + train_data, valid_data, test_data = person_split(self.data, train_names, + valid_names, test_names) + self.assertEqual(len(train_data), dengyl_num) + self.assertEqual(len(valid_data), liucx_num) + self.assertEqual(len(test_data), tangsy_num) + self.assertIsInstance(train_data, list) + self.assertIsInstance(valid_data, list) + self.assertIsInstance(test_data, list) + self.assertIsInstance(train_data[0], dict) + self.assertIsInstance(valid_data[0], dict) + self.assertIsInstance(test_data[0], dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_test.py new file mode 100644 index 00000000000..3ab2f185c57 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_test.py @@ -0,0 +1,77 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_split.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import unittest +from data_split import read_data +from data_split import split_data + + +class TestSplit(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.data = read_data("./data/complete_data") + self.num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} + with open("./data/complete_data", "r") as f: + lines = f.readlines() + self.num = len(lines) + + def test_read_data(self): + self.assertEqual(len(self.data), self.num) + self.assertIsInstance(self.data, list) + self.assertIsInstance(self.data[0], dict) + self.assertEqual( + set(list(self.data[-1])), set(["gesture", "accel_ms2_xyz", "name"])) + + def test_split_data(self): + with open("./data/complete_data", "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + for ges in self.num_dic: + if dic["gesture"] == ges: + self.num_dic[ges] += 1 + train_data_0, valid_data_0, test_data_100 = split_data(self.data, 0, 0) + train_data_50, valid_data_50, test_data_0 = split_data(self.data, 0.5, 0.5) + train_data_60, valid_data_20, test_data_20 = split_data(self.data, 0.6, 0.2) + len_60 = int(self.num_dic["wing"] * 0.6) + int( + self.num_dic["ring"] * 0.6) + int(self.num_dic["slope"] * 0.6) + int( + self.num_dic["negative"] * 0.6) + len_50 = int(self.num_dic["wing"] * 0.5) + int( + self.num_dic["ring"] * 0.5) + int(self.num_dic["slope"] * 0.5) + int( + self.num_dic["negative"] * 0.5) + len_20 = int(self.num_dic["wing"] * 0.2) + int( + self.num_dic["ring"] * 0.2) + int(self.num_dic["slope"] * 0.2) + int( + self.num_dic["negative"] * 0.2) + self.assertEqual(len(train_data_0), 0) + self.assertEqual(len(train_data_50), len_50) + self.assertEqual(len(train_data_60), len_60) + self.assertEqual(len(valid_data_0), 0) + self.assertEqual(len(valid_data_50), len_50) + self.assertEqual(len(valid_data_20), len_20) + self.assertEqual(len(test_data_100), self.num) + self.assertEqual(len(test_data_0), (self.num - 2 * len_50)) + self.assertEqual(len(test_data_20), (self.num - len_60 - len_20)) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 b/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 new file mode 100644 index 00000000000..1d825b3aaf7 Binary files /dev/null and b/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 differ diff --git a/samples/modules/tensorflow/magic_wand/train/requirements.txt b/samples/modules/tensorflow/magic_wand/train/requirements.txt new file mode 100644 index 00000000000..c83b8a48eb0 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.16.2 +tensorflow==2.0.0-beta1 diff --git a/samples/modules/tensorflow/magic_wand/train/train.py b/samples/modules/tensorflow/magic_wand/train/train.py new file mode 100644 index 00000000000..213c7c792e6 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/train.py @@ -0,0 +1,202 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Build and train neural networks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import os # pylint: disable=duplicate-code +from data_load import DataLoader + +import numpy as np # pylint: disable=duplicate-code +import tensorflow as tf + +logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) + + +def reshape_function(data, label): + reshaped_data = tf.reshape(data, [-1, 3, 1]) + return reshaped_data, label + + +def calculate_model_size(model): + print(model.summary()) + var_sizes = [ + np.product(list(map(int, v.shape))) * v.dtype.size + for v in model.trainable_variables + ] + print("Model size:", sum(var_sizes) / 1024, "KB") + + +def build_cnn(seq_length): + """Builds a convolutional neural network in Keras.""" + model = tf.keras.Sequential([ + tf.keras.layers.Conv2D( + 8, (4, 3), + padding="same", + activation="relu", + input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) + tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) + tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) + tf.keras.layers.Conv2D(16, (4, 1), padding="same", + activation="relu"), # (batch, 42, 1, 16) + tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) + tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) + tf.keras.layers.Flatten(), # (batch, 224) + tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) + tf.keras.layers.Dropout(0.1), # (batch, 16) + tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) + ]) + model_path = os.path.join("./netmodels", "CNN") + print("Built CNN.") + if not os.path.exists(model_path): + os.makedirs(model_path) + model.load_weights("./netmodels/CNN/weights.h5") + return model, model_path + + +def build_lstm(seq_length): + """Builds an LSTM in Keras.""" + model = tf.keras.Sequential([ + tf.keras.layers.Bidirectional( + tf.keras.layers.LSTM(22), + input_shape=(seq_length, 3)), # output_shape=(batch, 44) + tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) + ]) + model_path = os.path.join("./netmodels", "LSTM") + print("Built LSTM.") + if not os.path.exists(model_path): + os.makedirs(model_path) + return model, model_path + + +def load_data(train_data_path, valid_data_path, test_data_path, seq_length): + data_loader = DataLoader( + train_data_path, valid_data_path, test_data_path, seq_length=seq_length) + data_loader.format() + return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ + data_loader.valid_data, data_loader.test_len, data_loader.test_data + + +def build_net(args, seq_length): + if args.model == "CNN": + model, model_path = build_cnn(seq_length) + elif args.model == "LSTM": + model, model_path = build_lstm(seq_length) + else: + print("Please input correct model name.(CNN LSTM)") + return model, model_path + + +def train_net( + model, + model_path, # pylint: disable=unused-argument + train_len, # pylint: disable=unused-argument + train_data, + valid_len, + valid_data, + test_len, + test_data, + kind): + """Trains the model.""" + calculate_model_size(model) + epochs = 50 + batch_size = 64 + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"]) + if kind == "CNN": + train_data = train_data.map(reshape_function) + test_data = test_data.map(reshape_function) + valid_data = valid_data.map(reshape_function) + test_labels = np.zeros(test_len) + idx = 0 + for data, label in test_data: # pylint: disable=unused-variable + test_labels[idx] = label.numpy() + idx += 1 + train_data = train_data.batch(batch_size).repeat() + valid_data = valid_data.batch(batch_size) + test_data = test_data.batch(batch_size) + model.fit( + train_data, + epochs=epochs, + validation_data=valid_data, + steps_per_epoch=1000, + validation_steps=int((valid_len - 1) / batch_size + 1), + callbacks=[tensorboard_callback]) + loss, acc = model.evaluate(test_data) + pred = np.argmax(model.predict(test_data), axis=1) + confusion = tf.math.confusion_matrix( + labels=tf.constant(test_labels), + predictions=tf.constant(pred), + num_classes=4) + print(confusion) + print("Loss {}, Accuracy {}".format(loss, acc)) + # Convert the model to the TensorFlow Lite format without quantization + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + + # Save the model to disk + open("model.tflite", "wb").write(tflite_model) + + # Convert the model to the TensorFlow Lite format with quantization + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + tflite_model = converter.convert() + + # Save the model to disk + open("model_quantized.tflite", "wb").write(tflite_model) + + basic_model_size = os.path.getsize("model.tflite") + print("Basic model is %d bytes" % basic_model_size) + quantized_model_size = os.path.getsize("model_quantized.tflite") + print("Quantized model is %d bytes" % quantized_model_size) + difference = basic_model_size - quantized_model_size + print("Difference is %d bytes" % difference) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", "-m") + parser.add_argument("--person", "-p") + args = parser.parse_args() + + seq_length = 128 + + print("Start to load data...") + if args.person == "true": + train_len, train_data, valid_len, valid_data, test_len, test_data = \ + load_data("./person_split/train", "./person_split/valid", + "./person_split/test", seq_length) + else: + train_len, train_data, valid_len, valid_data, test_len, test_data = \ + load_data("./data/train", "./data/valid", "./data/test", seq_length) + + print("Start to build net...") + model, model_path = build_net(args, seq_length) + + print("Start training...") + train_net(model, model_path, train_len, train_data, valid_len, valid_data, + test_len, test_data, args.model) + + print("Training finished!") diff --git a/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb b/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb new file mode 100644 index 00000000000..d285c522fdc --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "source": [ + "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1BtkMGSYQOTQ" + }, + "source": [ + "# Train a gesture recognition model for microcontroller use" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BaFfr7DHRmGF" + }, + "source": [ + "This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n", + "\n", + "The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xXgS6rxyT7Qk" + }, + "source": [ + "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LG6ErX5FRIaV" + }, + "source": [ + "## Configure dependencies\n", + "\n", + "Run the following cell to ensure the correct version of TensorFlow is used." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "STNft9TrfoVh" + }, + "source": [ + "We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ygkWw73dRNda" + }, + "outputs": [], + "source": [ + "# Clone the repository from GitHub\n", + "!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n", + "# Copy the training scripts into our workspace\n", + "!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pXI7R4RehFdU" + }, + "source": [ + "## Prepare the data\n", + "\n", + "Next, we'll download the data and extract it into the expected location within the training scripts' directory." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "W2Sg2AKzVr2L" + }, + "outputs": [], + "source": [ + "# Download the data we will use to train the model\n", + "!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n", + "# Extract the data into the train directory\n", + "!tar xvzf data.tar.gz -C train 1>/dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DNjukI1Sgl2C" + }, + "source": [ + "We'll then run the scripts that split the data into training, validation, and test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XBqSVpi6Vxss" + }, + "outputs": [], + "source": [ + "# The scripts must be run from within the train directory\n", + "%cd train\n", + "# Prepare the data\n", + "!python data_prepare.py\n", + "# Split the data by person\n", + "!python data_split_person.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5-cmVbFvhTvy" + }, + "source": [ + "## Load TensorBoard\n", + "\n", + "Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CCx6SN9NWRPw" + }, + "outputs": [], + "source": [ + "# Load TensorBoard\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir logs/scalars" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ERC2Cr4PhaOl" + }, + "source": [ + "## Begin training\n", + "\n", + "The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DXmQZgbuWQFO" + }, + "outputs": [], + "source": [ + "!python train.py --model CNN --person true" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4gXbVzcXhvGD" + }, + "source": [ + "## Create a C source file\n", + "\n", + "The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n", + "\n", + "In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8wgei4OGe3Nz" + }, + "outputs": [], + "source": [ + "# Install xxd if it is not available\n", + "!apt-get -qq install xxd\n", + "# Save the file as a C source file\n", + "!xxd -i model.tflite > /content/model.cc\n", + "# Print the source file\n", + "!cat /content/model.cc" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Train a gesture recognition model for microcontroller use", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/samples/modules/tensorflow/magic_wand/train/train_test.py b/samples/modules/tensorflow/magic_wand/train/train_test.py new file mode 100644 index 00000000000..2b785663df7 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/train_test.py @@ -0,0 +1,78 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for train.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np +import tensorflow as tf +from train import build_cnn +from train import build_lstm +from train import load_data +from train import reshape_function + + +class TestTrain(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.seq_length = 128 + self.train_len, self.train_data, self.valid_len, self.valid_data, \ + self.test_len, self.test_data = \ + load_data("./data/train", "./data/valid", "./data/test", + self.seq_length) + + def test_load_data(self): + self.assertIsInstance(self.train_data, tf.data.Dataset) + self.assertIsInstance(self.valid_data, tf.data.Dataset) + self.assertIsInstance(self.test_data, tf.data.Dataset) + + def test_build_net(self): + cnn, cnn_path = build_cnn(self.seq_length) + lstm, lstm_path = build_lstm(self.seq_length) + cnn_data = np.random.rand(60, 128, 3, 1) + lstm_data = np.random.rand(60, 128, 3) + cnn_prob = cnn(tf.constant(cnn_data, dtype="float32")).numpy() + lstm_prob = lstm(tf.constant(lstm_data, dtype="float32")).numpy() + self.assertIsInstance(cnn, tf.keras.Sequential) + self.assertIsInstance(lstm, tf.keras.Sequential) + self.assertEqual(cnn_path, "./netmodels/CNN") + self.assertEqual(lstm_path, "./netmodels/LSTM") + self.assertEqual(cnn_prob.shape, (60, 4)) + self.assertEqual(lstm_prob.shape, (60, 4)) + + def test_reshape_function(self): + for data, label in self.train_data: + original_data_shape = data.numpy().shape + original_label_shape = label.numpy().shape + break + self.train_data = self.train_data.map(reshape_function) + for data, label in self.train_data: + reshaped_data_shape = data.numpy().shape + reshaped_label_shape = label.numpy().shape + break + self.assertEqual( + reshaped_data_shape, + (int(original_data_shape[0] * original_data_shape[1] / 3), 3, 1)) + self.assertEqual(reshaped_label_shape, original_label_shape) + + +if __name__ == "__main__": + unittest.main()