diff --git a/samples/modules/tensorflow/magic_wand/train/README.md b/samples/modules/tensorflow/magic_wand/train/README.md
new file mode 100644
index 00000000000..55f1617c9a9
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/README.md
@@ -0,0 +1,191 @@
+# Gesture Recognition Magic Wand Training Scripts
+
+## Introduction
+
+The scripts in this directory can be used to train a TensorFlow model that
+classifies gestures based on accelerometer data. The code uses Python 3.7 and
+TensorFlow 2.0. The resulting model is less than 20KB in size.
+
+The following document contains instructions on using the scripts to train a
+model, and capturing your own training data.
+
+This project was inspired by the [Gesture Recognition Magic Wand](https://github.com/jewang/gesture-demo)
+project by Jennifer Wang.
+
+## Training
+
+### Dataset
+
+Three magic gestures were chosen, and data were collected from 7
+different people. Some random long movement sequences were collected and divided
+into shorter pieces, which made up "negative" data along with some other
+automatically generated random data.
+
+The dataset can be downloaded from the following URL:
+
+[download.tensorflow.org/models/tflite/magic_wand/data.tar.gz](http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz)
+
+### Training in Colab
+
+The following [Google Colaboratory](https://colab.research.google.com)
+notebook demonstrates how to train the model. It's the easiest way to get
+started:
+
+
+
+If you'd prefer to run the scripts locally, use the following instructions.
+
+### Running the scripts
+
+Use the following command to install the required dependencies:
+
+```shell
+pip install -r requirements.txt
+```
+
+There are two ways to train the model:
+
+- Random data split, which mixes different people's data together and randomly
+ splits them into training, validation, and test sets
+- Person data split, which splits the data by person
+
+#### Random data split
+
+Using a random split results in higher training accuracy than a person split,
+but inferior performance on new data.
+
+```shell
+$ python data_prepare.py
+
+$ python data_split.py
+
+$ python train.py --model CNN --person false
+```
+
+#### Person data split
+
+Using a person data split results in lower training accuracy but better
+performance on new data.
+
+```shell
+$ python data_prepare.py
+
+$ python data_split_person.py
+
+$ python train.py --model CNN --person true
+```
+
+#### Model type
+
+In the `--model` argument, you can can provide `CNN` or `LSTM`. The CNN
+model has a smaller size and lower latency.
+
+## Collecting new data
+
+To obtain new training data using the
+[SparkFun Edge development board](https://sparkfun.com/products/15170), you can
+modify one of the examples in the [SparkFun Edge BSP](https://github.com/sparkfun/SparkFun_Edge_BSP)
+and deploy it using the Ambiq SDK.
+
+### Install the Ambiq SDK and SparkFun Edge BSP
+
+Follow SparkFun's
+[Using SparkFun Edge Board with Ambiq Apollo3 SDK](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all)
+guide to set up the Ambiq SDK and SparkFun Edge BSP.
+
+#### Modify the example code
+
+First, `cd` into
+`AmbiqSuite-Rel2.2.0/boards/SparkFun_Edge_BSP/examples/example1_edge_test`.
+
+##### Modify `src/tf_adc/tf_adc.c`
+
+Add `true` in line 62 as the second parameter of function
+`am_hal_adc_samples_read`.
+
+##### Modify `src/main.c`
+
+Add the line below in `int main(void)`, just before the line `while(1)`:
+
+```cc
+am_util_stdio_printf("-,-,-\r\n");
+```
+
+Change the following lines in `while(1){...}`
+
+```cc
+am_util_stdio_printf("Acc [mg] %04.2f x, %04.2f y, %04.2f z, Temp [deg C] %04.2f, MIC0 [counts / 2^14] %d\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2], temperature_degC, (audioSample) );
+```
+
+to:
+
+```cc
+am_util_stdio_printf("%04.2f,%04.2f,%04.2f\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2]);
+```
+
+#### Flash the binary
+
+Follow the instructions in
+[SparkFun's guide](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all#example-applications)
+to flash the binary to the device.
+
+#### Collect accelerometer data
+
+First, in a new terminal window, run the following command to begin logging
+output to `output.txt`:
+
+```shell
+$ script output.txt
+```
+
+Next, in the same window, use `screen` to connect to the device:
+
+```shell
+$ screen ${DEVICENAME} 115200
+```
+
+Output information collected from accelerometer sensor will be shown on the
+screen and saved in `output.txt`, in the format of "x,y,z" per line.
+
+Press the `RST` button to start capturing a new gesture, then press Button 14
+when it ends. New data will begin with a line "-,-,-".
+
+To exit `screen`, hit +Ctrl\\+A+, immediately followed by the +K+ key,
+then hit the +Y+ key. Then run
+
+```shell
+$ exit
+```
+
+to stop logging data. Data will be saved in `output.txt`. For compatibility
+with the training scripts, change the file name to include person's name and
+the gesture name, in the following format:
+
+```
+output_{gesture_name}_{person_name}.txt
+```
+
+#### Edit and run the scripts
+
+Edit the following files to include your new gesture names (replacing
+"wing", "ring", and "slope")
+
+- `data_load.py`
+- `data_prepare.py`
+- `data_split.py`
+
+Edit the following files to include your new person names (replacing "hyw",
+"shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "lsj", "pengxl", "liucx",
+and "zhangxy"):
+
+- `data_prepare.py`
+- `data_split_person.py`
+
+Finally, run the commands described earlier to train a new model.
diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py
new file mode 100644
index 00000000000..45d08c3fec6
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py
@@ -0,0 +1,74 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-bad-import-order
+
+"""Data augmentation that will be used in data_load.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import numpy as np
+
+
+def time_wrapping(molecule, denominator, data):
+ """Generate (molecule/denominator)x speed data."""
+ tmp_data = [[0
+ for i in range(len(data[0]))]
+ for j in range((int(len(data) / molecule) - 1) * denominator)]
+ for i in range(int(len(data) / molecule) - 1):
+ for j in range(len(data[i])):
+ for k in range(denominator):
+ tmp_data[denominator * i +
+ k][j] = (data[molecule * i + k][j] * (denominator - k) +
+ data[molecule * i + k + 1][j] * k) / denominator
+ return tmp_data
+
+
+def augment_data(original_data, original_label):
+ """Perform data augmentation."""
+ new_data = []
+ new_label = []
+ for idx, (data, label) in enumerate(zip(original_data, original_label)): # pylint: disable=unused-variable
+ # Original data
+ new_data.append(data)
+ new_label.append(label)
+ # Sequence shift
+ for num in range(5): # pylint: disable=unused-variable
+ new_data.append((np.array(data, dtype=np.float32) +
+ (random.random() - 0.5) * 200).tolist())
+ new_label.append(label)
+ # Random noise
+ tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))]
+ for num in range(5):
+ for i in range(len(tmp_data)): # pylint: disable=consider-using-enumerate
+ for j in range(len(tmp_data[i])):
+ tmp_data[i][j] = data[i][j] + 5 * random.random()
+ new_data.append(tmp_data)
+ new_label.append(label)
+ # Time warping
+ fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)]
+ for molecule, denominator in fractions:
+ new_data.append(time_wrapping(molecule, denominator, data))
+ new_label.append(label)
+ # Movement amplification
+ for molecule, denominator in fractions:
+ new_data.append(
+ (np.array(data, dtype=np.float32) * molecule / denominator).tolist())
+ new_label.append(label)
+ return new_data, new_label
diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py
new file mode 100644
index 00000000000..8da07e2e580
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py
@@ -0,0 +1,58 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-bad-import-order
+
+"""Test for data_augmentation.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import numpy as np
+
+from data_augmentation import augment_data
+from data_augmentation import time_wrapping
+
+
+class TestAugmentation(unittest.TestCase):
+
+ def test_time_wrapping(self):
+ original_data = np.random.rand(10, 3).tolist()
+ wrapped_data = time_wrapping(4, 5, original_data)
+ self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5)
+ self.assertEqual(len(wrapped_data[0]), len(original_data[0]))
+
+ def test_augment_data(self):
+ original_data = [
+ np.random.rand(128, 3).tolist(),
+ np.random.rand(66, 2).tolist(),
+ np.random.rand(9, 1).tolist()
+ ]
+ original_label = ["data", "augmentation", "test"]
+ augmented_data, augmented_label = augment_data(original_data,
+ original_label)
+ self.assertEqual(25 * len(original_data), len(augmented_data))
+ self.assertIsInstance(augmented_data, list)
+ self.assertEqual(25 * len(original_label), len(augmented_label))
+ self.assertIsInstance(augmented_label, list)
+ for i in range(len(original_label)): # pylint: disable=consider-using-enumerate
+ self.assertEqual(augmented_label[25 * i], original_label[i])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/samples/modules/tensorflow/magic_wand/train/data_load.py b/samples/modules/tensorflow/magic_wand/train/data_load.py
new file mode 100644
index 00000000000..35ff4825a8b
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_load.py
@@ -0,0 +1,106 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-bad-import-order
+
+"""Load data from the specified paths and format them for training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+import numpy as np
+import tensorflow as tf
+
+from data_augmentation import augment_data
+
+LABEL_NAME = "gesture"
+DATA_NAME = "accel_ms2_xyz"
+
+
+class DataLoader(object):
+ """Loads data and prepares for training."""
+
+ def __init__(self, train_data_path, valid_data_path, test_data_path,
+ seq_length):
+ self.dim = 3
+ self.seq_length = seq_length
+ self.label2id = {"wing": 0, "ring": 1, "slope": 2, "negative": 3}
+ self.train_data, self.train_label, self.train_len = self.get_data_file(
+ train_data_path, "train")
+ self.valid_data, self.valid_label, self.valid_len = self.get_data_file(
+ valid_data_path, "valid")
+ self.test_data, self.test_label, self.test_len = self.get_data_file(
+ test_data_path, "test")
+
+ def get_data_file(self, data_path, data_type): # pylint: disable=no-self-use
+ """Get train, valid and test data from files."""
+ data = []
+ label = []
+ with open(data_path, "r") as f:
+ lines = f.readlines()
+ for idx, line in enumerate(lines): # pylint: disable=unused-variable
+ dic = json.loads(line)
+ data.append(dic[DATA_NAME])
+ label.append(dic[LABEL_NAME])
+ if data_type == "train":
+ data, label = augment_data(data, label)
+ length = len(label)
+ print(data_type + "_data_length:" + str(length))
+ return data, label, length
+
+ def pad(self, data, seq_length, dim): # pylint: disable=no-self-use
+ """Get neighbour padding."""
+ noise_level = 20
+ padded_data = []
+ # Before- Neighbour padding
+ tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[0]
+ tmp_data[(seq_length -
+ min(len(data), seq_length)):] = data[:min(len(data), seq_length)]
+ padded_data.append(tmp_data)
+ # After- Neighbour padding
+ tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[-1]
+ tmp_data[:min(len(data), seq_length)] = data[:min(len(data), seq_length)]
+ padded_data.append(tmp_data)
+ return padded_data
+
+ def format_support_func(self, padded_num, length, data, label):
+ """Support function for format.(Helps format train, valid and test.)"""
+ # Add 2 padding, initialize data and label
+ length *= padded_num
+ features = np.zeros((length, self.seq_length, self.dim))
+ labels = np.zeros(length)
+ # Get padding for train, valid and test
+ for idx, (data, label) in enumerate(zip(data, label)): # pylint: disable=redefined-argument-from-local
+ padded_data = self.pad(data, self.seq_length, self.dim)
+ for num in range(padded_num):
+ features[padded_num * idx + num] = padded_data[num]
+ labels[padded_num * idx + num] = self.label2id[label]
+ # Turn into tf.data.Dataset
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (features, labels.astype("int32")))
+ return length, dataset
+
+ def format(self):
+ """Format data(including padding, etc.) and get the dataset for the model."""
+ padded_num = 2
+ self.train_len, self.train_data = self.format_support_func(
+ padded_num, self.train_len, self.train_data, self.train_label)
+ self.valid_len, self.valid_data = self.format_support_func(
+ padded_num, self.valid_len, self.valid_data, self.valid_label)
+ self.test_len, self.test_data = self.format_support_func(
+ padded_num, self.test_len, self.test_data, self.test_label)
diff --git a/samples/modules/tensorflow/magic_wand/train/data_load_test.py b/samples/modules/tensorflow/magic_wand/train/data_load_test.py
new file mode 100644
index 00000000000..82864974dad
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_load_test.py
@@ -0,0 +1,95 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-bad-import-order
+
+"""Test for data_load.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+from data_load import DataLoader
+
+import tensorflow as tf
+
+
+class TestLoad(unittest.TestCase):
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ self.loader = DataLoader(
+ "./data/train", "./data/valid", "./data/test", seq_length=512)
+
+ def test_get_data(self):
+ self.assertIsInstance(self.loader.train_data, list)
+ self.assertIsInstance(self.loader.train_label, list)
+ self.assertIsInstance(self.loader.valid_data, list)
+ self.assertIsInstance(self.loader.valid_label, list)
+ self.assertIsInstance(self.loader.test_data, list)
+ self.assertIsInstance(self.loader.test_label, list)
+ self.assertEqual(self.loader.train_len, len(self.loader.train_data))
+ self.assertEqual(self.loader.train_len, len(self.loader.train_label))
+ self.assertEqual(self.loader.valid_len, len(self.loader.valid_data))
+ self.assertEqual(self.loader.valid_len, len(self.loader.valid_label))
+ self.assertEqual(self.loader.test_len, len(self.loader.test_data))
+ self.assertEqual(self.loader.test_len, len(self.loader.test_label))
+
+ def test_pad(self):
+ original_data1 = [[2, 3], [1, 1]]
+ expected_data1_0 = [[2, 3], [2, 3], [2, 3], [2, 3], [1, 1]]
+ expected_data1_1 = [[2, 3], [1, 1], [1, 1], [1, 1], [1, 1]]
+ original_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333],
+ [9, 99], [-100, 0]]
+ expected_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333]]
+ padding_data1 = self.loader.pad(original_data1, seq_length=5, dim=2)
+ padding_data2 = self.loader.pad(original_data2, seq_length=5, dim=2)
+ for i in range(len(padding_data1[0])):
+ for j in range(len(padding_data1[0].tolist()[0])):
+ self.assertLess(
+ abs(padding_data1[0].tolist()[i][j] - expected_data1_0[i][j]),
+ 10.001)
+ for i in range(len(padding_data1[1])):
+ for j in range(len(padding_data1[1].tolist()[0])):
+ self.assertLess(
+ abs(padding_data1[1].tolist()[i][j] - expected_data1_1[i][j]),
+ 10.001)
+ self.assertEqual(padding_data2[0].tolist(), expected_data2)
+ self.assertEqual(padding_data2[1].tolist(), expected_data2)
+
+ def test_format(self):
+ self.loader.format()
+ expected_train_label = int(self.loader.label2id[self.loader.train_label[0]])
+ expected_valid_label = int(self.loader.label2id[self.loader.valid_label[0]])
+ expected_test_label = int(self.loader.label2id[self.loader.test_label[0]])
+ for feature, label in self.loader.train_data: # pylint: disable=unused-variable
+ format_train_label = label.numpy()
+ break
+ for feature, label in self.loader.valid_data:
+ format_valid_label = label.numpy()
+ break
+ for feature, label in self.loader.test_data:
+ format_test_label = label.numpy()
+ break
+ self.assertEqual(expected_train_label, format_train_label)
+ self.assertEqual(expected_valid_label, format_valid_label)
+ self.assertEqual(expected_test_label, format_test_label)
+ self.assertIsInstance(self.loader.train_data, tf.data.Dataset)
+ self.assertIsInstance(self.loader.valid_data, tf.data.Dataset)
+ self.assertIsInstance(self.loader.test_data, tf.data.Dataset)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare.py b/samples/modules/tensorflow/magic_wand/train/data_prepare.py
new file mode 100644
index 00000000000..727167f4d05
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_prepare.py
@@ -0,0 +1,164 @@
+# Lint as: python3
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Prepare data for further process.
+
+Read data from "/slope", "/ring", "/wing", "/negative" and save them
+in "/data/complete_data" in python dict format.
+
+It will generate a new file with the following structure:
+├── data
+│ └── complete_data
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import json
+import os
+import random
+
+LABEL_NAME = "gesture"
+DATA_NAME = "accel_ms2_xyz"
+folders = ["wing", "ring", "slope"]
+names = [
+ "hyw", "shiyun", "tangsy", "dengyl", "zhangxy", "pengxl", "liucx",
+ "jiangyh", "xunkai"
+]
+
+
+def prepare_original_data(folder, name, data, file_to_read):
+ """Read collected data from files."""
+ if folder != "negative":
+ with open(file_to_read, "r") as f:
+ lines = csv.reader(f)
+ data_new = {}
+ data_new[LABEL_NAME] = folder
+ data_new[DATA_NAME] = []
+ data_new["name"] = name
+ for idx, line in enumerate(lines): # pylint: disable=unused-variable
+ if len(line) == 3:
+ if line[2] == "-" and data_new[DATA_NAME]:
+ data.append(data_new)
+ data_new = {}
+ data_new[LABEL_NAME] = folder
+ data_new[DATA_NAME] = []
+ data_new["name"] = name
+ elif line[2] != "-":
+ data_new[DATA_NAME].append([float(i) for i in line[0:3]])
+ data.append(data_new)
+ else:
+ with open(file_to_read, "r") as f:
+ lines = csv.reader(f)
+ data_new = {}
+ data_new[LABEL_NAME] = folder
+ data_new[DATA_NAME] = []
+ data_new["name"] = name
+ for idx, line in enumerate(lines):
+ if len(line) == 3 and line[2] != "-":
+ if len(data_new[DATA_NAME]) == 120:
+ data.append(data_new)
+ data_new = {}
+ data_new[LABEL_NAME] = folder
+ data_new[DATA_NAME] = []
+ data_new["name"] = name
+ else:
+ data_new[DATA_NAME].append([float(i) for i in line[0:3]])
+ data.append(data_new)
+
+
+def generate_negative_data(data):
+ """Generate negative data labeled as 'negative6~8'."""
+ # Big movement -> around straight line
+ for i in range(100):
+ if i > 80:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
+ elif i > 60:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
+ else:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
+ start_x = (random.random() - 0.5) * 2000
+ start_y = (random.random() - 0.5) * 2000
+ start_z = (random.random() - 0.5) * 2000
+ x_increase = (random.random() - 0.5) * 10
+ y_increase = (random.random() - 0.5) * 10
+ z_increase = (random.random() - 0.5) * 10
+ for j in range(128):
+ dic[DATA_NAME].append([
+ start_x + j * x_increase + (random.random() - 0.5) * 6,
+ start_y + j * y_increase + (random.random() - 0.5) * 6,
+ start_z + j * z_increase + (random.random() - 0.5) * 6
+ ])
+ data.append(dic)
+ # Random
+ for i in range(100):
+ if i > 80:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
+ elif i > 60:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
+ else:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
+ for j in range(128):
+ dic[DATA_NAME].append([(random.random() - 0.5) * 1000,
+ (random.random() - 0.5) * 1000,
+ (random.random() - 0.5) * 1000])
+ data.append(dic)
+ # Stay still
+ for i in range(100):
+ if i > 80:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
+ elif i > 60:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
+ else:
+ dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
+ start_x = (random.random() - 0.5) * 2000
+ start_y = (random.random() - 0.5) * 2000
+ start_z = (random.random() - 0.5) * 2000
+ for j in range(128):
+ dic[DATA_NAME].append([
+ start_x + (random.random() - 0.5) * 40,
+ start_y + (random.random() - 0.5) * 40,
+ start_z + (random.random() - 0.5) * 40
+ ])
+ data.append(dic)
+
+
+# Write data to file
+def write_data(data_to_write, path):
+ with open(path, "w") as f:
+ for idx, item in enumerate(data_to_write): # pylint: disable=unused-variable
+ dic = json.dumps(item, ensure_ascii=False)
+ f.write(dic)
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ data = []
+ for idx1, folder in enumerate(folders):
+ for idx2, name in enumerate(names):
+ prepare_original_data(folder, name, data,
+ "./%s/output_%s_%s.txt" % (folder, folder, name))
+ for idx in range(5):
+ prepare_original_data("negative", "negative%d" % (idx + 1), data,
+ "./negative/output_negative_%d.txt" % (idx + 1))
+ generate_negative_data(data)
+ print("data_length: " + str(len(data)))
+ if not os.path.exists("./data"):
+ os.makedirs("./data")
+ write_data(data, "./data/complete_data")
diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py
new file mode 100644
index 00000000000..4c0e3f0657b
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py
@@ -0,0 +1,75 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for data_prepare.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import json
+import os
+import unittest
+from data_prepare import generate_negative_data
+from data_prepare import prepare_original_data
+from data_prepare import write_data
+
+
+class TestPrepare(unittest.TestCase):
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ self.file = "./%s/output_%s_%s.txt" % (folders[0], folders[0], names[0]) # pylint: disable=undefined-variable
+ self.data = []
+ prepare_original_data(folders[0], names[0], self.data, self.file) # pylint: disable=undefined-variable
+
+ def test_prepare_data(self):
+ num = 0
+ with open(self.file, "r") as f:
+ lines = csv.reader(f)
+ for idx, line in enumerate(lines): # pylint: disable=unused-variable
+ if len(line) == 3 and line[2] == "-":
+ num += 1
+ self.assertEqual(len(self.data), num)
+ self.assertIsInstance(self.data, list)
+ self.assertIsInstance(self.data[0], dict)
+ self.assertEqual(list(self.data[-1]), ["gesture", "accel_ms2_xyz", "name"])
+ self.assertEqual(self.data[0]["name"], names[0]) # pylint: disable=undefined-variable
+
+ def test_generate_negative(self):
+ original_len = len(self.data)
+ generate_negative_data(self.data)
+ self.assertEqual(original_len + 300, len(self.data))
+ generated_num = 0
+ for idx, data in enumerate(self.data): # pylint: disable=unused-variable
+ if data["name"] == "negative6" or data["name"] == "negative7" or data[
+ "name"] == "negative8":
+ generated_num += 1
+ self.assertEqual(generated_num, 300)
+
+ def test_write_data(self):
+ data_path_test = "./data/data0"
+ write_data(self.data, data_path_test)
+ with open(data_path_test, "r") as f:
+ lines = f.readlines()
+ self.assertEqual(len(lines), len(self.data))
+ self.assertEqual(json.loads(lines[0]), self.data[0])
+ self.assertEqual(json.loads(lines[-1]), self.data[-1])
+ os.remove(data_path_test)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/samples/modules/tensorflow/magic_wand/train/data_split.py b/samples/modules/tensorflow/magic_wand/train/data_split.py
new file mode 100644
index 00000000000..5448ed90020
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_split.py
@@ -0,0 +1,90 @@
+# Lint as: python3
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Mix and split data.
+
+Mix different people's data together and randomly split them into train,
+validation and test. These data would be saved separately under "/data".
+It will generate new files with the following structure:
+
+├── data
+│ ├── complete_data
+│ ├── test
+│ ├── train
+│ └── valid
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import random
+from data_prepare import write_data
+
+
+# Read data
+def read_data(path):
+ data = []
+ with open(path, "r") as f:
+ lines = f.readlines()
+ for idx, line in enumerate(lines): # pylint: disable=unused-variable
+ dic = json.loads(line)
+ data.append(dic)
+ print("data_length:" + str(len(data)))
+ return data
+
+
+def split_data(data, train_ratio, valid_ratio):
+ """Splits data into train, validation and test according to ratio."""
+ train_data = []
+ valid_data = []
+ test_data = []
+ num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
+ for idx, item in enumerate(data): # pylint: disable=unused-variable
+ for i in num_dic:
+ if item["gesture"] == i:
+ num_dic[i] += 1
+ print(num_dic)
+ train_num_dic = {}
+ valid_num_dic = {}
+ for i in num_dic:
+ train_num_dic[i] = int(train_ratio * num_dic[i])
+ valid_num_dic[i] = int(valid_ratio * num_dic[i])
+ random.seed(30)
+ random.shuffle(data)
+ for idx, item in enumerate(data):
+ for i in num_dic:
+ if item["gesture"] == i:
+ if train_num_dic[i] > 0:
+ train_data.append(item)
+ train_num_dic[i] -= 1
+ elif valid_num_dic[i] > 0:
+ valid_data.append(item)
+ valid_num_dic[i] -= 1
+ else:
+ test_data.append(item)
+ print("train_length:" + str(len(train_data)))
+ print("test_length:" + str(len(test_data)))
+ return train_data, valid_data, test_data
+
+
+if __name__ == "__main__":
+ data = read_data("./data/complete_data")
+ train_data, valid_data, test_data = split_data(data, 0.6, 0.2)
+ write_data(train_data, "./data/train")
+ write_data(valid_data, "./data/valid")
+ write_data(test_data, "./data/test")
diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person.py b/samples/modules/tensorflow/magic_wand/train/data_split_person.py
new file mode 100644
index 00000000000..8634b2f385a
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_split_person.py
@@ -0,0 +1,75 @@
+# Lint as: python3
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Split data into train, validation and test dataset according to person.
+
+That is, use some people's data as train, some other people's data as
+validation, and the rest ones' data as test. These data would be saved
+separately under "/person_split".
+
+It will generate new files with the following structure:
+├──person_split
+│ ├── test
+│ ├── train
+│ └──valid
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+from data_split import read_data
+from data_split import write_data
+
+
+def person_split(whole_data, train_names, valid_names, test_names):
+ """Split data by person."""
+ random.seed(30)
+ random.shuffle(whole_data)
+ train_data = []
+ valid_data = []
+ test_data = []
+ for idx, data in enumerate(whole_data): # pylint: disable=unused-variable
+ if data["name"] in train_names:
+ train_data.append(data)
+ elif data["name"] in valid_names:
+ valid_data.append(data)
+ elif data["name"] in test_names:
+ test_data.append(data)
+ print("train_length:" + str(len(train_data)))
+ print("valid_length:" + str(len(valid_data)))
+ print("test_length:" + str(len(test_data)))
+ return train_data, valid_data, test_data
+
+
+if __name__ == "__main__":
+ data = read_data("./data/complete_data")
+ train_names = [
+ "hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3",
+ "negative4", "negative5", "negative6"
+ ]
+ valid_names = ["lsj", "pengxl", "negative2", "negative7"]
+ test_names = ["liucx", "zhangxy", "negative1", "negative8"]
+ train_data, valid_data, test_data = person_split(data, train_names,
+ valid_names, test_names)
+ if not os.path.exists("./person_split"):
+ os.makedirs("./person_split")
+ write_data(train_data, "./person_split/train")
+ write_data(valid_data, "./person_split/valid")
+ write_data(test_data, "./person_split/test")
diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py
new file mode 100644
index 00000000000..25ed8e5fe64
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py
@@ -0,0 +1,54 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for data_split_person.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+from data_split_person import person_split
+from data_split_person import read_data
+
+
+class TestSplitPerson(unittest.TestCase):
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ self.data = read_data("./data/complete_data")
+
+ def test_person_split(self):
+ train_names = ["dengyl"]
+ valid_names = ["liucx"]
+ test_names = ["tangsy"]
+ dengyl_num = 63
+ liucx_num = 63
+ tangsy_num = 30
+ train_data, valid_data, test_data = person_split(self.data, train_names,
+ valid_names, test_names)
+ self.assertEqual(len(train_data), dengyl_num)
+ self.assertEqual(len(valid_data), liucx_num)
+ self.assertEqual(len(test_data), tangsy_num)
+ self.assertIsInstance(train_data, list)
+ self.assertIsInstance(valid_data, list)
+ self.assertIsInstance(test_data, list)
+ self.assertIsInstance(train_data[0], dict)
+ self.assertIsInstance(valid_data[0], dict)
+ self.assertIsInstance(test_data[0], dict)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_test.py
new file mode 100644
index 00000000000..3ab2f185c57
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/data_split_test.py
@@ -0,0 +1,77 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for data_split.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import unittest
+from data_split import read_data
+from data_split import split_data
+
+
+class TestSplit(unittest.TestCase):
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ self.data = read_data("./data/complete_data")
+ self.num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
+ with open("./data/complete_data", "r") as f:
+ lines = f.readlines()
+ self.num = len(lines)
+
+ def test_read_data(self):
+ self.assertEqual(len(self.data), self.num)
+ self.assertIsInstance(self.data, list)
+ self.assertIsInstance(self.data[0], dict)
+ self.assertEqual(
+ set(list(self.data[-1])), set(["gesture", "accel_ms2_xyz", "name"]))
+
+ def test_split_data(self):
+ with open("./data/complete_data", "r") as f:
+ lines = f.readlines()
+ for idx, line in enumerate(lines): # pylint: disable=unused-variable
+ dic = json.loads(line)
+ for ges in self.num_dic:
+ if dic["gesture"] == ges:
+ self.num_dic[ges] += 1
+ train_data_0, valid_data_0, test_data_100 = split_data(self.data, 0, 0)
+ train_data_50, valid_data_50, test_data_0 = split_data(self.data, 0.5, 0.5)
+ train_data_60, valid_data_20, test_data_20 = split_data(self.data, 0.6, 0.2)
+ len_60 = int(self.num_dic["wing"] * 0.6) + int(
+ self.num_dic["ring"] * 0.6) + int(self.num_dic["slope"] * 0.6) + int(
+ self.num_dic["negative"] * 0.6)
+ len_50 = int(self.num_dic["wing"] * 0.5) + int(
+ self.num_dic["ring"] * 0.5) + int(self.num_dic["slope"] * 0.5) + int(
+ self.num_dic["negative"] * 0.5)
+ len_20 = int(self.num_dic["wing"] * 0.2) + int(
+ self.num_dic["ring"] * 0.2) + int(self.num_dic["slope"] * 0.2) + int(
+ self.num_dic["negative"] * 0.2)
+ self.assertEqual(len(train_data_0), 0)
+ self.assertEqual(len(train_data_50), len_50)
+ self.assertEqual(len(train_data_60), len_60)
+ self.assertEqual(len(valid_data_0), 0)
+ self.assertEqual(len(valid_data_50), len_50)
+ self.assertEqual(len(valid_data_20), len_20)
+ self.assertEqual(len(test_data_100), self.num)
+ self.assertEqual(len(test_data_0), (self.num - 2 * len_50))
+ self.assertEqual(len(test_data_20), (self.num - len_60 - len_20))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 b/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5
new file mode 100644
index 00000000000..1d825b3aaf7
Binary files /dev/null and b/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 differ
diff --git a/samples/modules/tensorflow/magic_wand/train/requirements.txt b/samples/modules/tensorflow/magic_wand/train/requirements.txt
new file mode 100644
index 00000000000..c83b8a48eb0
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/requirements.txt
@@ -0,0 +1,2 @@
+numpy==1.16.2
+tensorflow==2.0.0-beta1
diff --git a/samples/modules/tensorflow/magic_wand/train/train.py b/samples/modules/tensorflow/magic_wand/train/train.py
new file mode 100644
index 00000000000..213c7c792e6
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/train.py
@@ -0,0 +1,202 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-bad-import-order
+
+"""Build and train neural networks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import datetime
+import os # pylint: disable=duplicate-code
+from data_load import DataLoader
+
+import numpy as np # pylint: disable=duplicate-code
+import tensorflow as tf
+
+logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
+
+
+def reshape_function(data, label):
+ reshaped_data = tf.reshape(data, [-1, 3, 1])
+ return reshaped_data, label
+
+
+def calculate_model_size(model):
+ print(model.summary())
+ var_sizes = [
+ np.product(list(map(int, v.shape))) * v.dtype.size
+ for v in model.trainable_variables
+ ]
+ print("Model size:", sum(var_sizes) / 1024, "KB")
+
+
+def build_cnn(seq_length):
+ """Builds a convolutional neural network in Keras."""
+ model = tf.keras.Sequential([
+ tf.keras.layers.Conv2D(
+ 8, (4, 3),
+ padding="same",
+ activation="relu",
+ input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
+ tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
+ tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
+ tf.keras.layers.Conv2D(16, (4, 1), padding="same",
+ activation="relu"), # (batch, 42, 1, 16)
+ tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
+ tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
+ tf.keras.layers.Flatten(), # (batch, 224)
+ tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
+ tf.keras.layers.Dropout(0.1), # (batch, 16)
+ tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
+ ])
+ model_path = os.path.join("./netmodels", "CNN")
+ print("Built CNN.")
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ model.load_weights("./netmodels/CNN/weights.h5")
+ return model, model_path
+
+
+def build_lstm(seq_length):
+ """Builds an LSTM in Keras."""
+ model = tf.keras.Sequential([
+ tf.keras.layers.Bidirectional(
+ tf.keras.layers.LSTM(22),
+ input_shape=(seq_length, 3)), # output_shape=(batch, 44)
+ tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
+ ])
+ model_path = os.path.join("./netmodels", "LSTM")
+ print("Built LSTM.")
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ return model, model_path
+
+
+def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
+ data_loader = DataLoader(
+ train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
+ data_loader.format()
+ return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
+ data_loader.valid_data, data_loader.test_len, data_loader.test_data
+
+
+def build_net(args, seq_length):
+ if args.model == "CNN":
+ model, model_path = build_cnn(seq_length)
+ elif args.model == "LSTM":
+ model, model_path = build_lstm(seq_length)
+ else:
+ print("Please input correct model name.(CNN LSTM)")
+ return model, model_path
+
+
+def train_net(
+ model,
+ model_path, # pylint: disable=unused-argument
+ train_len, # pylint: disable=unused-argument
+ train_data,
+ valid_len,
+ valid_data,
+ test_len,
+ test_data,
+ kind):
+ """Trains the model."""
+ calculate_model_size(model)
+ epochs = 50
+ batch_size = 64
+ model.compile(
+ optimizer="adam",
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"])
+ if kind == "CNN":
+ train_data = train_data.map(reshape_function)
+ test_data = test_data.map(reshape_function)
+ valid_data = valid_data.map(reshape_function)
+ test_labels = np.zeros(test_len)
+ idx = 0
+ for data, label in test_data: # pylint: disable=unused-variable
+ test_labels[idx] = label.numpy()
+ idx += 1
+ train_data = train_data.batch(batch_size).repeat()
+ valid_data = valid_data.batch(batch_size)
+ test_data = test_data.batch(batch_size)
+ model.fit(
+ train_data,
+ epochs=epochs,
+ validation_data=valid_data,
+ steps_per_epoch=1000,
+ validation_steps=int((valid_len - 1) / batch_size + 1),
+ callbacks=[tensorboard_callback])
+ loss, acc = model.evaluate(test_data)
+ pred = np.argmax(model.predict(test_data), axis=1)
+ confusion = tf.math.confusion_matrix(
+ labels=tf.constant(test_labels),
+ predictions=tf.constant(pred),
+ num_classes=4)
+ print(confusion)
+ print("Loss {}, Accuracy {}".format(loss, acc))
+ # Convert the model to the TensorFlow Lite format without quantization
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ tflite_model = converter.convert()
+
+ # Save the model to disk
+ open("model.tflite", "wb").write(tflite_model)
+
+ # Convert the model to the TensorFlow Lite format with quantization
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+ tflite_model = converter.convert()
+
+ # Save the model to disk
+ open("model_quantized.tflite", "wb").write(tflite_model)
+
+ basic_model_size = os.path.getsize("model.tflite")
+ print("Basic model is %d bytes" % basic_model_size)
+ quantized_model_size = os.path.getsize("model_quantized.tflite")
+ print("Quantized model is %d bytes" % quantized_model_size)
+ difference = basic_model_size - quantized_model_size
+ print("Difference is %d bytes" % difference)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", "-m")
+ parser.add_argument("--person", "-p")
+ args = parser.parse_args()
+
+ seq_length = 128
+
+ print("Start to load data...")
+ if args.person == "true":
+ train_len, train_data, valid_len, valid_data, test_len, test_data = \
+ load_data("./person_split/train", "./person_split/valid",
+ "./person_split/test", seq_length)
+ else:
+ train_len, train_data, valid_len, valid_data, test_len, test_data = \
+ load_data("./data/train", "./data/valid", "./data/test", seq_length)
+
+ print("Start to build net...")
+ model, model_path = build_net(args, seq_length)
+
+ print("Start training...")
+ train_net(model, model_path, train_len, train_data, valid_len, valid_data,
+ test_len, test_data, args.model)
+
+ print("Training finished!")
diff --git a/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb b/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb
new file mode 100644
index 00000000000..d285c522fdc
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb
@@ -0,0 +1,257 @@
+{
+ "cells": [
+ {
+ "source": [
+ "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "you may not use this file except in compliance with the License.\n",
+ "You may obtain a copy of the License at\n",
+ "\n",
+ " http://www.apache.org/licenses/LICENSE-2.0\n",
+ "\n",
+ "Unless required by applicable law or agreed to in writing, software\n",
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "See the License for the specific language governing permissions and\n",
+ "limitations under the License."
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "1BtkMGSYQOTQ"
+ },
+ "source": [
+ "# Train a gesture recognition model for microcontroller use"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BaFfr7DHRmGF"
+ },
+ "source": [
+ "This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n",
+ "\n",
+ "The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xXgS6rxyT7Qk"
+ },
+ "source": [
+ "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LG6ErX5FRIaV"
+ },
+ "source": [
+ "## Configure dependencies\n",
+ "\n",
+ "Run the following cell to ensure the correct version of TensorFlow is used."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "STNft9TrfoVh"
+ },
+ "source": [
+ "We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ygkWw73dRNda"
+ },
+ "outputs": [],
+ "source": [
+ "# Clone the repository from GitHub\n",
+ "!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n",
+ "# Copy the training scripts into our workspace\n",
+ "!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "pXI7R4RehFdU"
+ },
+ "source": [
+ "## Prepare the data\n",
+ "\n",
+ "Next, we'll download the data and extract it into the expected location within the training scripts' directory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "W2Sg2AKzVr2L"
+ },
+ "outputs": [],
+ "source": [
+ "# Download the data we will use to train the model\n",
+ "!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n",
+ "# Extract the data into the train directory\n",
+ "!tar xvzf data.tar.gz -C train 1>/dev/null"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "DNjukI1Sgl2C"
+ },
+ "source": [
+ "We'll then run the scripts that split the data into training, validation, and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XBqSVpi6Vxss"
+ },
+ "outputs": [],
+ "source": [
+ "# The scripts must be run from within the train directory\n",
+ "%cd train\n",
+ "# Prepare the data\n",
+ "!python data_prepare.py\n",
+ "# Split the data by person\n",
+ "!python data_split_person.py"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5-cmVbFvhTvy"
+ },
+ "source": [
+ "## Load TensorBoard\n",
+ "\n",
+ "Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "CCx6SN9NWRPw"
+ },
+ "outputs": [],
+ "source": [
+ "# Load TensorBoard\n",
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir logs/scalars"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ERC2Cr4PhaOl"
+ },
+ "source": [
+ "## Begin training\n",
+ "\n",
+ "The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "DXmQZgbuWQFO"
+ },
+ "outputs": [],
+ "source": [
+ "!python train.py --model CNN --person true"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "4gXbVzcXhvGD"
+ },
+ "source": [
+ "## Create a C source file\n",
+ "\n",
+ "The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n",
+ "\n",
+ "In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "8wgei4OGe3Nz"
+ },
+ "outputs": [],
+ "source": [
+ "# Install xxd if it is not available\n",
+ "!apt-get -qq install xxd\n",
+ "# Save the file as a C source file\n",
+ "!xxd -i model.tflite > /content/model.cc\n",
+ "# Print the source file\n",
+ "!cat /content/model.cc"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Train a gesture recognition model for microcontroller use",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/samples/modules/tensorflow/magic_wand/train/train_test.py b/samples/modules/tensorflow/magic_wand/train/train_test.py
new file mode 100644
index 00000000000..2b785663df7
--- /dev/null
+++ b/samples/modules/tensorflow/magic_wand/train/train_test.py
@@ -0,0 +1,78 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for train.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import numpy as np
+import tensorflow as tf
+from train import build_cnn
+from train import build_lstm
+from train import load_data
+from train import reshape_function
+
+
+class TestTrain(unittest.TestCase):
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ self.seq_length = 128
+ self.train_len, self.train_data, self.valid_len, self.valid_data, \
+ self.test_len, self.test_data = \
+ load_data("./data/train", "./data/valid", "./data/test",
+ self.seq_length)
+
+ def test_load_data(self):
+ self.assertIsInstance(self.train_data, tf.data.Dataset)
+ self.assertIsInstance(self.valid_data, tf.data.Dataset)
+ self.assertIsInstance(self.test_data, tf.data.Dataset)
+
+ def test_build_net(self):
+ cnn, cnn_path = build_cnn(self.seq_length)
+ lstm, lstm_path = build_lstm(self.seq_length)
+ cnn_data = np.random.rand(60, 128, 3, 1)
+ lstm_data = np.random.rand(60, 128, 3)
+ cnn_prob = cnn(tf.constant(cnn_data, dtype="float32")).numpy()
+ lstm_prob = lstm(tf.constant(lstm_data, dtype="float32")).numpy()
+ self.assertIsInstance(cnn, tf.keras.Sequential)
+ self.assertIsInstance(lstm, tf.keras.Sequential)
+ self.assertEqual(cnn_path, "./netmodels/CNN")
+ self.assertEqual(lstm_path, "./netmodels/LSTM")
+ self.assertEqual(cnn_prob.shape, (60, 4))
+ self.assertEqual(lstm_prob.shape, (60, 4))
+
+ def test_reshape_function(self):
+ for data, label in self.train_data:
+ original_data_shape = data.numpy().shape
+ original_label_shape = label.numpy().shape
+ break
+ self.train_data = self.train_data.map(reshape_function)
+ for data, label in self.train_data:
+ reshaped_data_shape = data.numpy().shape
+ reshaped_label_shape = label.numpy().shape
+ break
+ self.assertEqual(
+ reshaped_data_shape,
+ (int(original_data_shape[0] * original_data_shape[1] / 3), 3, 1))
+ self.assertEqual(reshaped_label_shape, original_label_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()