samples: add tensorflow magic wand sample train scripts

Adds training scripts to TensorFlow Magic Wand sample.

Signed-off-by: Lauren Murphy <lauren.murphy@intel.com>
This commit is contained in:
Lauren Murphy 2021-04-26 22:07:44 -05:00 committed by Anas Nashif
commit 14e6d1cf5d
16 changed files with 1598 additions and 0 deletions

View file

@ -0,0 +1,191 @@
# Gesture Recognition Magic Wand Training Scripts
## Introduction
The scripts in this directory can be used to train a TensorFlow model that
classifies gestures based on accelerometer data. The code uses Python 3.7 and
TensorFlow 2.0. The resulting model is less than 20KB in size.
The following document contains instructions on using the scripts to train a
model, and capturing your own training data.
This project was inspired by the [Gesture Recognition Magic Wand](https://github.com/jewang/gesture-demo)
project by Jennifer Wang.
## Training
### Dataset
Three magic gestures were chosen, and data were collected from 7
different people. Some random long movement sequences were collected and divided
into shorter pieces, which made up "negative" data along with some other
automatically generated random data.
The dataset can be downloaded from the following URL:
[download.tensorflow.org/models/tflite/magic_wand/data.tar.gz](http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz)
### Training in Colab
The following [Google Colaboratory](https://colab.research.google.com)
notebook demonstrates how to train the model. It's the easiest way to get
started:
<table class="tfo-notebook-buttons" align="left">
<td>
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td>
<td>
<a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td>
</table>
If you'd prefer to run the scripts locally, use the following instructions.
### Running the scripts
Use the following command to install the required dependencies:
```shell
pip install -r requirements.txt
```
There are two ways to train the model:
- Random data split, which mixes different people's data together and randomly
splits them into training, validation, and test sets
- Person data split, which splits the data by person
#### Random data split
Using a random split results in higher training accuracy than a person split,
but inferior performance on new data.
```shell
$ python data_prepare.py
$ python data_split.py
$ python train.py --model CNN --person false
```
#### Person data split
Using a person data split results in lower training accuracy but better
performance on new data.
```shell
$ python data_prepare.py
$ python data_split_person.py
$ python train.py --model CNN --person true
```
#### Model type
In the `--model` argument, you can can provide `CNN` or `LSTM`. The CNN
model has a smaller size and lower latency.
## Collecting new data
To obtain new training data using the
[SparkFun Edge development board](https://sparkfun.com/products/15170), you can
modify one of the examples in the [SparkFun Edge BSP](https://github.com/sparkfun/SparkFun_Edge_BSP)
and deploy it using the Ambiq SDK.
### Install the Ambiq SDK and SparkFun Edge BSP
Follow SparkFun's
[Using SparkFun Edge Board with Ambiq Apollo3 SDK](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all)
guide to set up the Ambiq SDK and SparkFun Edge BSP.
#### Modify the example code
First, `cd` into
`AmbiqSuite-Rel2.2.0/boards/SparkFun_Edge_BSP/examples/example1_edge_test`.
##### Modify `src/tf_adc/tf_adc.c`
Add `true` in line 62 as the second parameter of function
`am_hal_adc_samples_read`.
##### Modify `src/main.c`
Add the line below in `int main(void)`, just before the line `while(1)`:
```cc
am_util_stdio_printf("-,-,-\r\n");
```
Change the following lines in `while(1){...}`
```cc
am_util_stdio_printf("Acc [mg] %04.2f x, %04.2f y, %04.2f z, Temp [deg C] %04.2f, MIC0 [counts / 2^14] %d\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2], temperature_degC, (audioSample) );
```
to:
```cc
am_util_stdio_printf("%04.2f,%04.2f,%04.2f\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2]);
```
#### Flash the binary
Follow the instructions in
[SparkFun's guide](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all#example-applications)
to flash the binary to the device.
#### Collect accelerometer data
First, in a new terminal window, run the following command to begin logging
output to `output.txt`:
```shell
$ script output.txt
```
Next, in the same window, use `screen` to connect to the device:
```shell
$ screen ${DEVICENAME} 115200
```
Output information collected from accelerometer sensor will be shown on the
screen and saved in `output.txt`, in the format of "x,y,z" per line.
Press the `RST` button to start capturing a new gesture, then press Button 14
when it ends. New data will begin with a line "-,-,-".
To exit `screen`, hit +Ctrl\\+A+, immediately followed by the +K+ key,
then hit the +Y+ key. Then run
```shell
$ exit
```
to stop logging data. Data will be saved in `output.txt`. For compatibility
with the training scripts, change the file name to include person's name and
the gesture name, in the following format:
```
output_{gesture_name}_{person_name}.txt
```
#### Edit and run the scripts
Edit the following files to include your new gesture names (replacing
"wing", "ring", and "slope")
- `data_load.py`
- `data_prepare.py`
- `data_split.py`
Edit the following files to include your new person names (replacing "hyw",
"shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "lsj", "pengxl", "liucx",
and "zhangxy"):
- `data_prepare.py`
- `data_split_person.py`
Finally, run the commands described earlier to train a new model.

View file

@ -0,0 +1,74 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Data augmentation that will be used in data_load.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
def time_wrapping(molecule, denominator, data):
"""Generate (molecule/denominator)x speed data."""
tmp_data = [[0
for i in range(len(data[0]))]
for j in range((int(len(data) / molecule) - 1) * denominator)]
for i in range(int(len(data) / molecule) - 1):
for j in range(len(data[i])):
for k in range(denominator):
tmp_data[denominator * i +
k][j] = (data[molecule * i + k][j] * (denominator - k) +
data[molecule * i + k + 1][j] * k) / denominator
return tmp_data
def augment_data(original_data, original_label):
"""Perform data augmentation."""
new_data = []
new_label = []
for idx, (data, label) in enumerate(zip(original_data, original_label)): # pylint: disable=unused-variable
# Original data
new_data.append(data)
new_label.append(label)
# Sequence shift
for num in range(5): # pylint: disable=unused-variable
new_data.append((np.array(data, dtype=np.float32) +
(random.random() - 0.5) * 200).tolist())
new_label.append(label)
# Random noise
tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))]
for num in range(5):
for i in range(len(tmp_data)): # pylint: disable=consider-using-enumerate
for j in range(len(tmp_data[i])):
tmp_data[i][j] = data[i][j] + 5 * random.random()
new_data.append(tmp_data)
new_label.append(label)
# Time warping
fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)]
for molecule, denominator in fractions:
new_data.append(time_wrapping(molecule, denominator, data))
new_label.append(label)
# Movement amplification
for molecule, denominator in fractions:
new_data.append(
(np.array(data, dtype=np.float32) * molecule / denominator).tolist())
new_label.append(label)
return new_data, new_label

View file

@ -0,0 +1,58 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Test for data_augmentation.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from data_augmentation import augment_data
from data_augmentation import time_wrapping
class TestAugmentation(unittest.TestCase):
def test_time_wrapping(self):
original_data = np.random.rand(10, 3).tolist()
wrapped_data = time_wrapping(4, 5, original_data)
self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5)
self.assertEqual(len(wrapped_data[0]), len(original_data[0]))
def test_augment_data(self):
original_data = [
np.random.rand(128, 3).tolist(),
np.random.rand(66, 2).tolist(),
np.random.rand(9, 1).tolist()
]
original_label = ["data", "augmentation", "test"]
augmented_data, augmented_label = augment_data(original_data,
original_label)
self.assertEqual(25 * len(original_data), len(augmented_data))
self.assertIsInstance(augmented_data, list)
self.assertEqual(25 * len(original_label), len(augmented_label))
self.assertIsInstance(augmented_label, list)
for i in range(len(original_label)): # pylint: disable=consider-using-enumerate
self.assertEqual(augmented_label[25 * i], original_label[i])
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,106 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Load data from the specified paths and format them for training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
import tensorflow as tf
from data_augmentation import augment_data
LABEL_NAME = "gesture"
DATA_NAME = "accel_ms2_xyz"
class DataLoader(object):
"""Loads data and prepares for training."""
def __init__(self, train_data_path, valid_data_path, test_data_path,
seq_length):
self.dim = 3
self.seq_length = seq_length
self.label2id = {"wing": 0, "ring": 1, "slope": 2, "negative": 3}
self.train_data, self.train_label, self.train_len = self.get_data_file(
train_data_path, "train")
self.valid_data, self.valid_label, self.valid_len = self.get_data_file(
valid_data_path, "valid")
self.test_data, self.test_label, self.test_len = self.get_data_file(
test_data_path, "test")
def get_data_file(self, data_path, data_type): # pylint: disable=no-self-use
"""Get train, valid and test data from files."""
data = []
label = []
with open(data_path, "r") as f:
lines = f.readlines()
for idx, line in enumerate(lines): # pylint: disable=unused-variable
dic = json.loads(line)
data.append(dic[DATA_NAME])
label.append(dic[LABEL_NAME])
if data_type == "train":
data, label = augment_data(data, label)
length = len(label)
print(data_type + "_data_length:" + str(length))
return data, label, length
def pad(self, data, seq_length, dim): # pylint: disable=no-self-use
"""Get neighbour padding."""
noise_level = 20
padded_data = []
# Before- Neighbour padding
tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[0]
tmp_data[(seq_length -
min(len(data), seq_length)):] = data[:min(len(data), seq_length)]
padded_data.append(tmp_data)
# After- Neighbour padding
tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[-1]
tmp_data[:min(len(data), seq_length)] = data[:min(len(data), seq_length)]
padded_data.append(tmp_data)
return padded_data
def format_support_func(self, padded_num, length, data, label):
"""Support function for format.(Helps format train, valid and test.)"""
# Add 2 padding, initialize data and label
length *= padded_num
features = np.zeros((length, self.seq_length, self.dim))
labels = np.zeros(length)
# Get padding for train, valid and test
for idx, (data, label) in enumerate(zip(data, label)): # pylint: disable=redefined-argument-from-local
padded_data = self.pad(data, self.seq_length, self.dim)
for num in range(padded_num):
features[padded_num * idx + num] = padded_data[num]
labels[padded_num * idx + num] = self.label2id[label]
# Turn into tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices(
(features, labels.astype("int32")))
return length, dataset
def format(self):
"""Format data(including padding, etc.) and get the dataset for the model."""
padded_num = 2
self.train_len, self.train_data = self.format_support_func(
padded_num, self.train_len, self.train_data, self.train_label)
self.valid_len, self.valid_data = self.format_support_func(
padded_num, self.valid_len, self.valid_data, self.valid_label)
self.test_len, self.test_data = self.format_support_func(
padded_num, self.test_len, self.test_data, self.test_label)

View file

@ -0,0 +1,95 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Test for data_load.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from data_load import DataLoader
import tensorflow as tf
class TestLoad(unittest.TestCase):
def setUp(self): # pylint: disable=g-missing-super-call
self.loader = DataLoader(
"./data/train", "./data/valid", "./data/test", seq_length=512)
def test_get_data(self):
self.assertIsInstance(self.loader.train_data, list)
self.assertIsInstance(self.loader.train_label, list)
self.assertIsInstance(self.loader.valid_data, list)
self.assertIsInstance(self.loader.valid_label, list)
self.assertIsInstance(self.loader.test_data, list)
self.assertIsInstance(self.loader.test_label, list)
self.assertEqual(self.loader.train_len, len(self.loader.train_data))
self.assertEqual(self.loader.train_len, len(self.loader.train_label))
self.assertEqual(self.loader.valid_len, len(self.loader.valid_data))
self.assertEqual(self.loader.valid_len, len(self.loader.valid_label))
self.assertEqual(self.loader.test_len, len(self.loader.test_data))
self.assertEqual(self.loader.test_len, len(self.loader.test_label))
def test_pad(self):
original_data1 = [[2, 3], [1, 1]]
expected_data1_0 = [[2, 3], [2, 3], [2, 3], [2, 3], [1, 1]]
expected_data1_1 = [[2, 3], [1, 1], [1, 1], [1, 1], [1, 1]]
original_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333],
[9, 99], [-100, 0]]
expected_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333]]
padding_data1 = self.loader.pad(original_data1, seq_length=5, dim=2)
padding_data2 = self.loader.pad(original_data2, seq_length=5, dim=2)
for i in range(len(padding_data1[0])):
for j in range(len(padding_data1[0].tolist()[0])):
self.assertLess(
abs(padding_data1[0].tolist()[i][j] - expected_data1_0[i][j]),
10.001)
for i in range(len(padding_data1[1])):
for j in range(len(padding_data1[1].tolist()[0])):
self.assertLess(
abs(padding_data1[1].tolist()[i][j] - expected_data1_1[i][j]),
10.001)
self.assertEqual(padding_data2[0].tolist(), expected_data2)
self.assertEqual(padding_data2[1].tolist(), expected_data2)
def test_format(self):
self.loader.format()
expected_train_label = int(self.loader.label2id[self.loader.train_label[0]])
expected_valid_label = int(self.loader.label2id[self.loader.valid_label[0]])
expected_test_label = int(self.loader.label2id[self.loader.test_label[0]])
for feature, label in self.loader.train_data: # pylint: disable=unused-variable
format_train_label = label.numpy()
break
for feature, label in self.loader.valid_data:
format_valid_label = label.numpy()
break
for feature, label in self.loader.test_data:
format_test_label = label.numpy()
break
self.assertEqual(expected_train_label, format_train_label)
self.assertEqual(expected_valid_label, format_valid_label)
self.assertEqual(expected_test_label, format_test_label)
self.assertIsInstance(self.loader.train_data, tf.data.Dataset)
self.assertIsInstance(self.loader.valid_data, tf.data.Dataset)
self.assertIsInstance(self.loader.test_data, tf.data.Dataset)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,164 @@
# Lint as: python3
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prepare data for further process.
Read data from "/slope", "/ring", "/wing", "/negative" and save them
in "/data/complete_data" in python dict format.
It will generate a new file with the following structure:
data
   complete_data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import json
import os
import random
LABEL_NAME = "gesture"
DATA_NAME = "accel_ms2_xyz"
folders = ["wing", "ring", "slope"]
names = [
"hyw", "shiyun", "tangsy", "dengyl", "zhangxy", "pengxl", "liucx",
"jiangyh", "xunkai"
]
def prepare_original_data(folder, name, data, file_to_read):
"""Read collected data from files."""
if folder != "negative":
with open(file_to_read, "r") as f:
lines = csv.reader(f)
data_new = {}
data_new[LABEL_NAME] = folder
data_new[DATA_NAME] = []
data_new["name"] = name
for idx, line in enumerate(lines): # pylint: disable=unused-variable
if len(line) == 3:
if line[2] == "-" and data_new[DATA_NAME]:
data.append(data_new)
data_new = {}
data_new[LABEL_NAME] = folder
data_new[DATA_NAME] = []
data_new["name"] = name
elif line[2] != "-":
data_new[DATA_NAME].append([float(i) for i in line[0:3]])
data.append(data_new)
else:
with open(file_to_read, "r") as f:
lines = csv.reader(f)
data_new = {}
data_new[LABEL_NAME] = folder
data_new[DATA_NAME] = []
data_new["name"] = name
for idx, line in enumerate(lines):
if len(line) == 3 and line[2] != "-":
if len(data_new[DATA_NAME]) == 120:
data.append(data_new)
data_new = {}
data_new[LABEL_NAME] = folder
data_new[DATA_NAME] = []
data_new["name"] = name
else:
data_new[DATA_NAME].append([float(i) for i in line[0:3]])
data.append(data_new)
def generate_negative_data(data):
"""Generate negative data labeled as 'negative6~8'."""
# Big movement -> around straight line
for i in range(100):
if i > 80:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
elif i > 60:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
else:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
start_x = (random.random() - 0.5) * 2000
start_y = (random.random() - 0.5) * 2000
start_z = (random.random() - 0.5) * 2000
x_increase = (random.random() - 0.5) * 10
y_increase = (random.random() - 0.5) * 10
z_increase = (random.random() - 0.5) * 10
for j in range(128):
dic[DATA_NAME].append([
start_x + j * x_increase + (random.random() - 0.5) * 6,
start_y + j * y_increase + (random.random() - 0.5) * 6,
start_z + j * z_increase + (random.random() - 0.5) * 6
])
data.append(dic)
# Random
for i in range(100):
if i > 80:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
elif i > 60:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
else:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
for j in range(128):
dic[DATA_NAME].append([(random.random() - 0.5) * 1000,
(random.random() - 0.5) * 1000,
(random.random() - 0.5) * 1000])
data.append(dic)
# Stay still
for i in range(100):
if i > 80:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
elif i > 60:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
else:
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
start_x = (random.random() - 0.5) * 2000
start_y = (random.random() - 0.5) * 2000
start_z = (random.random() - 0.5) * 2000
for j in range(128):
dic[DATA_NAME].append([
start_x + (random.random() - 0.5) * 40,
start_y + (random.random() - 0.5) * 40,
start_z + (random.random() - 0.5) * 40
])
data.append(dic)
# Write data to file
def write_data(data_to_write, path):
with open(path, "w") as f:
for idx, item in enumerate(data_to_write): # pylint: disable=unused-variable
dic = json.dumps(item, ensure_ascii=False)
f.write(dic)
f.write("\n")
if __name__ == "__main__":
data = []
for idx1, folder in enumerate(folders):
for idx2, name in enumerate(names):
prepare_original_data(folder, name, data,
"./%s/output_%s_%s.txt" % (folder, folder, name))
for idx in range(5):
prepare_original_data("negative", "negative%d" % (idx + 1), data,
"./negative/output_negative_%d.txt" % (idx + 1))
generate_negative_data(data)
print("data_length: " + str(len(data)))
if not os.path.exists("./data"):
os.makedirs("./data")
write_data(data, "./data/complete_data")

View file

@ -0,0 +1,75 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for data_prepare.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import json
import os
import unittest
from data_prepare import generate_negative_data
from data_prepare import prepare_original_data
from data_prepare import write_data
class TestPrepare(unittest.TestCase):
def setUp(self): # pylint: disable=g-missing-super-call
self.file = "./%s/output_%s_%s.txt" % (folders[0], folders[0], names[0]) # pylint: disable=undefined-variable
self.data = []
prepare_original_data(folders[0], names[0], self.data, self.file) # pylint: disable=undefined-variable
def test_prepare_data(self):
num = 0
with open(self.file, "r") as f:
lines = csv.reader(f)
for idx, line in enumerate(lines): # pylint: disable=unused-variable
if len(line) == 3 and line[2] == "-":
num += 1
self.assertEqual(len(self.data), num)
self.assertIsInstance(self.data, list)
self.assertIsInstance(self.data[0], dict)
self.assertEqual(list(self.data[-1]), ["gesture", "accel_ms2_xyz", "name"])
self.assertEqual(self.data[0]["name"], names[0]) # pylint: disable=undefined-variable
def test_generate_negative(self):
original_len = len(self.data)
generate_negative_data(self.data)
self.assertEqual(original_len + 300, len(self.data))
generated_num = 0
for idx, data in enumerate(self.data): # pylint: disable=unused-variable
if data["name"] == "negative6" or data["name"] == "negative7" or data[
"name"] == "negative8":
generated_num += 1
self.assertEqual(generated_num, 300)
def test_write_data(self):
data_path_test = "./data/data0"
write_data(self.data, data_path_test)
with open(data_path_test, "r") as f:
lines = f.readlines()
self.assertEqual(len(lines), len(self.data))
self.assertEqual(json.loads(lines[0]), self.data[0])
self.assertEqual(json.loads(lines[-1]), self.data[-1])
os.remove(data_path_test)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,90 @@
# Lint as: python3
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mix and split data.
Mix different people's data together and randomly split them into train,
validation and test. These data would be saved separately under "/data".
It will generate new files with the following structure:
data
   complete_data
   test
   train
   valid
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
from data_prepare import write_data
# Read data
def read_data(path):
data = []
with open(path, "r") as f:
lines = f.readlines()
for idx, line in enumerate(lines): # pylint: disable=unused-variable
dic = json.loads(line)
data.append(dic)
print("data_length:" + str(len(data)))
return data
def split_data(data, train_ratio, valid_ratio):
"""Splits data into train, validation and test according to ratio."""
train_data = []
valid_data = []
test_data = []
num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
for idx, item in enumerate(data): # pylint: disable=unused-variable
for i in num_dic:
if item["gesture"] == i:
num_dic[i] += 1
print(num_dic)
train_num_dic = {}
valid_num_dic = {}
for i in num_dic:
train_num_dic[i] = int(train_ratio * num_dic[i])
valid_num_dic[i] = int(valid_ratio * num_dic[i])
random.seed(30)
random.shuffle(data)
for idx, item in enumerate(data):
for i in num_dic:
if item["gesture"] == i:
if train_num_dic[i] > 0:
train_data.append(item)
train_num_dic[i] -= 1
elif valid_num_dic[i] > 0:
valid_data.append(item)
valid_num_dic[i] -= 1
else:
test_data.append(item)
print("train_length:" + str(len(train_data)))
print("test_length:" + str(len(test_data)))
return train_data, valid_data, test_data
if __name__ == "__main__":
data = read_data("./data/complete_data")
train_data, valid_data, test_data = split_data(data, 0.6, 0.2)
write_data(train_data, "./data/train")
write_data(valid_data, "./data/valid")
write_data(test_data, "./data/test")

View file

@ -0,0 +1,75 @@
# Lint as: python3
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Split data into train, validation and test dataset according to person.
That is, use some people's data as train, some other people's data as
validation, and the rest ones' data as test. These data would be saved
separately under "/person_split".
It will generate new files with the following structure:
person_split
   test
   train
   valid
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
from data_split import read_data
from data_split import write_data
def person_split(whole_data, train_names, valid_names, test_names):
"""Split data by person."""
random.seed(30)
random.shuffle(whole_data)
train_data = []
valid_data = []
test_data = []
for idx, data in enumerate(whole_data): # pylint: disable=unused-variable
if data["name"] in train_names:
train_data.append(data)
elif data["name"] in valid_names:
valid_data.append(data)
elif data["name"] in test_names:
test_data.append(data)
print("train_length:" + str(len(train_data)))
print("valid_length:" + str(len(valid_data)))
print("test_length:" + str(len(test_data)))
return train_data, valid_data, test_data
if __name__ == "__main__":
data = read_data("./data/complete_data")
train_names = [
"hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3",
"negative4", "negative5", "negative6"
]
valid_names = ["lsj", "pengxl", "negative2", "negative7"]
test_names = ["liucx", "zhangxy", "negative1", "negative8"]
train_data, valid_data, test_data = person_split(data, train_names,
valid_names, test_names)
if not os.path.exists("./person_split"):
os.makedirs("./person_split")
write_data(train_data, "./person_split/train")
write_data(valid_data, "./person_split/valid")
write_data(test_data, "./person_split/test")

View file

@ -0,0 +1,54 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for data_split_person.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from data_split_person import person_split
from data_split_person import read_data
class TestSplitPerson(unittest.TestCase):
def setUp(self): # pylint: disable=g-missing-super-call
self.data = read_data("./data/complete_data")
def test_person_split(self):
train_names = ["dengyl"]
valid_names = ["liucx"]
test_names = ["tangsy"]
dengyl_num = 63
liucx_num = 63
tangsy_num = 30
train_data, valid_data, test_data = person_split(self.data, train_names,
valid_names, test_names)
self.assertEqual(len(train_data), dengyl_num)
self.assertEqual(len(valid_data), liucx_num)
self.assertEqual(len(test_data), tangsy_num)
self.assertIsInstance(train_data, list)
self.assertIsInstance(valid_data, list)
self.assertIsInstance(test_data, list)
self.assertIsInstance(train_data[0], dict)
self.assertIsInstance(valid_data[0], dict)
self.assertIsInstance(test_data[0], dict)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,77 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for data_split.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import unittest
from data_split import read_data
from data_split import split_data
class TestSplit(unittest.TestCase):
def setUp(self): # pylint: disable=g-missing-super-call
self.data = read_data("./data/complete_data")
self.num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
with open("./data/complete_data", "r") as f:
lines = f.readlines()
self.num = len(lines)
def test_read_data(self):
self.assertEqual(len(self.data), self.num)
self.assertIsInstance(self.data, list)
self.assertIsInstance(self.data[0], dict)
self.assertEqual(
set(list(self.data[-1])), set(["gesture", "accel_ms2_xyz", "name"]))
def test_split_data(self):
with open("./data/complete_data", "r") as f:
lines = f.readlines()
for idx, line in enumerate(lines): # pylint: disable=unused-variable
dic = json.loads(line)
for ges in self.num_dic:
if dic["gesture"] == ges:
self.num_dic[ges] += 1
train_data_0, valid_data_0, test_data_100 = split_data(self.data, 0, 0)
train_data_50, valid_data_50, test_data_0 = split_data(self.data, 0.5, 0.5)
train_data_60, valid_data_20, test_data_20 = split_data(self.data, 0.6, 0.2)
len_60 = int(self.num_dic["wing"] * 0.6) + int(
self.num_dic["ring"] * 0.6) + int(self.num_dic["slope"] * 0.6) + int(
self.num_dic["negative"] * 0.6)
len_50 = int(self.num_dic["wing"] * 0.5) + int(
self.num_dic["ring"] * 0.5) + int(self.num_dic["slope"] * 0.5) + int(
self.num_dic["negative"] * 0.5)
len_20 = int(self.num_dic["wing"] * 0.2) + int(
self.num_dic["ring"] * 0.2) + int(self.num_dic["slope"] * 0.2) + int(
self.num_dic["negative"] * 0.2)
self.assertEqual(len(train_data_0), 0)
self.assertEqual(len(train_data_50), len_50)
self.assertEqual(len(train_data_60), len_60)
self.assertEqual(len(valid_data_0), 0)
self.assertEqual(len(valid_data_50), len_50)
self.assertEqual(len(valid_data_20), len_20)
self.assertEqual(len(test_data_100), self.num)
self.assertEqual(len(test_data_0), (self.num - 2 * len_50))
self.assertEqual(len(test_data_20), (self.num - len_60 - len_20))
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,2 @@
numpy==1.16.2
tensorflow==2.0.0-beta1

View file

@ -0,0 +1,202 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Build and train neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import datetime
import os # pylint: disable=duplicate-code
from data_load import DataLoader
import numpy as np # pylint: disable=duplicate-code
import tensorflow as tf
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
def reshape_function(data, label):
reshaped_data = tf.reshape(data, [-1, 3, 1])
return reshaped_data, label
def calculate_model_size(model):
print(model.summary())
var_sizes = [
np.product(list(map(int, v.shape))) * v.dtype.size
for v in model.trainable_variables
]
print("Model size:", sum(var_sizes) / 1024, "KB")
def build_cnn(seq_length):
"""Builds a convolutional neural network in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(
8, (4, 3),
padding="same",
activation="relu",
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
tf.keras.layers.Conv2D(16, (4, 1), padding="same",
activation="relu"), # (batch, 42, 1, 16)
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
tf.keras.layers.Flatten(), # (batch, 224)
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
tf.keras.layers.Dropout(0.1), # (batch, 16)
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
])
model_path = os.path.join("./netmodels", "CNN")
print("Built CNN.")
if not os.path.exists(model_path):
os.makedirs(model_path)
model.load_weights("./netmodels/CNN/weights.h5")
return model, model_path
def build_lstm(seq_length):
"""Builds an LSTM in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(22),
input_shape=(seq_length, 3)), # output_shape=(batch, 44)
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
])
model_path = os.path.join("./netmodels", "LSTM")
print("Built LSTM.")
if not os.path.exists(model_path):
os.makedirs(model_path)
return model, model_path
def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
data_loader = DataLoader(
train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
data_loader.format()
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
data_loader.valid_data, data_loader.test_len, data_loader.test_data
def build_net(args, seq_length):
if args.model == "CNN":
model, model_path = build_cnn(seq_length)
elif args.model == "LSTM":
model, model_path = build_lstm(seq_length)
else:
print("Please input correct model name.(CNN LSTM)")
return model, model_path
def train_net(
model,
model_path, # pylint: disable=unused-argument
train_len, # pylint: disable=unused-argument
train_data,
valid_len,
valid_data,
test_len,
test_data,
kind):
"""Trains the model."""
calculate_model_size(model)
epochs = 50
batch_size = 64
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
if kind == "CNN":
train_data = train_data.map(reshape_function)
test_data = test_data.map(reshape_function)
valid_data = valid_data.map(reshape_function)
test_labels = np.zeros(test_len)
idx = 0
for data, label in test_data: # pylint: disable=unused-variable
test_labels[idx] = label.numpy()
idx += 1
train_data = train_data.batch(batch_size).repeat()
valid_data = valid_data.batch(batch_size)
test_data = test_data.batch(batch_size)
model.fit(
train_data,
epochs=epochs,
validation_data=valid_data,
steps_per_epoch=1000,
validation_steps=int((valid_len - 1) / batch_size + 1),
callbacks=[tensorboard_callback])
loss, acc = model.evaluate(test_data)
pred = np.argmax(model.predict(test_data), axis=1)
confusion = tf.math.confusion_matrix(
labels=tf.constant(test_labels),
predictions=tf.constant(pred),
num_classes=4)
print(confusion)
print("Loss {}, Accuracy {}".format(loss, acc))
# Convert the model to the TensorFlow Lite format without quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model to disk
open("model.tflite", "wb").write(tflite_model)
# Convert the model to the TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
# Save the model to disk
open("model_quantized.tflite", "wb").write(tflite_model)
basic_model_size = os.path.getsize("model.tflite")
print("Basic model is %d bytes" % basic_model_size)
quantized_model_size = os.path.getsize("model_quantized.tflite")
print("Quantized model is %d bytes" % quantized_model_size)
difference = basic_model_size - quantized_model_size
print("Difference is %d bytes" % difference)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m")
parser.add_argument("--person", "-p")
args = parser.parse_args()
seq_length = 128
print("Start to load data...")
if args.person == "true":
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./person_split/train", "./person_split/valid",
"./person_split/test", seq_length)
else:
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./data/train", "./data/valid", "./data/test", seq_length)
print("Start to build net...")
model, model_path = build_net(args, seq_length)
print("Start training...")
train_net(model, model_path, train_len, train_data, valid_len, valid_data,
test_len, test_data, args.model)
print("Training finished!")

View file

@ -0,0 +1,257 @@
{
"cells": [
{
"source": [
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1BtkMGSYQOTQ"
},
"source": [
"# Train a gesture recognition model for microcontroller use"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BaFfr7DHRmGF"
},
"source": [
"This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n",
"\n",
"The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
"\n",
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xXgS6rxyT7Qk"
},
"source": [
"Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LG6ErX5FRIaV"
},
"source": [
"## Configure dependencies\n",
"\n",
"Run the following cell to ensure the correct version of TensorFlow is used."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "STNft9TrfoVh"
},
"source": [
"We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ygkWw73dRNda"
},
"outputs": [],
"source": [
"# Clone the repository from GitHub\n",
"!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n",
"# Copy the training scripts into our workspace\n",
"!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pXI7R4RehFdU"
},
"source": [
"## Prepare the data\n",
"\n",
"Next, we'll download the data and extract it into the expected location within the training scripts' directory."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "W2Sg2AKzVr2L"
},
"outputs": [],
"source": [
"# Download the data we will use to train the model\n",
"!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n",
"# Extract the data into the train directory\n",
"!tar xvzf data.tar.gz -C train 1>/dev/null"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DNjukI1Sgl2C"
},
"source": [
"We'll then run the scripts that split the data into training, validation, and test sets."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XBqSVpi6Vxss"
},
"outputs": [],
"source": [
"# The scripts must be run from within the train directory\n",
"%cd train\n",
"# Prepare the data\n",
"!python data_prepare.py\n",
"# Split the data by person\n",
"!python data_split_person.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "5-cmVbFvhTvy"
},
"source": [
"## Load TensorBoard\n",
"\n",
"Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "CCx6SN9NWRPw"
},
"outputs": [],
"source": [
"# Load TensorBoard\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir logs/scalars"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ERC2Cr4PhaOl"
},
"source": [
"## Begin training\n",
"\n",
"The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DXmQZgbuWQFO"
},
"outputs": [],
"source": [
"!python train.py --model CNN --person true"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4gXbVzcXhvGD"
},
"source": [
"## Create a C source file\n",
"\n",
"The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n",
"\n",
"In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "8wgei4OGe3Nz"
},
"outputs": [],
"source": [
"# Install xxd if it is not available\n",
"!apt-get -qq install xxd\n",
"# Save the file as a C source file\n",
"!xxd -i model.tflite > /content/model.cc\n",
"# Print the source file\n",
"!cat /content/model.cc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Train a gesture recognition model for microcontroller use",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -0,0 +1,78 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for train.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
import tensorflow as tf
from train import build_cnn
from train import build_lstm
from train import load_data
from train import reshape_function
class TestTrain(unittest.TestCase):
def setUp(self): # pylint: disable=g-missing-super-call
self.seq_length = 128
self.train_len, self.train_data, self.valid_len, self.valid_data, \
self.test_len, self.test_data = \
load_data("./data/train", "./data/valid", "./data/test",
self.seq_length)
def test_load_data(self):
self.assertIsInstance(self.train_data, tf.data.Dataset)
self.assertIsInstance(self.valid_data, tf.data.Dataset)
self.assertIsInstance(self.test_data, tf.data.Dataset)
def test_build_net(self):
cnn, cnn_path = build_cnn(self.seq_length)
lstm, lstm_path = build_lstm(self.seq_length)
cnn_data = np.random.rand(60, 128, 3, 1)
lstm_data = np.random.rand(60, 128, 3)
cnn_prob = cnn(tf.constant(cnn_data, dtype="float32")).numpy()
lstm_prob = lstm(tf.constant(lstm_data, dtype="float32")).numpy()
self.assertIsInstance(cnn, tf.keras.Sequential)
self.assertIsInstance(lstm, tf.keras.Sequential)
self.assertEqual(cnn_path, "./netmodels/CNN")
self.assertEqual(lstm_path, "./netmodels/LSTM")
self.assertEqual(cnn_prob.shape, (60, 4))
self.assertEqual(lstm_prob.shape, (60, 4))
def test_reshape_function(self):
for data, label in self.train_data:
original_data_shape = data.numpy().shape
original_label_shape = label.numpy().shape
break
self.train_data = self.train_data.map(reshape_function)
for data, label in self.train_data:
reshaped_data_shape = data.numpy().shape
reshaped_label_shape = label.numpy().shape
break
self.assertEqual(
reshaped_data_shape,
(int(original_data_shape[0] * original_data_shape[1] / 3), 3, 1))
self.assertEqual(reshaped_label_shape, original_label_shape)
if __name__ == "__main__":
unittest.main()