samples: add tensorflow magic wand sample train scripts
Adds training scripts to TensorFlow Magic Wand sample. Signed-off-by: Lauren Murphy <lauren.murphy@intel.com>
This commit is contained in:
parent
83a036d738
commit
14e6d1cf5d
16 changed files with 1598 additions and 0 deletions
191
samples/modules/tensorflow/magic_wand/train/README.md
Normal file
191
samples/modules/tensorflow/magic_wand/train/README.md
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Gesture Recognition Magic Wand Training Scripts
|
||||
|
||||
## Introduction
|
||||
|
||||
The scripts in this directory can be used to train a TensorFlow model that
|
||||
classifies gestures based on accelerometer data. The code uses Python 3.7 and
|
||||
TensorFlow 2.0. The resulting model is less than 20KB in size.
|
||||
|
||||
The following document contains instructions on using the scripts to train a
|
||||
model, and capturing your own training data.
|
||||
|
||||
This project was inspired by the [Gesture Recognition Magic Wand](https://github.com/jewang/gesture-demo)
|
||||
project by Jennifer Wang.
|
||||
|
||||
## Training
|
||||
|
||||
### Dataset
|
||||
|
||||
Three magic gestures were chosen, and data were collected from 7
|
||||
different people. Some random long movement sequences were collected and divided
|
||||
into shorter pieces, which made up "negative" data along with some other
|
||||
automatically generated random data.
|
||||
|
||||
The dataset can be downloaded from the following URL:
|
||||
|
||||
[download.tensorflow.org/models/tflite/magic_wand/data.tar.gz](http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz)
|
||||
|
||||
### Training in Colab
|
||||
|
||||
The following [Google Colaboratory](https://colab.research.google.com)
|
||||
notebook demonstrates how to train the model. It's the easiest way to get
|
||||
started:
|
||||
|
||||
<table class="tfo-notebook-buttons" align="left">
|
||||
<td>
|
||||
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
|
||||
</td>
|
||||
<td>
|
||||
<a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
|
||||
</td>
|
||||
</table>
|
||||
|
||||
If you'd prefer to run the scripts locally, use the following instructions.
|
||||
|
||||
### Running the scripts
|
||||
|
||||
Use the following command to install the required dependencies:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
There are two ways to train the model:
|
||||
|
||||
- Random data split, which mixes different people's data together and randomly
|
||||
splits them into training, validation, and test sets
|
||||
- Person data split, which splits the data by person
|
||||
|
||||
#### Random data split
|
||||
|
||||
Using a random split results in higher training accuracy than a person split,
|
||||
but inferior performance on new data.
|
||||
|
||||
```shell
|
||||
$ python data_prepare.py
|
||||
|
||||
$ python data_split.py
|
||||
|
||||
$ python train.py --model CNN --person false
|
||||
```
|
||||
|
||||
#### Person data split
|
||||
|
||||
Using a person data split results in lower training accuracy but better
|
||||
performance on new data.
|
||||
|
||||
```shell
|
||||
$ python data_prepare.py
|
||||
|
||||
$ python data_split_person.py
|
||||
|
||||
$ python train.py --model CNN --person true
|
||||
```
|
||||
|
||||
#### Model type
|
||||
|
||||
In the `--model` argument, you can can provide `CNN` or `LSTM`. The CNN
|
||||
model has a smaller size and lower latency.
|
||||
|
||||
## Collecting new data
|
||||
|
||||
To obtain new training data using the
|
||||
[SparkFun Edge development board](https://sparkfun.com/products/15170), you can
|
||||
modify one of the examples in the [SparkFun Edge BSP](https://github.com/sparkfun/SparkFun_Edge_BSP)
|
||||
and deploy it using the Ambiq SDK.
|
||||
|
||||
### Install the Ambiq SDK and SparkFun Edge BSP
|
||||
|
||||
Follow SparkFun's
|
||||
[Using SparkFun Edge Board with Ambiq Apollo3 SDK](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all)
|
||||
guide to set up the Ambiq SDK and SparkFun Edge BSP.
|
||||
|
||||
#### Modify the example code
|
||||
|
||||
First, `cd` into
|
||||
`AmbiqSuite-Rel2.2.0/boards/SparkFun_Edge_BSP/examples/example1_edge_test`.
|
||||
|
||||
##### Modify `src/tf_adc/tf_adc.c`
|
||||
|
||||
Add `true` in line 62 as the second parameter of function
|
||||
`am_hal_adc_samples_read`.
|
||||
|
||||
##### Modify `src/main.c`
|
||||
|
||||
Add the line below in `int main(void)`, just before the line `while(1)`:
|
||||
|
||||
```cc
|
||||
am_util_stdio_printf("-,-,-\r\n");
|
||||
```
|
||||
|
||||
Change the following lines in `while(1){...}`
|
||||
|
||||
```cc
|
||||
am_util_stdio_printf("Acc [mg] %04.2f x, %04.2f y, %04.2f z, Temp [deg C] %04.2f, MIC0 [counts / 2^14] %d\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2], temperature_degC, (audioSample) );
|
||||
```
|
||||
|
||||
to:
|
||||
|
||||
```cc
|
||||
am_util_stdio_printf("%04.2f,%04.2f,%04.2f\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2]);
|
||||
```
|
||||
|
||||
#### Flash the binary
|
||||
|
||||
Follow the instructions in
|
||||
[SparkFun's guide](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all#example-applications)
|
||||
to flash the binary to the device.
|
||||
|
||||
#### Collect accelerometer data
|
||||
|
||||
First, in a new terminal window, run the following command to begin logging
|
||||
output to `output.txt`:
|
||||
|
||||
```shell
|
||||
$ script output.txt
|
||||
```
|
||||
|
||||
Next, in the same window, use `screen` to connect to the device:
|
||||
|
||||
```shell
|
||||
$ screen ${DEVICENAME} 115200
|
||||
```
|
||||
|
||||
Output information collected from accelerometer sensor will be shown on the
|
||||
screen and saved in `output.txt`, in the format of "x,y,z" per line.
|
||||
|
||||
Press the `RST` button to start capturing a new gesture, then press Button 14
|
||||
when it ends. New data will begin with a line "-,-,-".
|
||||
|
||||
To exit `screen`, hit +Ctrl\\+A+, immediately followed by the +K+ key,
|
||||
then hit the +Y+ key. Then run
|
||||
|
||||
```shell
|
||||
$ exit
|
||||
```
|
||||
|
||||
to stop logging data. Data will be saved in `output.txt`. For compatibility
|
||||
with the training scripts, change the file name to include person's name and
|
||||
the gesture name, in the following format:
|
||||
|
||||
```
|
||||
output_{gesture_name}_{person_name}.txt
|
||||
```
|
||||
|
||||
#### Edit and run the scripts
|
||||
|
||||
Edit the following files to include your new gesture names (replacing
|
||||
"wing", "ring", and "slope")
|
||||
|
||||
- `data_load.py`
|
||||
- `data_prepare.py`
|
||||
- `data_split.py`
|
||||
|
||||
Edit the following files to include your new person names (replacing "hyw",
|
||||
"shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "lsj", "pengxl", "liucx",
|
||||
and "zhangxy"):
|
||||
|
||||
- `data_prepare.py`
|
||||
- `data_split_person.py`
|
||||
|
||||
Finally, run the commands described earlier to train a new model.
|
|
@ -0,0 +1,74 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
"""Data augmentation that will be used in data_load.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def time_wrapping(molecule, denominator, data):
|
||||
"""Generate (molecule/denominator)x speed data."""
|
||||
tmp_data = [[0
|
||||
for i in range(len(data[0]))]
|
||||
for j in range((int(len(data) / molecule) - 1) * denominator)]
|
||||
for i in range(int(len(data) / molecule) - 1):
|
||||
for j in range(len(data[i])):
|
||||
for k in range(denominator):
|
||||
tmp_data[denominator * i +
|
||||
k][j] = (data[molecule * i + k][j] * (denominator - k) +
|
||||
data[molecule * i + k + 1][j] * k) / denominator
|
||||
return tmp_data
|
||||
|
||||
|
||||
def augment_data(original_data, original_label):
|
||||
"""Perform data augmentation."""
|
||||
new_data = []
|
||||
new_label = []
|
||||
for idx, (data, label) in enumerate(zip(original_data, original_label)): # pylint: disable=unused-variable
|
||||
# Original data
|
||||
new_data.append(data)
|
||||
new_label.append(label)
|
||||
# Sequence shift
|
||||
for num in range(5): # pylint: disable=unused-variable
|
||||
new_data.append((np.array(data, dtype=np.float32) +
|
||||
(random.random() - 0.5) * 200).tolist())
|
||||
new_label.append(label)
|
||||
# Random noise
|
||||
tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))]
|
||||
for num in range(5):
|
||||
for i in range(len(tmp_data)): # pylint: disable=consider-using-enumerate
|
||||
for j in range(len(tmp_data[i])):
|
||||
tmp_data[i][j] = data[i][j] + 5 * random.random()
|
||||
new_data.append(tmp_data)
|
||||
new_label.append(label)
|
||||
# Time warping
|
||||
fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)]
|
||||
for molecule, denominator in fractions:
|
||||
new_data.append(time_wrapping(molecule, denominator, data))
|
||||
new_label.append(label)
|
||||
# Movement amplification
|
||||
for molecule, denominator in fractions:
|
||||
new_data.append(
|
||||
(np.array(data, dtype=np.float32) * molecule / denominator).tolist())
|
||||
new_label.append(label)
|
||||
return new_data, new_label
|
|
@ -0,0 +1,58 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
"""Test for data_augmentation.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from data_augmentation import augment_data
|
||||
from data_augmentation import time_wrapping
|
||||
|
||||
|
||||
class TestAugmentation(unittest.TestCase):
|
||||
|
||||
def test_time_wrapping(self):
|
||||
original_data = np.random.rand(10, 3).tolist()
|
||||
wrapped_data = time_wrapping(4, 5, original_data)
|
||||
self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5)
|
||||
self.assertEqual(len(wrapped_data[0]), len(original_data[0]))
|
||||
|
||||
def test_augment_data(self):
|
||||
original_data = [
|
||||
np.random.rand(128, 3).tolist(),
|
||||
np.random.rand(66, 2).tolist(),
|
||||
np.random.rand(9, 1).tolist()
|
||||
]
|
||||
original_label = ["data", "augmentation", "test"]
|
||||
augmented_data, augmented_label = augment_data(original_data,
|
||||
original_label)
|
||||
self.assertEqual(25 * len(original_data), len(augmented_data))
|
||||
self.assertIsInstance(augmented_data, list)
|
||||
self.assertEqual(25 * len(original_label), len(augmented_label))
|
||||
self.assertIsInstance(augmented_label, list)
|
||||
for i in range(len(original_label)): # pylint: disable=consider-using-enumerate
|
||||
self.assertEqual(augmented_label[25 * i], original_label[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
106
samples/modules/tensorflow/magic_wand/train/data_load.py
Normal file
106
samples/modules/tensorflow/magic_wand/train/data_load.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
"""Load data from the specified paths and format them for training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from data_augmentation import augment_data
|
||||
|
||||
LABEL_NAME = "gesture"
|
||||
DATA_NAME = "accel_ms2_xyz"
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
"""Loads data and prepares for training."""
|
||||
|
||||
def __init__(self, train_data_path, valid_data_path, test_data_path,
|
||||
seq_length):
|
||||
self.dim = 3
|
||||
self.seq_length = seq_length
|
||||
self.label2id = {"wing": 0, "ring": 1, "slope": 2, "negative": 3}
|
||||
self.train_data, self.train_label, self.train_len = self.get_data_file(
|
||||
train_data_path, "train")
|
||||
self.valid_data, self.valid_label, self.valid_len = self.get_data_file(
|
||||
valid_data_path, "valid")
|
||||
self.test_data, self.test_label, self.test_len = self.get_data_file(
|
||||
test_data_path, "test")
|
||||
|
||||
def get_data_file(self, data_path, data_type): # pylint: disable=no-self-use
|
||||
"""Get train, valid and test data from files."""
|
||||
data = []
|
||||
label = []
|
||||
with open(data_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
for idx, line in enumerate(lines): # pylint: disable=unused-variable
|
||||
dic = json.loads(line)
|
||||
data.append(dic[DATA_NAME])
|
||||
label.append(dic[LABEL_NAME])
|
||||
if data_type == "train":
|
||||
data, label = augment_data(data, label)
|
||||
length = len(label)
|
||||
print(data_type + "_data_length:" + str(length))
|
||||
return data, label, length
|
||||
|
||||
def pad(self, data, seq_length, dim): # pylint: disable=no-self-use
|
||||
"""Get neighbour padding."""
|
||||
noise_level = 20
|
||||
padded_data = []
|
||||
# Before- Neighbour padding
|
||||
tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[0]
|
||||
tmp_data[(seq_length -
|
||||
min(len(data), seq_length)):] = data[:min(len(data), seq_length)]
|
||||
padded_data.append(tmp_data)
|
||||
# After- Neighbour padding
|
||||
tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[-1]
|
||||
tmp_data[:min(len(data), seq_length)] = data[:min(len(data), seq_length)]
|
||||
padded_data.append(tmp_data)
|
||||
return padded_data
|
||||
|
||||
def format_support_func(self, padded_num, length, data, label):
|
||||
"""Support function for format.(Helps format train, valid and test.)"""
|
||||
# Add 2 padding, initialize data and label
|
||||
length *= padded_num
|
||||
features = np.zeros((length, self.seq_length, self.dim))
|
||||
labels = np.zeros(length)
|
||||
# Get padding for train, valid and test
|
||||
for idx, (data, label) in enumerate(zip(data, label)): # pylint: disable=redefined-argument-from-local
|
||||
padded_data = self.pad(data, self.seq_length, self.dim)
|
||||
for num in range(padded_num):
|
||||
features[padded_num * idx + num] = padded_data[num]
|
||||
labels[padded_num * idx + num] = self.label2id[label]
|
||||
# Turn into tf.data.Dataset
|
||||
dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(features, labels.astype("int32")))
|
||||
return length, dataset
|
||||
|
||||
def format(self):
|
||||
"""Format data(including padding, etc.) and get the dataset for the model."""
|
||||
padded_num = 2
|
||||
self.train_len, self.train_data = self.format_support_func(
|
||||
padded_num, self.train_len, self.train_data, self.train_label)
|
||||
self.valid_len, self.valid_data = self.format_support_func(
|
||||
padded_num, self.valid_len, self.valid_data, self.valid_label)
|
||||
self.test_len, self.test_data = self.format_support_func(
|
||||
padded_num, self.test_len, self.test_data, self.test_label)
|
|
@ -0,0 +1,95 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
"""Test for data_load.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from data_load import DataLoader
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class TestLoad(unittest.TestCase):
|
||||
|
||||
def setUp(self): # pylint: disable=g-missing-super-call
|
||||
self.loader = DataLoader(
|
||||
"./data/train", "./data/valid", "./data/test", seq_length=512)
|
||||
|
||||
def test_get_data(self):
|
||||
self.assertIsInstance(self.loader.train_data, list)
|
||||
self.assertIsInstance(self.loader.train_label, list)
|
||||
self.assertIsInstance(self.loader.valid_data, list)
|
||||
self.assertIsInstance(self.loader.valid_label, list)
|
||||
self.assertIsInstance(self.loader.test_data, list)
|
||||
self.assertIsInstance(self.loader.test_label, list)
|
||||
self.assertEqual(self.loader.train_len, len(self.loader.train_data))
|
||||
self.assertEqual(self.loader.train_len, len(self.loader.train_label))
|
||||
self.assertEqual(self.loader.valid_len, len(self.loader.valid_data))
|
||||
self.assertEqual(self.loader.valid_len, len(self.loader.valid_label))
|
||||
self.assertEqual(self.loader.test_len, len(self.loader.test_data))
|
||||
self.assertEqual(self.loader.test_len, len(self.loader.test_label))
|
||||
|
||||
def test_pad(self):
|
||||
original_data1 = [[2, 3], [1, 1]]
|
||||
expected_data1_0 = [[2, 3], [2, 3], [2, 3], [2, 3], [1, 1]]
|
||||
expected_data1_1 = [[2, 3], [1, 1], [1, 1], [1, 1], [1, 1]]
|
||||
original_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333],
|
||||
[9, 99], [-100, 0]]
|
||||
expected_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333]]
|
||||
padding_data1 = self.loader.pad(original_data1, seq_length=5, dim=2)
|
||||
padding_data2 = self.loader.pad(original_data2, seq_length=5, dim=2)
|
||||
for i in range(len(padding_data1[0])):
|
||||
for j in range(len(padding_data1[0].tolist()[0])):
|
||||
self.assertLess(
|
||||
abs(padding_data1[0].tolist()[i][j] - expected_data1_0[i][j]),
|
||||
10.001)
|
||||
for i in range(len(padding_data1[1])):
|
||||
for j in range(len(padding_data1[1].tolist()[0])):
|
||||
self.assertLess(
|
||||
abs(padding_data1[1].tolist()[i][j] - expected_data1_1[i][j]),
|
||||
10.001)
|
||||
self.assertEqual(padding_data2[0].tolist(), expected_data2)
|
||||
self.assertEqual(padding_data2[1].tolist(), expected_data2)
|
||||
|
||||
def test_format(self):
|
||||
self.loader.format()
|
||||
expected_train_label = int(self.loader.label2id[self.loader.train_label[0]])
|
||||
expected_valid_label = int(self.loader.label2id[self.loader.valid_label[0]])
|
||||
expected_test_label = int(self.loader.label2id[self.loader.test_label[0]])
|
||||
for feature, label in self.loader.train_data: # pylint: disable=unused-variable
|
||||
format_train_label = label.numpy()
|
||||
break
|
||||
for feature, label in self.loader.valid_data:
|
||||
format_valid_label = label.numpy()
|
||||
break
|
||||
for feature, label in self.loader.test_data:
|
||||
format_test_label = label.numpy()
|
||||
break
|
||||
self.assertEqual(expected_train_label, format_train_label)
|
||||
self.assertEqual(expected_valid_label, format_valid_label)
|
||||
self.assertEqual(expected_test_label, format_test_label)
|
||||
self.assertIsInstance(self.loader.train_data, tf.data.Dataset)
|
||||
self.assertIsInstance(self.loader.valid_data, tf.data.Dataset)
|
||||
self.assertIsInstance(self.loader.test_data, tf.data.Dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
164
samples/modules/tensorflow/magic_wand/train/data_prepare.py
Normal file
164
samples/modules/tensorflow/magic_wand/train/data_prepare.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Lint as: python3
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Prepare data for further process.
|
||||
|
||||
Read data from "/slope", "/ring", "/wing", "/negative" and save them
|
||||
in "/data/complete_data" in python dict format.
|
||||
|
||||
It will generate a new file with the following structure:
|
||||
├── data
|
||||
│ └── complete_data
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
LABEL_NAME = "gesture"
|
||||
DATA_NAME = "accel_ms2_xyz"
|
||||
folders = ["wing", "ring", "slope"]
|
||||
names = [
|
||||
"hyw", "shiyun", "tangsy", "dengyl", "zhangxy", "pengxl", "liucx",
|
||||
"jiangyh", "xunkai"
|
||||
]
|
||||
|
||||
|
||||
def prepare_original_data(folder, name, data, file_to_read):
|
||||
"""Read collected data from files."""
|
||||
if folder != "negative":
|
||||
with open(file_to_read, "r") as f:
|
||||
lines = csv.reader(f)
|
||||
data_new = {}
|
||||
data_new[LABEL_NAME] = folder
|
||||
data_new[DATA_NAME] = []
|
||||
data_new["name"] = name
|
||||
for idx, line in enumerate(lines): # pylint: disable=unused-variable
|
||||
if len(line) == 3:
|
||||
if line[2] == "-" and data_new[DATA_NAME]:
|
||||
data.append(data_new)
|
||||
data_new = {}
|
||||
data_new[LABEL_NAME] = folder
|
||||
data_new[DATA_NAME] = []
|
||||
data_new["name"] = name
|
||||
elif line[2] != "-":
|
||||
data_new[DATA_NAME].append([float(i) for i in line[0:3]])
|
||||
data.append(data_new)
|
||||
else:
|
||||
with open(file_to_read, "r") as f:
|
||||
lines = csv.reader(f)
|
||||
data_new = {}
|
||||
data_new[LABEL_NAME] = folder
|
||||
data_new[DATA_NAME] = []
|
||||
data_new["name"] = name
|
||||
for idx, line in enumerate(lines):
|
||||
if len(line) == 3 and line[2] != "-":
|
||||
if len(data_new[DATA_NAME]) == 120:
|
||||
data.append(data_new)
|
||||
data_new = {}
|
||||
data_new[LABEL_NAME] = folder
|
||||
data_new[DATA_NAME] = []
|
||||
data_new["name"] = name
|
||||
else:
|
||||
data_new[DATA_NAME].append([float(i) for i in line[0:3]])
|
||||
data.append(data_new)
|
||||
|
||||
|
||||
def generate_negative_data(data):
|
||||
"""Generate negative data labeled as 'negative6~8'."""
|
||||
# Big movement -> around straight line
|
||||
for i in range(100):
|
||||
if i > 80:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
|
||||
elif i > 60:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
|
||||
else:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
|
||||
start_x = (random.random() - 0.5) * 2000
|
||||
start_y = (random.random() - 0.5) * 2000
|
||||
start_z = (random.random() - 0.5) * 2000
|
||||
x_increase = (random.random() - 0.5) * 10
|
||||
y_increase = (random.random() - 0.5) * 10
|
||||
z_increase = (random.random() - 0.5) * 10
|
||||
for j in range(128):
|
||||
dic[DATA_NAME].append([
|
||||
start_x + j * x_increase + (random.random() - 0.5) * 6,
|
||||
start_y + j * y_increase + (random.random() - 0.5) * 6,
|
||||
start_z + j * z_increase + (random.random() - 0.5) * 6
|
||||
])
|
||||
data.append(dic)
|
||||
# Random
|
||||
for i in range(100):
|
||||
if i > 80:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
|
||||
elif i > 60:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
|
||||
else:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
|
||||
for j in range(128):
|
||||
dic[DATA_NAME].append([(random.random() - 0.5) * 1000,
|
||||
(random.random() - 0.5) * 1000,
|
||||
(random.random() - 0.5) * 1000])
|
||||
data.append(dic)
|
||||
# Stay still
|
||||
for i in range(100):
|
||||
if i > 80:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"}
|
||||
elif i > 60:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"}
|
||||
else:
|
||||
dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"}
|
||||
start_x = (random.random() - 0.5) * 2000
|
||||
start_y = (random.random() - 0.5) * 2000
|
||||
start_z = (random.random() - 0.5) * 2000
|
||||
for j in range(128):
|
||||
dic[DATA_NAME].append([
|
||||
start_x + (random.random() - 0.5) * 40,
|
||||
start_y + (random.random() - 0.5) * 40,
|
||||
start_z + (random.random() - 0.5) * 40
|
||||
])
|
||||
data.append(dic)
|
||||
|
||||
|
||||
# Write data to file
|
||||
def write_data(data_to_write, path):
|
||||
with open(path, "w") as f:
|
||||
for idx, item in enumerate(data_to_write): # pylint: disable=unused-variable
|
||||
dic = json.dumps(item, ensure_ascii=False)
|
||||
f.write(dic)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = []
|
||||
for idx1, folder in enumerate(folders):
|
||||
for idx2, name in enumerate(names):
|
||||
prepare_original_data(folder, name, data,
|
||||
"./%s/output_%s_%s.txt" % (folder, folder, name))
|
||||
for idx in range(5):
|
||||
prepare_original_data("negative", "negative%d" % (idx + 1), data,
|
||||
"./negative/output_negative_%d.txt" % (idx + 1))
|
||||
generate_negative_data(data)
|
||||
print("data_length: " + str(len(data)))
|
||||
if not os.path.exists("./data"):
|
||||
os.makedirs("./data")
|
||||
write_data(data, "./data/complete_data")
|
|
@ -0,0 +1,75 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Test for data_prepare.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from data_prepare import generate_negative_data
|
||||
from data_prepare import prepare_original_data
|
||||
from data_prepare import write_data
|
||||
|
||||
|
||||
class TestPrepare(unittest.TestCase):
|
||||
|
||||
def setUp(self): # pylint: disable=g-missing-super-call
|
||||
self.file = "./%s/output_%s_%s.txt" % (folders[0], folders[0], names[0]) # pylint: disable=undefined-variable
|
||||
self.data = []
|
||||
prepare_original_data(folders[0], names[0], self.data, self.file) # pylint: disable=undefined-variable
|
||||
|
||||
def test_prepare_data(self):
|
||||
num = 0
|
||||
with open(self.file, "r") as f:
|
||||
lines = csv.reader(f)
|
||||
for idx, line in enumerate(lines): # pylint: disable=unused-variable
|
||||
if len(line) == 3 and line[2] == "-":
|
||||
num += 1
|
||||
self.assertEqual(len(self.data), num)
|
||||
self.assertIsInstance(self.data, list)
|
||||
self.assertIsInstance(self.data[0], dict)
|
||||
self.assertEqual(list(self.data[-1]), ["gesture", "accel_ms2_xyz", "name"])
|
||||
self.assertEqual(self.data[0]["name"], names[0]) # pylint: disable=undefined-variable
|
||||
|
||||
def test_generate_negative(self):
|
||||
original_len = len(self.data)
|
||||
generate_negative_data(self.data)
|
||||
self.assertEqual(original_len + 300, len(self.data))
|
||||
generated_num = 0
|
||||
for idx, data in enumerate(self.data): # pylint: disable=unused-variable
|
||||
if data["name"] == "negative6" or data["name"] == "negative7" or data[
|
||||
"name"] == "negative8":
|
||||
generated_num += 1
|
||||
self.assertEqual(generated_num, 300)
|
||||
|
||||
def test_write_data(self):
|
||||
data_path_test = "./data/data0"
|
||||
write_data(self.data, data_path_test)
|
||||
with open(data_path_test, "r") as f:
|
||||
lines = f.readlines()
|
||||
self.assertEqual(len(lines), len(self.data))
|
||||
self.assertEqual(json.loads(lines[0]), self.data[0])
|
||||
self.assertEqual(json.loads(lines[-1]), self.data[-1])
|
||||
os.remove(data_path_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
90
samples/modules/tensorflow/magic_wand/train/data_split.py
Normal file
90
samples/modules/tensorflow/magic_wand/train/data_split.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Lint as: python3
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Mix and split data.
|
||||
|
||||
Mix different people's data together and randomly split them into train,
|
||||
validation and test. These data would be saved separately under "/data".
|
||||
It will generate new files with the following structure:
|
||||
|
||||
├── data
|
||||
│ ├── complete_data
|
||||
│ ├── test
|
||||
│ ├── train
|
||||
│ └── valid
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import random
|
||||
from data_prepare import write_data
|
||||
|
||||
|
||||
# Read data
|
||||
def read_data(path):
|
||||
data = []
|
||||
with open(path, "r") as f:
|
||||
lines = f.readlines()
|
||||
for idx, line in enumerate(lines): # pylint: disable=unused-variable
|
||||
dic = json.loads(line)
|
||||
data.append(dic)
|
||||
print("data_length:" + str(len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def split_data(data, train_ratio, valid_ratio):
|
||||
"""Splits data into train, validation and test according to ratio."""
|
||||
train_data = []
|
||||
valid_data = []
|
||||
test_data = []
|
||||
num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
|
||||
for idx, item in enumerate(data): # pylint: disable=unused-variable
|
||||
for i in num_dic:
|
||||
if item["gesture"] == i:
|
||||
num_dic[i] += 1
|
||||
print(num_dic)
|
||||
train_num_dic = {}
|
||||
valid_num_dic = {}
|
||||
for i in num_dic:
|
||||
train_num_dic[i] = int(train_ratio * num_dic[i])
|
||||
valid_num_dic[i] = int(valid_ratio * num_dic[i])
|
||||
random.seed(30)
|
||||
random.shuffle(data)
|
||||
for idx, item in enumerate(data):
|
||||
for i in num_dic:
|
||||
if item["gesture"] == i:
|
||||
if train_num_dic[i] > 0:
|
||||
train_data.append(item)
|
||||
train_num_dic[i] -= 1
|
||||
elif valid_num_dic[i] > 0:
|
||||
valid_data.append(item)
|
||||
valid_num_dic[i] -= 1
|
||||
else:
|
||||
test_data.append(item)
|
||||
print("train_length:" + str(len(train_data)))
|
||||
print("test_length:" + str(len(test_data)))
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = read_data("./data/complete_data")
|
||||
train_data, valid_data, test_data = split_data(data, 0.6, 0.2)
|
||||
write_data(train_data, "./data/train")
|
||||
write_data(valid_data, "./data/valid")
|
||||
write_data(test_data, "./data/test")
|
|
@ -0,0 +1,75 @@
|
|||
# Lint as: python3
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Split data into train, validation and test dataset according to person.
|
||||
|
||||
That is, use some people's data as train, some other people's data as
|
||||
validation, and the rest ones' data as test. These data would be saved
|
||||
separately under "/person_split".
|
||||
|
||||
It will generate new files with the following structure:
|
||||
├──person_split
|
||||
│ ├── test
|
||||
│ ├── train
|
||||
│ └──valid
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
from data_split import read_data
|
||||
from data_split import write_data
|
||||
|
||||
|
||||
def person_split(whole_data, train_names, valid_names, test_names):
|
||||
"""Split data by person."""
|
||||
random.seed(30)
|
||||
random.shuffle(whole_data)
|
||||
train_data = []
|
||||
valid_data = []
|
||||
test_data = []
|
||||
for idx, data in enumerate(whole_data): # pylint: disable=unused-variable
|
||||
if data["name"] in train_names:
|
||||
train_data.append(data)
|
||||
elif data["name"] in valid_names:
|
||||
valid_data.append(data)
|
||||
elif data["name"] in test_names:
|
||||
test_data.append(data)
|
||||
print("train_length:" + str(len(train_data)))
|
||||
print("valid_length:" + str(len(valid_data)))
|
||||
print("test_length:" + str(len(test_data)))
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = read_data("./data/complete_data")
|
||||
train_names = [
|
||||
"hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3",
|
||||
"negative4", "negative5", "negative6"
|
||||
]
|
||||
valid_names = ["lsj", "pengxl", "negative2", "negative7"]
|
||||
test_names = ["liucx", "zhangxy", "negative1", "negative8"]
|
||||
train_data, valid_data, test_data = person_split(data, train_names,
|
||||
valid_names, test_names)
|
||||
if not os.path.exists("./person_split"):
|
||||
os.makedirs("./person_split")
|
||||
write_data(train_data, "./person_split/train")
|
||||
write_data(valid_data, "./person_split/valid")
|
||||
write_data(test_data, "./person_split/test")
|
|
@ -0,0 +1,54 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Test for data_split_person.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from data_split_person import person_split
|
||||
from data_split_person import read_data
|
||||
|
||||
|
||||
class TestSplitPerson(unittest.TestCase):
|
||||
|
||||
def setUp(self): # pylint: disable=g-missing-super-call
|
||||
self.data = read_data("./data/complete_data")
|
||||
|
||||
def test_person_split(self):
|
||||
train_names = ["dengyl"]
|
||||
valid_names = ["liucx"]
|
||||
test_names = ["tangsy"]
|
||||
dengyl_num = 63
|
||||
liucx_num = 63
|
||||
tangsy_num = 30
|
||||
train_data, valid_data, test_data = person_split(self.data, train_names,
|
||||
valid_names, test_names)
|
||||
self.assertEqual(len(train_data), dengyl_num)
|
||||
self.assertEqual(len(valid_data), liucx_num)
|
||||
self.assertEqual(len(test_data), tangsy_num)
|
||||
self.assertIsInstance(train_data, list)
|
||||
self.assertIsInstance(valid_data, list)
|
||||
self.assertIsInstance(test_data, list)
|
||||
self.assertIsInstance(train_data[0], dict)
|
||||
self.assertIsInstance(valid_data[0], dict)
|
||||
self.assertIsInstance(test_data[0], dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,77 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Test for data_split.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from data_split import read_data
|
||||
from data_split import split_data
|
||||
|
||||
|
||||
class TestSplit(unittest.TestCase):
|
||||
|
||||
def setUp(self): # pylint: disable=g-missing-super-call
|
||||
self.data = read_data("./data/complete_data")
|
||||
self.num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
|
||||
with open("./data/complete_data", "r") as f:
|
||||
lines = f.readlines()
|
||||
self.num = len(lines)
|
||||
|
||||
def test_read_data(self):
|
||||
self.assertEqual(len(self.data), self.num)
|
||||
self.assertIsInstance(self.data, list)
|
||||
self.assertIsInstance(self.data[0], dict)
|
||||
self.assertEqual(
|
||||
set(list(self.data[-1])), set(["gesture", "accel_ms2_xyz", "name"]))
|
||||
|
||||
def test_split_data(self):
|
||||
with open("./data/complete_data", "r") as f:
|
||||
lines = f.readlines()
|
||||
for idx, line in enumerate(lines): # pylint: disable=unused-variable
|
||||
dic = json.loads(line)
|
||||
for ges in self.num_dic:
|
||||
if dic["gesture"] == ges:
|
||||
self.num_dic[ges] += 1
|
||||
train_data_0, valid_data_0, test_data_100 = split_data(self.data, 0, 0)
|
||||
train_data_50, valid_data_50, test_data_0 = split_data(self.data, 0.5, 0.5)
|
||||
train_data_60, valid_data_20, test_data_20 = split_data(self.data, 0.6, 0.2)
|
||||
len_60 = int(self.num_dic["wing"] * 0.6) + int(
|
||||
self.num_dic["ring"] * 0.6) + int(self.num_dic["slope"] * 0.6) + int(
|
||||
self.num_dic["negative"] * 0.6)
|
||||
len_50 = int(self.num_dic["wing"] * 0.5) + int(
|
||||
self.num_dic["ring"] * 0.5) + int(self.num_dic["slope"] * 0.5) + int(
|
||||
self.num_dic["negative"] * 0.5)
|
||||
len_20 = int(self.num_dic["wing"] * 0.2) + int(
|
||||
self.num_dic["ring"] * 0.2) + int(self.num_dic["slope"] * 0.2) + int(
|
||||
self.num_dic["negative"] * 0.2)
|
||||
self.assertEqual(len(train_data_0), 0)
|
||||
self.assertEqual(len(train_data_50), len_50)
|
||||
self.assertEqual(len(train_data_60), len_60)
|
||||
self.assertEqual(len(valid_data_0), 0)
|
||||
self.assertEqual(len(valid_data_50), len_50)
|
||||
self.assertEqual(len(valid_data_20), len_20)
|
||||
self.assertEqual(len(test_data_100), self.num)
|
||||
self.assertEqual(len(test_data_0), (self.num - 2 * len_50))
|
||||
self.assertEqual(len(test_data_20), (self.num - len_60 - len_20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
numpy==1.16.2
|
||||
tensorflow==2.0.0-beta1
|
202
samples/modules/tensorflow/magic_wand/train/train.py
Normal file
202
samples/modules/tensorflow/magic_wand/train/train.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
"""Build and train neural networks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os # pylint: disable=duplicate-code
|
||||
from data_load import DataLoader
|
||||
|
||||
import numpy as np # pylint: disable=duplicate-code
|
||||
import tensorflow as tf
|
||||
|
||||
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
|
||||
|
||||
|
||||
def reshape_function(data, label):
|
||||
reshaped_data = tf.reshape(data, [-1, 3, 1])
|
||||
return reshaped_data, label
|
||||
|
||||
|
||||
def calculate_model_size(model):
|
||||
print(model.summary())
|
||||
var_sizes = [
|
||||
np.product(list(map(int, v.shape))) * v.dtype.size
|
||||
for v in model.trainable_variables
|
||||
]
|
||||
print("Model size:", sum(var_sizes) / 1024, "KB")
|
||||
|
||||
|
||||
def build_cnn(seq_length):
|
||||
"""Builds a convolutional neural network in Keras."""
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(
|
||||
8, (4, 3),
|
||||
padding="same",
|
||||
activation="relu",
|
||||
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
|
||||
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
|
||||
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
|
||||
tf.keras.layers.Conv2D(16, (4, 1), padding="same",
|
||||
activation="relu"), # (batch, 42, 1, 16)
|
||||
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
|
||||
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
|
||||
tf.keras.layers.Flatten(), # (batch, 224)
|
||||
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
|
||||
tf.keras.layers.Dropout(0.1), # (batch, 16)
|
||||
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
|
||||
])
|
||||
model_path = os.path.join("./netmodels", "CNN")
|
||||
print("Built CNN.")
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
model.load_weights("./netmodels/CNN/weights.h5")
|
||||
return model, model_path
|
||||
|
||||
|
||||
def build_lstm(seq_length):
|
||||
"""Builds an LSTM in Keras."""
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Bidirectional(
|
||||
tf.keras.layers.LSTM(22),
|
||||
input_shape=(seq_length, 3)), # output_shape=(batch, 44)
|
||||
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
|
||||
])
|
||||
model_path = os.path.join("./netmodels", "LSTM")
|
||||
print("Built LSTM.")
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
return model, model_path
|
||||
|
||||
|
||||
def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
|
||||
data_loader = DataLoader(
|
||||
train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
|
||||
data_loader.format()
|
||||
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
|
||||
data_loader.valid_data, data_loader.test_len, data_loader.test_data
|
||||
|
||||
|
||||
def build_net(args, seq_length):
|
||||
if args.model == "CNN":
|
||||
model, model_path = build_cnn(seq_length)
|
||||
elif args.model == "LSTM":
|
||||
model, model_path = build_lstm(seq_length)
|
||||
else:
|
||||
print("Please input correct model name.(CNN LSTM)")
|
||||
return model, model_path
|
||||
|
||||
|
||||
def train_net(
|
||||
model,
|
||||
model_path, # pylint: disable=unused-argument
|
||||
train_len, # pylint: disable=unused-argument
|
||||
train_data,
|
||||
valid_len,
|
||||
valid_data,
|
||||
test_len,
|
||||
test_data,
|
||||
kind):
|
||||
"""Trains the model."""
|
||||
calculate_model_size(model)
|
||||
epochs = 50
|
||||
batch_size = 64
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy"])
|
||||
if kind == "CNN":
|
||||
train_data = train_data.map(reshape_function)
|
||||
test_data = test_data.map(reshape_function)
|
||||
valid_data = valid_data.map(reshape_function)
|
||||
test_labels = np.zeros(test_len)
|
||||
idx = 0
|
||||
for data, label in test_data: # pylint: disable=unused-variable
|
||||
test_labels[idx] = label.numpy()
|
||||
idx += 1
|
||||
train_data = train_data.batch(batch_size).repeat()
|
||||
valid_data = valid_data.batch(batch_size)
|
||||
test_data = test_data.batch(batch_size)
|
||||
model.fit(
|
||||
train_data,
|
||||
epochs=epochs,
|
||||
validation_data=valid_data,
|
||||
steps_per_epoch=1000,
|
||||
validation_steps=int((valid_len - 1) / batch_size + 1),
|
||||
callbacks=[tensorboard_callback])
|
||||
loss, acc = model.evaluate(test_data)
|
||||
pred = np.argmax(model.predict(test_data), axis=1)
|
||||
confusion = tf.math.confusion_matrix(
|
||||
labels=tf.constant(test_labels),
|
||||
predictions=tf.constant(pred),
|
||||
num_classes=4)
|
||||
print(confusion)
|
||||
print("Loss {}, Accuracy {}".format(loss, acc))
|
||||
# Convert the model to the TensorFlow Lite format without quantization
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Save the model to disk
|
||||
open("model.tflite", "wb").write(tflite_model)
|
||||
|
||||
# Convert the model to the TensorFlow Lite format with quantization
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Save the model to disk
|
||||
open("model_quantized.tflite", "wb").write(tflite_model)
|
||||
|
||||
basic_model_size = os.path.getsize("model.tflite")
|
||||
print("Basic model is %d bytes" % basic_model_size)
|
||||
quantized_model_size = os.path.getsize("model_quantized.tflite")
|
||||
print("Quantized model is %d bytes" % quantized_model_size)
|
||||
difference = basic_model_size - quantized_model_size
|
||||
print("Difference is %d bytes" % difference)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", "-m")
|
||||
parser.add_argument("--person", "-p")
|
||||
args = parser.parse_args()
|
||||
|
||||
seq_length = 128
|
||||
|
||||
print("Start to load data...")
|
||||
if args.person == "true":
|
||||
train_len, train_data, valid_len, valid_data, test_len, test_data = \
|
||||
load_data("./person_split/train", "./person_split/valid",
|
||||
"./person_split/test", seq_length)
|
||||
else:
|
||||
train_len, train_data, valid_len, valid_data, test_len, test_data = \
|
||||
load_data("./data/train", "./data/valid", "./data/test", seq_length)
|
||||
|
||||
print("Start to build net...")
|
||||
model, model_path = build_net(args, seq_length)
|
||||
|
||||
print("Start training...")
|
||||
train_net(model, model_path, train_len, train_data, valid_len, valid_data,
|
||||
test_len, test_data, args.model)
|
||||
|
||||
print("Training finished!")
|
|
@ -0,0 +1,257 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"source": [
|
||||
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
" http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License."
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "1BtkMGSYQOTQ"
|
||||
},
|
||||
"source": [
|
||||
"# Train a gesture recognition model for microcontroller use"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "BaFfr7DHRmGF"
|
||||
},
|
||||
"source": [
|
||||
"This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n",
|
||||
"\n",
|
||||
"The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
|
||||
"\n",
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
"</table>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "xXgS6rxyT7Qk"
|
||||
},
|
||||
"source": [
|
||||
"Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "LG6ErX5FRIaV"
|
||||
},
|
||||
"source": [
|
||||
"## Configure dependencies\n",
|
||||
"\n",
|
||||
"Run the following cell to ensure the correct version of TensorFlow is used."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "STNft9TrfoVh"
|
||||
},
|
||||
"source": [
|
||||
"We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ygkWw73dRNda"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Clone the repository from GitHub\n",
|
||||
"!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n",
|
||||
"# Copy the training scripts into our workspace\n",
|
||||
"!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "pXI7R4RehFdU"
|
||||
},
|
||||
"source": [
|
||||
"## Prepare the data\n",
|
||||
"\n",
|
||||
"Next, we'll download the data and extract it into the expected location within the training scripts' directory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "W2Sg2AKzVr2L"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download the data we will use to train the model\n",
|
||||
"!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n",
|
||||
"# Extract the data into the train directory\n",
|
||||
"!tar xvzf data.tar.gz -C train 1>/dev/null"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "DNjukI1Sgl2C"
|
||||
},
|
||||
"source": [
|
||||
"We'll then run the scripts that split the data into training, validation, and test sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "XBqSVpi6Vxss"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The scripts must be run from within the train directory\n",
|
||||
"%cd train\n",
|
||||
"# Prepare the data\n",
|
||||
"!python data_prepare.py\n",
|
||||
"# Split the data by person\n",
|
||||
"!python data_split_person.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "5-cmVbFvhTvy"
|
||||
},
|
||||
"source": [
|
||||
"## Load TensorBoard\n",
|
||||
"\n",
|
||||
"Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "CCx6SN9NWRPw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load TensorBoard\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir logs/scalars"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "ERC2Cr4PhaOl"
|
||||
},
|
||||
"source": [
|
||||
"## Begin training\n",
|
||||
"\n",
|
||||
"The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "DXmQZgbuWQFO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!python train.py --model CNN --person true"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "4gXbVzcXhvGD"
|
||||
},
|
||||
"source": [
|
||||
"## Create a C source file\n",
|
||||
"\n",
|
||||
"The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n",
|
||||
"\n",
|
||||
"In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "8wgei4OGe3Nz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install xxd if it is not available\n",
|
||||
"!apt-get -qq install xxd\n",
|
||||
"# Save the file as a C source file\n",
|
||||
"!xxd -i model.tflite > /content/model.cc\n",
|
||||
"# Print the source file\n",
|
||||
"!cat /content/model.cc"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "Train a gesture recognition model for microcontroller use",
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
78
samples/modules/tensorflow/magic_wand/train/train_test.py
Normal file
78
samples/modules/tensorflow/magic_wand/train/train_test.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Test for train.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from train import build_cnn
|
||||
from train import build_lstm
|
||||
from train import load_data
|
||||
from train import reshape_function
|
||||
|
||||
|
||||
class TestTrain(unittest.TestCase):
|
||||
|
||||
def setUp(self): # pylint: disable=g-missing-super-call
|
||||
self.seq_length = 128
|
||||
self.train_len, self.train_data, self.valid_len, self.valid_data, \
|
||||
self.test_len, self.test_data = \
|
||||
load_data("./data/train", "./data/valid", "./data/test",
|
||||
self.seq_length)
|
||||
|
||||
def test_load_data(self):
|
||||
self.assertIsInstance(self.train_data, tf.data.Dataset)
|
||||
self.assertIsInstance(self.valid_data, tf.data.Dataset)
|
||||
self.assertIsInstance(self.test_data, tf.data.Dataset)
|
||||
|
||||
def test_build_net(self):
|
||||
cnn, cnn_path = build_cnn(self.seq_length)
|
||||
lstm, lstm_path = build_lstm(self.seq_length)
|
||||
cnn_data = np.random.rand(60, 128, 3, 1)
|
||||
lstm_data = np.random.rand(60, 128, 3)
|
||||
cnn_prob = cnn(tf.constant(cnn_data, dtype="float32")).numpy()
|
||||
lstm_prob = lstm(tf.constant(lstm_data, dtype="float32")).numpy()
|
||||
self.assertIsInstance(cnn, tf.keras.Sequential)
|
||||
self.assertIsInstance(lstm, tf.keras.Sequential)
|
||||
self.assertEqual(cnn_path, "./netmodels/CNN")
|
||||
self.assertEqual(lstm_path, "./netmodels/LSTM")
|
||||
self.assertEqual(cnn_prob.shape, (60, 4))
|
||||
self.assertEqual(lstm_prob.shape, (60, 4))
|
||||
|
||||
def test_reshape_function(self):
|
||||
for data, label in self.train_data:
|
||||
original_data_shape = data.numpy().shape
|
||||
original_label_shape = label.numpy().shape
|
||||
break
|
||||
self.train_data = self.train_data.map(reshape_function)
|
||||
for data, label in self.train_data:
|
||||
reshaped_data_shape = data.numpy().shape
|
||||
reshaped_label_shape = label.numpy().shape
|
||||
break
|
||||
self.assertEqual(
|
||||
reshaped_data_shape,
|
||||
(int(original_data_shape[0] * original_data_shape[1] / 3), 3, 1))
|
||||
self.assertEqual(reshaped_label_shape, original_label_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue