From 14e6d1cf5d27d94419b460a68d150aae6bacced9 Mon Sep 17 00:00:00 2001 From: Lauren Murphy Date: Mon, 26 Apr 2021 22:07:44 -0500 Subject: [PATCH] samples: add tensorflow magic wand sample train scripts Adds training scripts to TensorFlow Magic Wand sample. Signed-off-by: Lauren Murphy --- .../tensorflow/magic_wand/train/README.md | 191 +++++++++++++ .../magic_wand/train/data_augmentation.py | 74 +++++ .../train/data_augmentation_test.py | 58 ++++ .../tensorflow/magic_wand/train/data_load.py | 106 ++++++++ .../magic_wand/train/data_load_test.py | 95 +++++++ .../magic_wand/train/data_prepare.py | 164 +++++++++++ .../magic_wand/train/data_prepare_test.py | 75 +++++ .../tensorflow/magic_wand/train/data_split.py | 90 ++++++ .../magic_wand/train/data_split_person.py | 75 +++++ .../train/data_split_person_test.py | 54 ++++ .../magic_wand/train/data_split_test.py | 77 ++++++ .../magic_wand/train/netmodels/CNN/weights.h5 | Bin 0 -> 40512 bytes .../magic_wand/train/requirements.txt | 2 + .../tensorflow/magic_wand/train/train.py | 202 ++++++++++++++ .../train/train_magic_wand_model.ipynb | 257 ++++++++++++++++++ .../tensorflow/magic_wand/train/train_test.py | 78 ++++++ 16 files changed, 1598 insertions(+) create mode 100644 samples/modules/tensorflow/magic_wand/train/README.md create mode 100644 samples/modules/tensorflow/magic_wand/train/data_augmentation.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_load.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_load_test.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_prepare.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_prepare_test.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_split.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_split_person.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_split_person_test.py create mode 100644 samples/modules/tensorflow/magic_wand/train/data_split_test.py create mode 100644 samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 create mode 100644 samples/modules/tensorflow/magic_wand/train/requirements.txt create mode 100644 samples/modules/tensorflow/magic_wand/train/train.py create mode 100644 samples/modules/tensorflow/magic_wand/train/train_magic_wand_model.ipynb create mode 100644 samples/modules/tensorflow/magic_wand/train/train_test.py diff --git a/samples/modules/tensorflow/magic_wand/train/README.md b/samples/modules/tensorflow/magic_wand/train/README.md new file mode 100644 index 00000000000..55f1617c9a9 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/README.md @@ -0,0 +1,191 @@ +# Gesture Recognition Magic Wand Training Scripts + +## Introduction + +The scripts in this directory can be used to train a TensorFlow model that +classifies gestures based on accelerometer data. The code uses Python 3.7 and +TensorFlow 2.0. The resulting model is less than 20KB in size. + +The following document contains instructions on using the scripts to train a +model, and capturing your own training data. + +This project was inspired by the [Gesture Recognition Magic Wand](https://github.com/jewang/gesture-demo) +project by Jennifer Wang. + +## Training + +### Dataset + +Three magic gestures were chosen, and data were collected from 7 +different people. Some random long movement sequences were collected and divided +into shorter pieces, which made up "negative" data along with some other +automatically generated random data. + +The dataset can be downloaded from the following URL: + +[download.tensorflow.org/models/tflite/magic_wand/data.tar.gz](http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz) + +### Training in Colab + +The following [Google Colaboratory](https://colab.research.google.com) +notebook demonstrates how to train the model. It's the easiest way to get +started: + + + + +
+ Run in Google Colab + + View source on GitHub +
+ +If you'd prefer to run the scripts locally, use the following instructions. + +### Running the scripts + +Use the following command to install the required dependencies: + +```shell +pip install -r requirements.txt +``` + +There are two ways to train the model: + +- Random data split, which mixes different people's data together and randomly + splits them into training, validation, and test sets +- Person data split, which splits the data by person + +#### Random data split + +Using a random split results in higher training accuracy than a person split, +but inferior performance on new data. + +```shell +$ python data_prepare.py + +$ python data_split.py + +$ python train.py --model CNN --person false +``` + +#### Person data split + +Using a person data split results in lower training accuracy but better +performance on new data. + +```shell +$ python data_prepare.py + +$ python data_split_person.py + +$ python train.py --model CNN --person true +``` + +#### Model type + +In the `--model` argument, you can can provide `CNN` or `LSTM`. The CNN +model has a smaller size and lower latency. + +## Collecting new data + +To obtain new training data using the +[SparkFun Edge development board](https://sparkfun.com/products/15170), you can +modify one of the examples in the [SparkFun Edge BSP](https://github.com/sparkfun/SparkFun_Edge_BSP) +and deploy it using the Ambiq SDK. + +### Install the Ambiq SDK and SparkFun Edge BSP + +Follow SparkFun's +[Using SparkFun Edge Board with Ambiq Apollo3 SDK](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all) +guide to set up the Ambiq SDK and SparkFun Edge BSP. + +#### Modify the example code + +First, `cd` into +`AmbiqSuite-Rel2.2.0/boards/SparkFun_Edge_BSP/examples/example1_edge_test`. + +##### Modify `src/tf_adc/tf_adc.c` + +Add `true` in line 62 as the second parameter of function +`am_hal_adc_samples_read`. + +##### Modify `src/main.c` + +Add the line below in `int main(void)`, just before the line `while(1)`: + +```cc +am_util_stdio_printf("-,-,-\r\n"); +``` + +Change the following lines in `while(1){...}` + +```cc +am_util_stdio_printf("Acc [mg] %04.2f x, %04.2f y, %04.2f z, Temp [deg C] %04.2f, MIC0 [counts / 2^14] %d\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2], temperature_degC, (audioSample) ); +``` + +to: + +```cc +am_util_stdio_printf("%04.2f,%04.2f,%04.2f\r\n", acceleration_mg[0], acceleration_mg[1], acceleration_mg[2]); +``` + +#### Flash the binary + +Follow the instructions in +[SparkFun's guide](https://learn.sparkfun.com/tutorials/using-sparkfun-edge-board-with-ambiq-apollo3-sdk/all#example-applications) +to flash the binary to the device. + +#### Collect accelerometer data + +First, in a new terminal window, run the following command to begin logging +output to `output.txt`: + +```shell +$ script output.txt +``` + +Next, in the same window, use `screen` to connect to the device: + +```shell +$ screen ${DEVICENAME} 115200 +``` + +Output information collected from accelerometer sensor will be shown on the +screen and saved in `output.txt`, in the format of "x,y,z" per line. + +Press the `RST` button to start capturing a new gesture, then press Button 14 +when it ends. New data will begin with a line "-,-,-". + +To exit `screen`, hit +Ctrl\\+A+, immediately followed by the +K+ key, +then hit the +Y+ key. Then run + +```shell +$ exit +``` + +to stop logging data. Data will be saved in `output.txt`. For compatibility +with the training scripts, change the file name to include person's name and +the gesture name, in the following format: + +``` +output_{gesture_name}_{person_name}.txt +``` + +#### Edit and run the scripts + +Edit the following files to include your new gesture names (replacing +"wing", "ring", and "slope") + +- `data_load.py` +- `data_prepare.py` +- `data_split.py` + +Edit the following files to include your new person names (replacing "hyw", +"shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "lsj", "pengxl", "liucx", +and "zhangxy"): + +- `data_prepare.py` +- `data_split_person.py` + +Finally, run the commands described earlier to train a new model. diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py new file mode 100644 index 00000000000..45d08c3fec6 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation.py @@ -0,0 +1,74 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Data augmentation that will be used in data_load.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np + + +def time_wrapping(molecule, denominator, data): + """Generate (molecule/denominator)x speed data.""" + tmp_data = [[0 + for i in range(len(data[0]))] + for j in range((int(len(data) / molecule) - 1) * denominator)] + for i in range(int(len(data) / molecule) - 1): + for j in range(len(data[i])): + for k in range(denominator): + tmp_data[denominator * i + + k][j] = (data[molecule * i + k][j] * (denominator - k) + + data[molecule * i + k + 1][j] * k) / denominator + return tmp_data + + +def augment_data(original_data, original_label): + """Perform data augmentation.""" + new_data = [] + new_label = [] + for idx, (data, label) in enumerate(zip(original_data, original_label)): # pylint: disable=unused-variable + # Original data + new_data.append(data) + new_label.append(label) + # Sequence shift + for num in range(5): # pylint: disable=unused-variable + new_data.append((np.array(data, dtype=np.float32) + + (random.random() - 0.5) * 200).tolist()) + new_label.append(label) + # Random noise + tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))] + for num in range(5): + for i in range(len(tmp_data)): # pylint: disable=consider-using-enumerate + for j in range(len(tmp_data[i])): + tmp_data[i][j] = data[i][j] + 5 * random.random() + new_data.append(tmp_data) + new_label.append(label) + # Time warping + fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)] + for molecule, denominator in fractions: + new_data.append(time_wrapping(molecule, denominator, data)) + new_label.append(label) + # Movement amplification + for molecule, denominator in fractions: + new_data.append( + (np.array(data, dtype=np.float32) * molecule / denominator).tolist()) + new_label.append(label) + return new_data, new_label diff --git a/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py new file mode 100644 index 00000000000..8da07e2e580 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_augmentation_test.py @@ -0,0 +1,58 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Test for data_augmentation.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np + +from data_augmentation import augment_data +from data_augmentation import time_wrapping + + +class TestAugmentation(unittest.TestCase): + + def test_time_wrapping(self): + original_data = np.random.rand(10, 3).tolist() + wrapped_data = time_wrapping(4, 5, original_data) + self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5) + self.assertEqual(len(wrapped_data[0]), len(original_data[0])) + + def test_augment_data(self): + original_data = [ + np.random.rand(128, 3).tolist(), + np.random.rand(66, 2).tolist(), + np.random.rand(9, 1).tolist() + ] + original_label = ["data", "augmentation", "test"] + augmented_data, augmented_label = augment_data(original_data, + original_label) + self.assertEqual(25 * len(original_data), len(augmented_data)) + self.assertIsInstance(augmented_data, list) + self.assertEqual(25 * len(original_label), len(augmented_label)) + self.assertIsInstance(augmented_label, list) + for i in range(len(original_label)): # pylint: disable=consider-using-enumerate + self.assertEqual(augmented_label[25 * i], original_label[i]) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_load.py b/samples/modules/tensorflow/magic_wand/train/data_load.py new file mode 100644 index 00000000000..35ff4825a8b --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_load.py @@ -0,0 +1,106 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Load data from the specified paths and format them for training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +import numpy as np +import tensorflow as tf + +from data_augmentation import augment_data + +LABEL_NAME = "gesture" +DATA_NAME = "accel_ms2_xyz" + + +class DataLoader(object): + """Loads data and prepares for training.""" + + def __init__(self, train_data_path, valid_data_path, test_data_path, + seq_length): + self.dim = 3 + self.seq_length = seq_length + self.label2id = {"wing": 0, "ring": 1, "slope": 2, "negative": 3} + self.train_data, self.train_label, self.train_len = self.get_data_file( + train_data_path, "train") + self.valid_data, self.valid_label, self.valid_len = self.get_data_file( + valid_data_path, "valid") + self.test_data, self.test_label, self.test_len = self.get_data_file( + test_data_path, "test") + + def get_data_file(self, data_path, data_type): # pylint: disable=no-self-use + """Get train, valid and test data from files.""" + data = [] + label = [] + with open(data_path, "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + data.append(dic[DATA_NAME]) + label.append(dic[LABEL_NAME]) + if data_type == "train": + data, label = augment_data(data, label) + length = len(label) + print(data_type + "_data_length:" + str(length)) + return data, label, length + + def pad(self, data, seq_length, dim): # pylint: disable=no-self-use + """Get neighbour padding.""" + noise_level = 20 + padded_data = [] + # Before- Neighbour padding + tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[0] + tmp_data[(seq_length - + min(len(data), seq_length)):] = data[:min(len(data), seq_length)] + padded_data.append(tmp_data) + # After- Neighbour padding + tmp_data = (np.random.rand(seq_length, dim) - 0.5) * noise_level + data[-1] + tmp_data[:min(len(data), seq_length)] = data[:min(len(data), seq_length)] + padded_data.append(tmp_data) + return padded_data + + def format_support_func(self, padded_num, length, data, label): + """Support function for format.(Helps format train, valid and test.)""" + # Add 2 padding, initialize data and label + length *= padded_num + features = np.zeros((length, self.seq_length, self.dim)) + labels = np.zeros(length) + # Get padding for train, valid and test + for idx, (data, label) in enumerate(zip(data, label)): # pylint: disable=redefined-argument-from-local + padded_data = self.pad(data, self.seq_length, self.dim) + for num in range(padded_num): + features[padded_num * idx + num] = padded_data[num] + labels[padded_num * idx + num] = self.label2id[label] + # Turn into tf.data.Dataset + dataset = tf.data.Dataset.from_tensor_slices( + (features, labels.astype("int32"))) + return length, dataset + + def format(self): + """Format data(including padding, etc.) and get the dataset for the model.""" + padded_num = 2 + self.train_len, self.train_data = self.format_support_func( + padded_num, self.train_len, self.train_data, self.train_label) + self.valid_len, self.valid_data = self.format_support_func( + padded_num, self.valid_len, self.valid_data, self.valid_label) + self.test_len, self.test_data = self.format_support_func( + padded_num, self.test_len, self.test_data, self.test_label) diff --git a/samples/modules/tensorflow/magic_wand/train/data_load_test.py b/samples/modules/tensorflow/magic_wand/train/data_load_test.py new file mode 100644 index 00000000000..82864974dad --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_load_test.py @@ -0,0 +1,95 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-bad-import-order + +"""Test for data_load.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from data_load import DataLoader + +import tensorflow as tf + + +class TestLoad(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.loader = DataLoader( + "./data/train", "./data/valid", "./data/test", seq_length=512) + + def test_get_data(self): + self.assertIsInstance(self.loader.train_data, list) + self.assertIsInstance(self.loader.train_label, list) + self.assertIsInstance(self.loader.valid_data, list) + self.assertIsInstance(self.loader.valid_label, list) + self.assertIsInstance(self.loader.test_data, list) + self.assertIsInstance(self.loader.test_label, list) + self.assertEqual(self.loader.train_len, len(self.loader.train_data)) + self.assertEqual(self.loader.train_len, len(self.loader.train_label)) + self.assertEqual(self.loader.valid_len, len(self.loader.valid_data)) + self.assertEqual(self.loader.valid_len, len(self.loader.valid_label)) + self.assertEqual(self.loader.test_len, len(self.loader.test_data)) + self.assertEqual(self.loader.test_len, len(self.loader.test_label)) + + def test_pad(self): + original_data1 = [[2, 3], [1, 1]] + expected_data1_0 = [[2, 3], [2, 3], [2, 3], [2, 3], [1, 1]] + expected_data1_1 = [[2, 3], [1, 1], [1, 1], [1, 1], [1, 1]] + original_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333], + [9, 99], [-100, 0]] + expected_data2 = [[-2, 3], [-77, -681], [5, 6], [9, -7], [22, 3333]] + padding_data1 = self.loader.pad(original_data1, seq_length=5, dim=2) + padding_data2 = self.loader.pad(original_data2, seq_length=5, dim=2) + for i in range(len(padding_data1[0])): + for j in range(len(padding_data1[0].tolist()[0])): + self.assertLess( + abs(padding_data1[0].tolist()[i][j] - expected_data1_0[i][j]), + 10.001) + for i in range(len(padding_data1[1])): + for j in range(len(padding_data1[1].tolist()[0])): + self.assertLess( + abs(padding_data1[1].tolist()[i][j] - expected_data1_1[i][j]), + 10.001) + self.assertEqual(padding_data2[0].tolist(), expected_data2) + self.assertEqual(padding_data2[1].tolist(), expected_data2) + + def test_format(self): + self.loader.format() + expected_train_label = int(self.loader.label2id[self.loader.train_label[0]]) + expected_valid_label = int(self.loader.label2id[self.loader.valid_label[0]]) + expected_test_label = int(self.loader.label2id[self.loader.test_label[0]]) + for feature, label in self.loader.train_data: # pylint: disable=unused-variable + format_train_label = label.numpy() + break + for feature, label in self.loader.valid_data: + format_valid_label = label.numpy() + break + for feature, label in self.loader.test_data: + format_test_label = label.numpy() + break + self.assertEqual(expected_train_label, format_train_label) + self.assertEqual(expected_valid_label, format_valid_label) + self.assertEqual(expected_test_label, format_test_label) + self.assertIsInstance(self.loader.train_data, tf.data.Dataset) + self.assertIsInstance(self.loader.valid_data, tf.data.Dataset) + self.assertIsInstance(self.loader.test_data, tf.data.Dataset) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare.py b/samples/modules/tensorflow/magic_wand/train/data_prepare.py new file mode 100644 index 00000000000..727167f4d05 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_prepare.py @@ -0,0 +1,164 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Prepare data for further process. + +Read data from "/slope", "/ring", "/wing", "/negative" and save them +in "/data/complete_data" in python dict format. + +It will generate a new file with the following structure: +├── data +│   └── complete_data +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import json +import os +import random + +LABEL_NAME = "gesture" +DATA_NAME = "accel_ms2_xyz" +folders = ["wing", "ring", "slope"] +names = [ + "hyw", "shiyun", "tangsy", "dengyl", "zhangxy", "pengxl", "liucx", + "jiangyh", "xunkai" +] + + +def prepare_original_data(folder, name, data, file_to_read): + """Read collected data from files.""" + if folder != "negative": + with open(file_to_read, "r") as f: + lines = csv.reader(f) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + for idx, line in enumerate(lines): # pylint: disable=unused-variable + if len(line) == 3: + if line[2] == "-" and data_new[DATA_NAME]: + data.append(data_new) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + elif line[2] != "-": + data_new[DATA_NAME].append([float(i) for i in line[0:3]]) + data.append(data_new) + else: + with open(file_to_read, "r") as f: + lines = csv.reader(f) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + for idx, line in enumerate(lines): + if len(line) == 3 and line[2] != "-": + if len(data_new[DATA_NAME]) == 120: + data.append(data_new) + data_new = {} + data_new[LABEL_NAME] = folder + data_new[DATA_NAME] = [] + data_new["name"] = name + else: + data_new[DATA_NAME].append([float(i) for i in line[0:3]]) + data.append(data_new) + + +def generate_negative_data(data): + """Generate negative data labeled as 'negative6~8'.""" + # Big movement -> around straight line + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + start_x = (random.random() - 0.5) * 2000 + start_y = (random.random() - 0.5) * 2000 + start_z = (random.random() - 0.5) * 2000 + x_increase = (random.random() - 0.5) * 10 + y_increase = (random.random() - 0.5) * 10 + z_increase = (random.random() - 0.5) * 10 + for j in range(128): + dic[DATA_NAME].append([ + start_x + j * x_increase + (random.random() - 0.5) * 6, + start_y + j * y_increase + (random.random() - 0.5) * 6, + start_z + j * z_increase + (random.random() - 0.5) * 6 + ]) + data.append(dic) + # Random + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + for j in range(128): + dic[DATA_NAME].append([(random.random() - 0.5) * 1000, + (random.random() - 0.5) * 1000, + (random.random() - 0.5) * 1000]) + data.append(dic) + # Stay still + for i in range(100): + if i > 80: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative8"} + elif i > 60: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative7"} + else: + dic = {DATA_NAME: [], LABEL_NAME: "negative", "name": "negative6"} + start_x = (random.random() - 0.5) * 2000 + start_y = (random.random() - 0.5) * 2000 + start_z = (random.random() - 0.5) * 2000 + for j in range(128): + dic[DATA_NAME].append([ + start_x + (random.random() - 0.5) * 40, + start_y + (random.random() - 0.5) * 40, + start_z + (random.random() - 0.5) * 40 + ]) + data.append(dic) + + +# Write data to file +def write_data(data_to_write, path): + with open(path, "w") as f: + for idx, item in enumerate(data_to_write): # pylint: disable=unused-variable + dic = json.dumps(item, ensure_ascii=False) + f.write(dic) + f.write("\n") + + +if __name__ == "__main__": + data = [] + for idx1, folder in enumerate(folders): + for idx2, name in enumerate(names): + prepare_original_data(folder, name, data, + "./%s/output_%s_%s.txt" % (folder, folder, name)) + for idx in range(5): + prepare_original_data("negative", "negative%d" % (idx + 1), data, + "./negative/output_negative_%d.txt" % (idx + 1)) + generate_negative_data(data) + print("data_length: " + str(len(data))) + if not os.path.exists("./data"): + os.makedirs("./data") + write_data(data, "./data/complete_data") diff --git a/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py new file mode 100644 index 00000000000..4c0e3f0657b --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_prepare_test.py @@ -0,0 +1,75 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_prepare.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import json +import os +import unittest +from data_prepare import generate_negative_data +from data_prepare import prepare_original_data +from data_prepare import write_data + + +class TestPrepare(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.file = "./%s/output_%s_%s.txt" % (folders[0], folders[0], names[0]) # pylint: disable=undefined-variable + self.data = [] + prepare_original_data(folders[0], names[0], self.data, self.file) # pylint: disable=undefined-variable + + def test_prepare_data(self): + num = 0 + with open(self.file, "r") as f: + lines = csv.reader(f) + for idx, line in enumerate(lines): # pylint: disable=unused-variable + if len(line) == 3 and line[2] == "-": + num += 1 + self.assertEqual(len(self.data), num) + self.assertIsInstance(self.data, list) + self.assertIsInstance(self.data[0], dict) + self.assertEqual(list(self.data[-1]), ["gesture", "accel_ms2_xyz", "name"]) + self.assertEqual(self.data[0]["name"], names[0]) # pylint: disable=undefined-variable + + def test_generate_negative(self): + original_len = len(self.data) + generate_negative_data(self.data) + self.assertEqual(original_len + 300, len(self.data)) + generated_num = 0 + for idx, data in enumerate(self.data): # pylint: disable=unused-variable + if data["name"] == "negative6" or data["name"] == "negative7" or data[ + "name"] == "negative8": + generated_num += 1 + self.assertEqual(generated_num, 300) + + def test_write_data(self): + data_path_test = "./data/data0" + write_data(self.data, data_path_test) + with open(data_path_test, "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), len(self.data)) + self.assertEqual(json.loads(lines[0]), self.data[0]) + self.assertEqual(json.loads(lines[-1]), self.data[-1]) + os.remove(data_path_test) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_split.py b/samples/modules/tensorflow/magic_wand/train/data_split.py new file mode 100644 index 00000000000..5448ed90020 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split.py @@ -0,0 +1,90 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mix and split data. + +Mix different people's data together and randomly split them into train, +validation and test. These data would be saved separately under "/data". +It will generate new files with the following structure: + +├── data +│   ├── complete_data +│   ├── test +│   ├── train +│   └── valid +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import random +from data_prepare import write_data + + +# Read data +def read_data(path): + data = [] + with open(path, "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + data.append(dic) + print("data_length:" + str(len(data))) + return data + + +def split_data(data, train_ratio, valid_ratio): + """Splits data into train, validation and test according to ratio.""" + train_data = [] + valid_data = [] + test_data = [] + num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} + for idx, item in enumerate(data): # pylint: disable=unused-variable + for i in num_dic: + if item["gesture"] == i: + num_dic[i] += 1 + print(num_dic) + train_num_dic = {} + valid_num_dic = {} + for i in num_dic: + train_num_dic[i] = int(train_ratio * num_dic[i]) + valid_num_dic[i] = int(valid_ratio * num_dic[i]) + random.seed(30) + random.shuffle(data) + for idx, item in enumerate(data): + for i in num_dic: + if item["gesture"] == i: + if train_num_dic[i] > 0: + train_data.append(item) + train_num_dic[i] -= 1 + elif valid_num_dic[i] > 0: + valid_data.append(item) + valid_num_dic[i] -= 1 + else: + test_data.append(item) + print("train_length:" + str(len(train_data))) + print("test_length:" + str(len(test_data))) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + data = read_data("./data/complete_data") + train_data, valid_data, test_data = split_data(data, 0.6, 0.2) + write_data(train_data, "./data/train") + write_data(valid_data, "./data/valid") + write_data(test_data, "./data/test") diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person.py b/samples/modules/tensorflow/magic_wand/train/data_split_person.py new file mode 100644 index 00000000000..8634b2f385a --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_person.py @@ -0,0 +1,75 @@ +# Lint as: python3 +# coding=utf-8 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Split data into train, validation and test dataset according to person. + +That is, use some people's data as train, some other people's data as +validation, and the rest ones' data as test. These data would be saved +separately under "/person_split". + +It will generate new files with the following structure: +├──person_split +│   ├── test +│   ├── train +│   └──valid +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +from data_split import read_data +from data_split import write_data + + +def person_split(whole_data, train_names, valid_names, test_names): + """Split data by person.""" + random.seed(30) + random.shuffle(whole_data) + train_data = [] + valid_data = [] + test_data = [] + for idx, data in enumerate(whole_data): # pylint: disable=unused-variable + if data["name"] in train_names: + train_data.append(data) + elif data["name"] in valid_names: + valid_data.append(data) + elif data["name"] in test_names: + test_data.append(data) + print("train_length:" + str(len(train_data))) + print("valid_length:" + str(len(valid_data))) + print("test_length:" + str(len(test_data))) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + data = read_data("./data/complete_data") + train_names = [ + "hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3", + "negative4", "negative5", "negative6" + ] + valid_names = ["lsj", "pengxl", "negative2", "negative7"] + test_names = ["liucx", "zhangxy", "negative1", "negative8"] + train_data, valid_data, test_data = person_split(data, train_names, + valid_names, test_names) + if not os.path.exists("./person_split"): + os.makedirs("./person_split") + write_data(train_data, "./person_split/train") + write_data(valid_data, "./person_split/valid") + write_data(test_data, "./person_split/test") diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py new file mode 100644 index 00000000000..25ed8e5fe64 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_person_test.py @@ -0,0 +1,54 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_split_person.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from data_split_person import person_split +from data_split_person import read_data + + +class TestSplitPerson(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.data = read_data("./data/complete_data") + + def test_person_split(self): + train_names = ["dengyl"] + valid_names = ["liucx"] + test_names = ["tangsy"] + dengyl_num = 63 + liucx_num = 63 + tangsy_num = 30 + train_data, valid_data, test_data = person_split(self.data, train_names, + valid_names, test_names) + self.assertEqual(len(train_data), dengyl_num) + self.assertEqual(len(valid_data), liucx_num) + self.assertEqual(len(test_data), tangsy_num) + self.assertIsInstance(train_data, list) + self.assertIsInstance(valid_data, list) + self.assertIsInstance(test_data, list) + self.assertIsInstance(train_data[0], dict) + self.assertIsInstance(valid_data[0], dict) + self.assertIsInstance(test_data[0], dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/data_split_test.py b/samples/modules/tensorflow/magic_wand/train/data_split_test.py new file mode 100644 index 00000000000..3ab2f185c57 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/data_split_test.py @@ -0,0 +1,77 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for data_split.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import unittest +from data_split import read_data +from data_split import split_data + + +class TestSplit(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.data = read_data("./data/complete_data") + self.num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} + with open("./data/complete_data", "r") as f: + lines = f.readlines() + self.num = len(lines) + + def test_read_data(self): + self.assertEqual(len(self.data), self.num) + self.assertIsInstance(self.data, list) + self.assertIsInstance(self.data[0], dict) + self.assertEqual( + set(list(self.data[-1])), set(["gesture", "accel_ms2_xyz", "name"])) + + def test_split_data(self): + with open("./data/complete_data", "r") as f: + lines = f.readlines() + for idx, line in enumerate(lines): # pylint: disable=unused-variable + dic = json.loads(line) + for ges in self.num_dic: + if dic["gesture"] == ges: + self.num_dic[ges] += 1 + train_data_0, valid_data_0, test_data_100 = split_data(self.data, 0, 0) + train_data_50, valid_data_50, test_data_0 = split_data(self.data, 0.5, 0.5) + train_data_60, valid_data_20, test_data_20 = split_data(self.data, 0.6, 0.2) + len_60 = int(self.num_dic["wing"] * 0.6) + int( + self.num_dic["ring"] * 0.6) + int(self.num_dic["slope"] * 0.6) + int( + self.num_dic["negative"] * 0.6) + len_50 = int(self.num_dic["wing"] * 0.5) + int( + self.num_dic["ring"] * 0.5) + int(self.num_dic["slope"] * 0.5) + int( + self.num_dic["negative"] * 0.5) + len_20 = int(self.num_dic["wing"] * 0.2) + int( + self.num_dic["ring"] * 0.2) + int(self.num_dic["slope"] * 0.2) + int( + self.num_dic["negative"] * 0.2) + self.assertEqual(len(train_data_0), 0) + self.assertEqual(len(train_data_50), len_50) + self.assertEqual(len(train_data_60), len_60) + self.assertEqual(len(valid_data_0), 0) + self.assertEqual(len(valid_data_50), len_50) + self.assertEqual(len(valid_data_20), len_20) + self.assertEqual(len(test_data_100), self.num) + self.assertEqual(len(test_data_0), (self.num - 2 * len_50)) + self.assertEqual(len(test_data_20), (self.num - len_60 - len_20)) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 b/samples/modules/tensorflow/magic_wand/train/netmodels/CNN/weights.h5 new file mode 100644 index 0000000000000000000000000000000000000000..1d825b3aaf7c391757c4a95f1975498a2e47861c GIT binary patch literal 40512 zcmeFa30#fcw>bVZ&vlyTIuSyWgw(Tk$CMB_wH<^I zG9?KiWJ;)nj2_At6V%#v-3$S{h`Z)Ha9`Mmb_`EqR;sqGcV zj~z0EVPqL;`_izz+8+(@+cOH{+9FPW0Q`#lPsV`zkO7{*%`dlolEU`TYeZW_MzV7M z)pa(ytvt3JUPc1-f5K_EuSfCktusvZ*ml77KwAFux>je(Bt0^cw$vGA$?6WCu`FPb z&-@wl0v7*XQ&-Z7)^Fu9Z411Oz!*jv`Yo6r=HTDn3B$~r5$+SRU_tQg`Llk@^l8{Pp|g$pMZUOwF66w1A}LTh6c=+*7XmVzc@ge_@iYXhqkQWOQkh`%Ti|q z?ekJ*^dvKzIm2&mz|b;{r%l{cQjcu>e#$ROqmcIzi+W_0U{A)+Q za{fnSK+0dd+Gi3+h~B1tnz>E_jmlL)}*an<@R;?ck8^MdmEgveHecS zS7!HTxOQ-6=%CRazt5$8!qRz3`ILl{EMBlEFnGbzwi@l*x`Z>d{T%EZ>>PW921?s2 z_>JlP`u)PdFAV&`z%LB^!oV*K{KCL54E(~tFAV&`z%LB^!oV*K{I|luxKX1A%Pp$xgV?kcGoC%$B2u{U_H<&qlG3*wuH+9pKwCz%M_ZURsx1s} zhZAmZi%8qcYEPH!Xv^of!p?Ys*Ap)d zm^~{f^j{RkKkJF5`W0!|rf=#gQJ>Ba2zIt_PtKe@W3jV++a#r$sy|OzYIl)-NSSs` zm5h8F=Gs3i({6^5g8i`${!zJ{MCbimOWJ(P-{;3Le8y0jF2&CV6H8B%WCd*ExMUN=8!Y#U8)txQ2HM*&HXIEJIkk)vx)!Ix%le5HW8PA3w1awxi9EH-=tRAI%5dq{VT9$c zz%ce8qz`w7;dXn$sK-%sG}(vtMm=cn`P0bssxCBj@B>c8T8$V@-Nv>QSwYNKMQm^w z0gl!#uwm!|GDT$=EzHs&^}e$=dHdppv!dBwnQap$&maG`Hk^`=GcYeDN+ zFM7Y;m+12B3EE61|J|40ZB>ps3Hs3cwIbTIaR4=nYa#nCf1_JE?4;H6Rzm89rNp82 z4Nb_qNw4VU(6OIikQGX*u4eChP&!(h9v#8B*3Udb)W>h7YnkJenV!xvHU&kQ^QKV2 z)<9B!dLlN>&!Qm{lOQlz)pdQ^J37-kiRQQ}kq-_rR6KPwEFW4$RZ5;yyWGpvYuIwU z(sGBE49z4q?w4WEiIX&Bc?0IT9VAuV%t*zj8X8i5k_?Z0NE(i+xe8p?aP@1BqaeGA zdi|)RnIFoiyxRmiBrlI9`pA=g&r<00M|HH;A^~H{BWQ7wAsy^hN|Y}wAp^xquK(S? z&#&?Q!oV*K{KCM$7XyEl8{6bIsUMT{+$0VE{c>QeRofW(e;VpvMvN+wPYm9kf(F^ps}&lc7}K#P8bHkf6Q4fAPE1koesh zF+b~@wE3>>vZjD>RAU@}v@3w5^!?Xi{pfY@wxATvG z#An3)+Ti~*_)8Ax{Kj*MhzNFe;q&1o$~lm~&zZ zt-Y*F!zT}=r3;plRMn9vRvM1{FSVGeork&Z3u(cSBxsP+pq_f0Fmd-tTEpqlbiOyu zy5EZkM@Dcqo16)g<{}CX>PTxd;vjOS7ELG~NALJ8B>vH^#3@OG@)ceoKbRpkkGqoE z_#q@mZz_#_=?Ec9$6>D4d{RHGGY!8qht{7v55iuPIllRMuxDDq+F=}hJI0L`_b)A}jrv3hSNYrSd&DSbGB_!^Ffg3ElGdQ5~i z9UftUSvJ@>6k+u}ZOrO-8N=mHfG5~sqRKN&b$9{xogV|cdpY5&e`6C@jUz(Y6b$h@ zgv=)y6mA_%ZQM>l^Za4N>9q<`zu$|dZX{UqbRwzml?%cRJ6#&$1`%6V7sA{L;#g-t zBJgpewhd~8S2&0W68&h|Tzw2#Falyfccmpv6EHG?tZ-GTi^mE*lGvpO4S9YMqdQ!O zpcCdq@Nf}wxPUO(JCW%*o>nY>jCx+<2rpt6CVp;#SYty>_3=QD*TbpsiWjCXc!!ZE zR+8K@Z%kjiNMg(AP3@oT1mE3{*^rhx)SK0ZmgY^R+^Xf2k&Sgx-n0?wH>(lBf&nyg zi7TyHVnb60nL>RJKNKFd#cE4sS{po*l$cjzxMvWNS71o&7m3ZKr!mbsW=GU_@JT}w z5c|lXn7ZmH3idn#-#TZKtLTV=!>(M}{lln!Q;Vkdvw@)LUL^O_N}4Lujqo+~X#F9E z)XP~yeOVQHjz5IC=}JU>pdraA4x`o6R?x)8)r4QEiOmsAX%h8Jk^`H-03vn{pyJiz2yEJv~;=AY?=eW|9ST7Es zwq5)Q>uN))QH4C!iym^dX_118h#Rm`}I-1 z&XE>axY7!DUy@~3j7+G6U(^M%Ww(sbz5ixpcPt~ht6U*8Z8<5U8^M0JBQW*$DC}%b z@2pfOo(sAmGxi(HTGA=P$|Fnqxi^b8q8+%tse*%KiA&TJB5 z-W6@4LJ41K1Lo%Jf=JCknwyah)%;S~cU!_MZ5EN-rwgE@XFM`X9b8IO#}TKm=P_nPFJk*ug~aw534;9@tk{sFu6!DW+(c-Y{h05Bg$AP0OLJ?)LwZFp7#V0G;^h7 z%^1^!j#Ef-RR`iBCm@Mu8qr-?2Yj=yZ2jb^ShIg3DfyvC`Aaj9&t2w1u31CNC{bT1@QKY>02XC#Gslf?63z>S_26j69dp)Q46yX2TR(W+cFnYl^_wFLW-u>rV^b zd^P<5jYRQ`ez75nRez!jimS2nm?%F$ZmD`*{JOuTHC(&CfWw5)Cn znoV9#PMU<%SOo_n+;zo8J%XV%YbFp!LuDdtsAFx17*cWi7p~=i4z+gcOZ+vgDBors zn|k2_XVj%9HKN^VY3%}<^}vKMD{MKhZ{dWQ(}gI~UQZy=onXm<&eSw-ZS8=OhR^9YKZWR8&&ef`WZ1oNuHjG6@}vyb>e9cildS-Or=B zIuc&}XchI^OfhnY3l$th@N_?cf0Zx)Z1<4Lk<#$rFE8d=wVkm9e>!UYRbKqBwR`-X z9;;$-+pvZ0qyMD`{?ljtmG}t*Kg*X=`9s*gj-=pHxzAIY@lS?QJ(gf#TSLC>{ry+# zfsNYLK*IL<{i2_L@=|{Q!{($iJQ8COf z`e{eRui#(Bz~9ji)BbLdw9^0A^doJz&F=A+2Y&u|;YVs0k;4DEh>mYyWTVb!kd^`#XE+@8kT#7}}Qlvn5nI z!#@xCuX3u*?*KAw*Q}W8c5rF!pVR+-xwUped(%$s;qSp;<<_A8v-m$e{*U-Z>iD)H z{Nfw`;h*BaWd}d=4Zd`g?fZ!|`5$fHtZ#ROkn*L!dv8!Wo?l@b27b10)^7Q~V@68n zxugH@aR1v%bKa%lq)}Ut>(Pj@=*+D>+zdveOIY!YP;|b0A7a(ru&%xULhY3CaV0~P zduT$+D_ugJ-?OlUZHCV$M^v z3u^$CWIjn0E`!tRI`o!aIb@lAXSwdND7 zK95~DC*Y>mIcP=+1m%2XN3Ew2RxC$#8!br9Hw}6qNSP`wxrqr$R?vJ(1NfgDxQoUj zsPDd)JKrQnO14_!?vcv0``H6HbAL3<@;ZbQ`NeQAP=PL1S0}B}-<&f)9u;1Qqr?uc|LJ{1OyG3ulwuY}t-H4~jqoZ@sgG(f@HO3q*I8mix}z^Buj z!Lv%6wM%V$p-qhJIXKi9)5Tq7dtNwyZ^!P8)d) zOX6H52FoNcUtmJYF0O*8ge@?A{#r;J-Hg$9+_3NG2Ix_*Nrh3ZZ1~SPQt(rPl4yqg^kav1>VCXh&gNv!Z00HFry5UR2X_~ z1dn8CY~ehU+&OD7lh{2qX!b50A}bh)+`X=#WAzSvH;#kql6NlmQf0_j)e@M-D3EB2 zq2S;egT?29K&hJvNgZ<)zjkYZflV90CiWbgxPJ=x4seIo%EzFX@fMzPcR|oG9yV6W z5Sb4q#3qtYr!3iyytz-=g1Q(G_6cBH)UA=xUcj;;7PMk~F_*ehnI_CJ1LrSKKv=58 zJ&n|$mPeFHanBx5e(yEhN;v}tj<&e^hyq>q)sh4q{emyJY{J`9uet#^APQu(Z@9@!59x55D(t^rpC|%_Rcg4meC3dNF&M#Mncs;LWWy>qT^P~+o(QhBfX4GI_`Wt-eIRrPnPsGwE3ShV! zao_SIkov=yosg1as2HwqkNKVsavBJ}M2gH7nhVYts5E>>?a zD;zmjG;^F1%~>)HJ_Yee?bKe_O_@SPbP@>SzHr%h_G0AuV{p`+N9WE+N6pm+G&DN` z7M%{oIS-34wlSGoYcc`njL5~WGrmB?Ts@ReHpJ#Gx49md@8Mga0WH;cfo1CQ#BJ6D z33o37vw(vTI_x%je~kd5_7KK5=fd0+b&@+r4s-99V2NfSTQYt>=Qt<;_Ro_ezT?#C zYv*#PUi;0(sIH4-PA^bA*_R_0(b#DKk6a)55^rV9#kCn9CHyZ(G_+2Sm{=YHKUW_0 z3$mbQbFSb~B(IW%-M5(#@tq16!Je_K*=2QfQd*C(`hHwDeFe##O+F+Z z>IemyrEJlAgtGAq*`_Z2Flb8#_vVrj708SQ#d+`Hoq|41b{>Q!Lte3o*)njgNS>~; z{tnvTWr>d2D?E^*L;`KqsqjsebM5^%Y(aD~TsE!2%R}^udf+Wo-CqaFHG!~yj3TLi zIF$=7FTu-U+H{rX3z#XWf?RJ~cBbE3FzabW_vJi=x*diRJ~Njs>v$H;SILutu?yhG z3oB};n*jwSd*PymDz?NtMZV5bZenr;ROom>b>l~o@W2&kqgCbXlg%c?`(zVVk59vT zw_a>sb}1-Mss!Z|?hu`*h~|?de5`w@OY3vV8rg11s~7I!d>_Pdk7hLB#VsGeuThcs zdR^syw91lLk7@AHMv<0HWZ0~``RtU-4>5Xd9ol#|px%hjXsL1=FW)hz>tf7kNFQ@d zQ#GWgeREK8)@STD&yosBxx63Gff(PH?1^i@x^)&97V3~WD^p-?pXZ<$8iod)I$)m_mZUV&6nHJqxv=L! zIA!HCs2!Wjc^>Tpnx8EwcViBY3hRe+zyh7aY~J6)UH` z1M8J)G%drB3iNm)n}XX86};pMAbn*4WG%C#<>i&|;)*7o zobLfQE{c%p#?F2o>pA?@}8?d$+5FGR$%Cx>FD0547{fasOzCfS8VSNXJ)U(E9a{5 zZ6_WG4mFDghuMMifELI#S;7i*TtrhX)Je6%4Gc7X4YKqEEGlqCALFlBY%>l#&8Op- zGiF3Dd^Y5N>xe<`)3_SHQxI=zNgt^j(A*EVG2ioow!nkKISP9=G#Ozb9g#6yu4DPrMm&7PkXR3y2~PW$prUyr zJk`Adt{yrhYS~8MoO;8BwOVxFl5pIV_84*|yhcW>CJLJp2jllPVvB1Y2<};M3)4)9 z!YmoWw-e&PC7&_8cN&*S^PqOZ8+LHtuFx+;n;hMe0^yraz?R5-Wcb_re|`N8&C{b$ zoehAUhi|~!sjcW_uL8;|3}N~Q4(%QVWA)7~P~`gtB9&F)e0N#m_Tn^@>1lxP{z`Un zsv2$bt%sW9o>>1x4>rkn#FPcPWTTlSi76S1?pOOj*uo?zS(yb1xEjQ*3%S&$d~Qy; zI&FAhLO1u6BLkh6qK&^T`X2d#L3JCri`r*F?W+=DS9w75npa$Ymu*m)eH_MkD-*>* zRw(=Q2aX@ag8@sTaDH|<&N+Gy_qy^)S0{B^@W=|_@qQGa>j@`=7h>wO3jDbDGV~2= zf@qT@&fk41L|d$985OxAFBJ(t8~K#=k5;06*55;gW*M^BSykedxkNyEhPGL)D1 zk`wlF=4!i|K~OgpiC^19WU>-reAfwJ`b>k|+VuraKh~t5Cl;XMnFW}fp@>_vui%eD zBYMhMhA@`iEN}5v&ScUiP|hBRah;547B>vkzUUJ1i3~P3-VTRv{tQm<&fwFT&A2%K z6vjl#(i<5W=(yD%kLdHryCoK6L??;wV0ty@J}@6`!k2<)s}~Ia@Cv*oJk@Ho0f||< z9YYT(V5veWzP8wmOvET>uh=C}_6)cd*>or$_zsw)*M(iTYmw@4ouO#rW;m88L(Z8M zgV(A&Hf+i&jNW7mEhfVt)v^yuq}+BIWsA30?De_~_59_Bp1#AB#&5sGi~MDx39 zzr~IiZNHiar8W1zSNThEHDv*F)66yD;OFDb4JB9@ng3NN`I>G}!b6c3wUQcjsD= zvJ^8qG2%3iKcz{N7M6kD?L(+I`aYh0XGV?uwP?nO9VlKb!=|p8hxIZWP$No#BoF9_ z7LEHMq&f+Yo-IXRD{El#jHp7g9ND>hAJoRD1JkiD7p?n}6{=^rOv&huHWxNZt>wFD@@IbiiD2oZJ~&?+iy*mxpL{Xbr|x&4Dt*r)+%VK3K18Kxeu=hiIh?=sWi_ z?DO~pf=2e^z16jzjVT7aKs8Qdf$Q( zTE_In>Af(w;~H#riH7<`!4PV%NpF7CAU*e(Qd-M;ahMyp$ zwSYD0?SW{aLxxW($2#6lTy|ZR#Kq{+o?t)=#IHHwn&&wGWChr|<>S32@u>A(iq{2!HAs|n{2^ERTd-UrOsl#Vlx-i3K%HefbSiSC!sft-;%>ENzLTVuDP z)5{XB+U1~&xJK7ylB*n1J^TjF1*%c`eO5SYnhx#0rU_q9kmwU;Ou^A^CPZYcMA-!@ z)Vi1C%otJ4G8TF?sBt@LN_Zc=+6_nT+<+FlH>15zCm2;%0_B;nFw%A$7l{8Ft>PBvI$ZkJFv z8JUQl7Cu~6p$T}M3C0U91z10>o)hde6-CA}$PHKspI_ZW_kmm4(i6t?F=tN8?H+>t zh(26R=@yJ|Y=-NNmw>4Wg=fy@q@|ZSt*D(1GtB~U@JS2sbk?ILdlMk`S~2@xU_v#= zpT*p79eUI%&zQGsK%&BTLDbk!a-fTo8=LKlAx`}=AkWZ53_;j|y zUKBhzBg*ye#&-7GjlCy}FrlOjoiCh+!|IZAeEdHKr>P&5HLhij9UoW+L=f-h}18 zui(Kv9(}X)K8CGm!BmrMm#KrRQSPt|ZJeu21**AN6KqV=9^8Wt8WV9#jRHwnk%4xD zm!eQT*TrnW3h)k(fTp(+-Qp{K?6b_2)W2BI1;%Dc==lU}iZh|9N6%nL;6+aH(pNaT z<_jENXGT^WnG3?z>zy~oYZHr{6DWLROc|FOE=fz@gSCbU&DGw3wtaL#bINY8lsSS+ zXDCs(jtW%R?*ON8MuVIx;So10K!X)!;LS*MnV-~2Sn_v>?X!g|(Fle^OK*dhhd)-R z?B=}kw{gx|#~_El6(8OHj)JR-T-pa6Dzn&>Y>>!3(IZ+|{d206KXxyB+|Ga`i?X0Z zJRjdGD^Y&vcQ5xjM*5(H%`E@izGAVR)3Zi~_(tYE~L{ zPDw%Lbr+YazMt_)Pi?}J=t~yuUWlq=FF@;?>tJn?3NKpLp|M2&Qg^!%W}mR4`gKvGDknE1_0EKQPKJd4=`Ppnx;8y= zFAD3=HRCG9Ti`ftDZ~V4q0#qxbg<6DXLC)+t){2oP^E<8m_SJAIv%UXz2`zEuH!0J zO7unsrWl#J4a$yL(Ox%F`ma z6MAf?LYRtog_-7WaIEA_uxjIa(G!Ur!C1NBv2n`8tbZ=7_Awxp>)wNj|5-?nIg6eh z4zl@$W|)xG6%&s;TJ$PP((`T(gy=@s#bEh6N$rhM) zQiFCIcm)mKZk3!9H{hI$6hf|8!W8B%_$!V@#quK65NYuK_VKlbR*Yj7mpfS3>C6Y(c48Xb|%R`AVHdCXq!T$dw|>oAY2Ty_`F9q5X| zH#C45ysOAsUY(j4&cN|=hQOwA>pTv`&WD@O-LMu1O3urM zZ(?y&pcp*QJz}@kyo7156-hlej@`0Ng*+K%OrNBw5&kx0nTXp(nX76bcIrF!P<0hv z-%|{`v~(zYZ7VzuRw2pO=kOL^p4z%=!O>U_=(j9Xtvd({BQ*(cJaB6ob*cIr$(oyY z3yPZS;r0E_SlJp2u|p29f|X^QP4IkR-Y$lj3sN9U;W^i+U`TwmEI7|(EpF_19}M{t zg@V~y+=3elr2a!VYN@CbpMBRc#P>Dlls^n)4^iOKdO?lh5@;wsiqE&HlUHk+(alMX zZizaG1ySzMBIt#l{TJdN-Hfh7*%3b)WSpp-!7xNF8Vs)Pz-0QWHy|%)Z<8Npd z{n%_yPlv^z+nA$Zmlu!ARwg4K3q-NIt=KW051`kUj@WWjR-$9!p@kZUqV5}U&oddq z<2FO_10xWwJy7&uP#P}zU`F}c{_y4#Lt@Sw(ojP=8gN#Z zeGZY-X(}4%p9#w^G$6mWH_Xw`2ENV#m$kOnq2+NDJfCGsJ-2IfnLjQ;6W0yy; zs0>VX)CryT9!`o%ur=lq3fIVqx>>!H+;{knT3+R7rj-Y8)h6J+&*!l`#ss znW%f&!*KP_Xxh_~_FG^`66a7hG{+oz251t?8E^30?!9+<1%O_N%cqr3Kc=@`!0&2Hxwt7e^XekqEaW z82j`Nq-Cm8jYw5$pIM0NDZ^mQVgr&mJDSbYr~{$Tb!XdedBAtODpGi$L4x+of#g>Q zQ8~>FPWZ{wTV)3%ccQd#?aY0^j11)#%+(@&_x8sIUCCL#-(vRAlWDkj%Uw+WZVd<1 z@=;}vC7HJQJ>IH61f_e9K;oQjkl)e_`}>#?CglbfUf_nIQ>xLXj~vwCJtR6qu{@&) zk1sYRPOa*6Z@f9twaCPfDf6+uzKP4zkHS<}3syfzLPwqftnZ`J1yMebPdW3kDXVA^>7-SvohOGM` zEYR7`KK3^tV#y7YShW>w<-Tlay_<)-7M{XsHzUz=pa}CjWI=O}6gEpa0WG%ZQ6&pL zeK-0IitjvO%Z`R~%+v!#`E~E1EG&eLR*UEE=u7Uh%pJ&8fA+;({lnmRy&3t8Ebi)O zPK=6kxX?j|VC_Rq+H1oP42j;wiaQ8Ice?w*gAf@y^I;vVyTKt$io{1(_knxt8kpXD zC}zHW0aG#&I{0tH%Aoh?;oXe*&VrQA)}f<9o`KH2`M4nECEBJbQ=x1ETm85RlMY+a zmZ2id-T8!d?=cK^j(!OjZ`^|sg;{uL-)q=CB^_(Rb)k6O2G9v8gRnl2@X)Lp;KiGA zvHK2SNW)Gb=En4G*K`mU_mk{#VlMoeEVZAu9J3T9_eE>PD9m=^WT%OtV$D*JIl?3U z_4^>SSAVeCzlZZM=>)#36uGEVnZOV5;{?$RlL}uIUG7ci0Kbo@ZjQ1+^xD9bC20^MF%fR>`M(G z9x?!!+l0F$qeI-Y3%SWF_h7JbH0bFa#EXw|;biL!OnvQOK4g>Ept3RWjM*3EF0#fQf7@Xl|c~t)md^>K{gGLw6uDw7x>;fA)trMmN=+bew z-{WA-F<7@S2AqaiVzitA1g(0-=0-cR;?=`lBqnUAK6{;IN(Zv%J{|uB>7A0FF>Lx~P)ec+FtsTolB3Xmz*=&T>AsQ~QC{i~#p0(J($kv#(N&^T&d``HcV9BKtkL#sg;XT=KM_2p*t zwxYLsThdK?A3&Yya~wHRla>wJ&b~;0j%#PHm)!4tC0aO0gB+0PJ&d-xfPzGpI`nD| zitW#F^|xNI?!ocwxk)cTZ^vU`g$~t!Q;sD0C>K@Ws!F2Jht2$_F3dPFFSL z4b|hktF&;=Z8_Say9PsV?SM<4AEKs#7M-xlgcQ`uLUg?&WIntOUO5l&V(LZQv0H~s z+%G9V8MJI0stBLwFI=d$va!SV!Gxy-7&60c5yO*mZ-ZOVA>19HBDp)8gyQpV+>0PBtRFv_ z_3U;MmadlEy_?_xMpoCj)B6f>TyJ%9q*$K}8Kg*NhndprH!@_pExmlD{5*>Aapw|TVxBxS;ne}SG#&am?1KJ@G^==|skt;a<5-4k3-UUjGLeYBhC=e{l76qv25gmtm<5P`%WzsA6D?Sam4O{bfL#pj3m!-;z#9b$k zRWD&{9*byB@3`FbvRJn_J+ zTaP*BY!{X*e1#px#9JCGkcrAUA^{IY!-KTahkJZoPpyt2cpfg2ZnhI!YOBaW9K|Sj9k?j#?3G%Vy$9s)hiR4o|FN(i<;O2gUxAoiQJ_$ z&4?TgyntiV<6zV1+gLT>FboJfjT6Rbk+B_|%*VDfg!I7lxLe}2sk zo-qsq8*LQvYFM5-s0h-A{&;NA7wggRHiyebP)ce7D(rv1D7;3KX7$hH7WoashdcCWQ=%MQ_3SOQvJ#&fgSFtsJ%Wa3yK#e;>r%z|DlEBfNR!{SVCBh;P_d~S@E025ZARh)G_XiSBsQ&<>1R=F ze#xa~AA#rtnvnYRs*B)42|9jog}bW9F{?=vuCIuJ(hieR92^fp{#!ABS~UcEJ_lQ7 zIXqu2L(i2QfzcA$`WUeh_yyW*YPUTuM~oQ~v0o2e0xU^FpDS2m@{MDz%8R0Yl);dz zx+LQsLssrlC&m5Dfpx!!v0G$)GZNePFVrJQw=M&4-$n;lw(zY6vIa)d0M;a z8D~3Cg<3iq5u-Z}80%WdIet6=4ZAx+?Ckv@o7Dg_?Mp#@&N}F?qe)8NOorS`0;u9B&yJ_SPY&z+VNye{mO@%%k?idR<4;YY$ zsdpjZvMj02m&mKvsEg^{WBAmy0E{&?N#MzB7(CMvIW`&CM}7$g<)H7;%-Gj?Jad$W}?e3W~mOgR5#UxpUM6|k{q z53|b0P2u?7YP2YtgLh3;N!JP;vTaui=02Urop9ZNNw2^xOdK_v4a@v=h$PU+iF<+M2TE?RHmVI%ON(1!|0{y_$6*T zo~*C~{=n-ZllOkOywr*W`*ep6+Sidc<|5ubWkBOw>*4!Vd766Jf_r3k9!6*J$cck8 zWZjU%=)cH;@d0&p|?|g?wi^E}Fml|xSD#VqVR`f0N8OH?hX;Y9ceb7nrhRWr6 zP+S}XxjJvSExW8}amP$JI$ENmOgoOAce-QYI!khM2#3?%qoCQf8WVFjz{RpAFp5ut zfis?9pVMEF-|$)TZcu4cI}lcYacY z6%Pgg-B$<9SUoyKqPvp6G8R3LFMuz*IUJ(JBb8OM)N}lGu0whboSf>61!@#KESZOO zh7G7aBN1&oSkn6QN^IGzbud|5hp6AP$Mb<-aN43mOiapU-zur%@HLXRTePllpIBA0 zxy6Lknit@Q@jUu=hBjT|V@~_t&cUcn=G6Ie6NG%z!$*D>z$RP;Gl#zc#RuuYlS{z7 zTPx7EJ_(eSSK}een-F&oQDGL(INc==! z?u0()_Mp2}FF0{D5shME*vFl7G4X8!#Md<(fImI13`ya1r$$ z`39TC2cS$_32#wVym!l-CcPGesw2g@eYL2Jx5S=xtpS^&P@;2Zh^$2bt~qE*j%dnI zua<={`5=$T790kruLiWJI3EVT8wW|#6lP?1AL_ z^Xsr``&9VWOO@0wOW{1use;z?Ld@$^1eqThdT3Q8OuDB)cuPuIg`Yh>vWFaqOm%*jgh)gE0>J=6r&4j#nY6S0yIg*#rTl8pP*;9yw%o4p%-& zz?bV(Npi_P%->!OFVa&Ws`DPm6+0kfpIDf>;}TanYdci*Im+Igvk_DJTae}tC&6i0 zGdtpjB=<=XZE?{WY||p-K7vkjZ%sc+n5P7pW@ zg=+_+{lWp5m0HJf#0DZ&R8UxQ%z6IetFUskJo)k}1AX&DIFY&nDN5t8v*I3{dccZQ z=en{_t{PLjO;6$B;xkwxcZ=gUDX^N))8MgzC28951xEeI2SJ9&B`;i#7WFy}CwFB6 zlk(bGm^l%3Zz~bIrt6ScB@e!Zek^myTvQJoq5iNUo0?hX!fvWXPn$>V5SAg#q8yQj zNSmG=t4ln>^x?D7QB2Z#jD-iyN%(|Lh!PoophpX+o1Z|tC0EdX$Wbn6W)f&#u7jAl zEAfy?G&s4{pw%5E+Bip!#?N{U{&TlM_I4f-ys?L;pE5Ciio`yui=yc6r#LkgJ{fDc z3kQyA!BdkC;G?4kRLTE1p1vWm+Z<8?+s(sJdB`U=@Zf2D`?e#P6rBY1*mzuVKm~o1 z`0SNgIncF9iSUZ=ak0h;Z1#RDBEM@m=*?PyatS=rXNf*(QF6pLrwxek*e20%#bfX( zQJFT*Qlfd`Wf(GYE^9Mr5b$j)Mfc`eQBAw8@On-Y_U&2=wTCxyS;k>dTx&p$JhZ|0 zOFRgC=D6HBEBQMytRJO}e5+0htb$xHxQsG}|>9>A(aSa#a&( z&uTb7$O+}wnxwK~OH^?Gwv!JHqD=!gfXvQ>_0`-c(pDI=5{(b3DIy-kpS=gC?VaJu zFabTGG8-lP{D7*;!Qv1P%xg%L&T-}Zu4^%PF43Ze4`Q)3V!-Zql%iSnFQ{;f#q=gA?%HLHtIYZA zz@l<+|KcJ;t(T!=kP(-MAqDr)-|>fs0@_>uXACS%z$|+Kd&X3sj#$ALd2Js_$Imoi zw{tlg1{2MoSbrSy+UqEB7%S4!>4G`;V?e%cE)%)*IF~P@&6+okMrR*gcI?9dIBD=b zF8ipK5r>~;tSy9ej^`3wxR6Ju?3;^%(dCrPBn&51d&Av%LiUom4xY;?#>{bCF0ex# zr5$=1Bd)HOre^5T1Gn7A+5$7yWzT-_uujC&N2}3m#{tAp1(sy}ia8A#a3snG(oVE7 z7oRj>p>_$Z-c-Sam97QPTuXY|nnsX}kwQ|o3CnA$rYx%7VZGaiqNC9;daB`WG(0sM z`8Da3Xx$Dp%%6nX3xpWN<)&0^90UhP=R&AWEDkY>VY&-!(B)PL_{C?UcHv;GXwXH8 z=6BR&xdj`$Di5kkzJ&`ddi0tnmUPYXSWJ;wfL8EbV1$#QJ^mDxJ-rgLj%~)Y2S-JM zPmG~+Un-Mk)hG(8w8ui6q&DnHN}b7z%N{$n?ThaqS8>q@BR zwk~EReG`37CgQFc9;o4MbHa{hA$(>x_!{fN1V0m2uFKWWBsqaqtRGHQjG&_gj!qvhz)lJN7rmuH-70t9Ap+>j%>o@`3EM0xl2e-55B${tC6!&4{IDjY6lK z49Kd}V_V0^!TfEbApYELSTWR`c8e>3%ouZYPjH1rk(Z#&r2;3)&ZB$yA!_-LL+M0m zI-_f!gsqB4AZxydR?}T@hJ7#=b-zRFgcN&;SpMHgV0=&>)Rt6!d{$N#c z8K(G*#AN|8s(0P^%fdHn7b>NDvk7s zbCINEB>V4`E3$vV*Qx>abBNeyv#NhO5l$T+N@eC``ah4ve$@N5{~LUqzhrzaJ`UF% zgy3?h5%R& zgqNsX*j9CJ^mV>8CoB2;&zG0{!QRlvtDi3dm7FfB{X@hi_DX1iX+?~J!~\n", + " \n", + " Run in Google Colab\n", + " \n", + " \n", + " View source on GitHub\n", + " \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xXgS6rxyT7Qk" + }, + "source": [ + "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LG6ErX5FRIaV" + }, + "source": [ + "## Configure dependencies\n", + "\n", + "Run the following cell to ensure the correct version of TensorFlow is used." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "STNft9TrfoVh" + }, + "source": [ + "We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ygkWw73dRNda" + }, + "outputs": [], + "source": [ + "# Clone the repository from GitHub\n", + "!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n", + "# Copy the training scripts into our workspace\n", + "!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pXI7R4RehFdU" + }, + "source": [ + "## Prepare the data\n", + "\n", + "Next, we'll download the data and extract it into the expected location within the training scripts' directory." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "W2Sg2AKzVr2L" + }, + "outputs": [], + "source": [ + "# Download the data we will use to train the model\n", + "!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n", + "# Extract the data into the train directory\n", + "!tar xvzf data.tar.gz -C train 1>/dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DNjukI1Sgl2C" + }, + "source": [ + "We'll then run the scripts that split the data into training, validation, and test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XBqSVpi6Vxss" + }, + "outputs": [], + "source": [ + "# The scripts must be run from within the train directory\n", + "%cd train\n", + "# Prepare the data\n", + "!python data_prepare.py\n", + "# Split the data by person\n", + "!python data_split_person.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5-cmVbFvhTvy" + }, + "source": [ + "## Load TensorBoard\n", + "\n", + "Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CCx6SN9NWRPw" + }, + "outputs": [], + "source": [ + "# Load TensorBoard\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir logs/scalars" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ERC2Cr4PhaOl" + }, + "source": [ + "## Begin training\n", + "\n", + "The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DXmQZgbuWQFO" + }, + "outputs": [], + "source": [ + "!python train.py --model CNN --person true" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4gXbVzcXhvGD" + }, + "source": [ + "## Create a C source file\n", + "\n", + "The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n", + "\n", + "In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8wgei4OGe3Nz" + }, + "outputs": [], + "source": [ + "# Install xxd if it is not available\n", + "!apt-get -qq install xxd\n", + "# Save the file as a C source file\n", + "!xxd -i model.tflite > /content/model.cc\n", + "# Print the source file\n", + "!cat /content/model.cc" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Train a gesture recognition model for microcontroller use", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/samples/modules/tensorflow/magic_wand/train/train_test.py b/samples/modules/tensorflow/magic_wand/train/train_test.py new file mode 100644 index 00000000000..2b785663df7 --- /dev/null +++ b/samples/modules/tensorflow/magic_wand/train/train_test.py @@ -0,0 +1,78 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for train.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np +import tensorflow as tf +from train import build_cnn +from train import build_lstm +from train import load_data +from train import reshape_function + + +class TestTrain(unittest.TestCase): + + def setUp(self): # pylint: disable=g-missing-super-call + self.seq_length = 128 + self.train_len, self.train_data, self.valid_len, self.valid_data, \ + self.test_len, self.test_data = \ + load_data("./data/train", "./data/valid", "./data/test", + self.seq_length) + + def test_load_data(self): + self.assertIsInstance(self.train_data, tf.data.Dataset) + self.assertIsInstance(self.valid_data, tf.data.Dataset) + self.assertIsInstance(self.test_data, tf.data.Dataset) + + def test_build_net(self): + cnn, cnn_path = build_cnn(self.seq_length) + lstm, lstm_path = build_lstm(self.seq_length) + cnn_data = np.random.rand(60, 128, 3, 1) + lstm_data = np.random.rand(60, 128, 3) + cnn_prob = cnn(tf.constant(cnn_data, dtype="float32")).numpy() + lstm_prob = lstm(tf.constant(lstm_data, dtype="float32")).numpy() + self.assertIsInstance(cnn, tf.keras.Sequential) + self.assertIsInstance(lstm, tf.keras.Sequential) + self.assertEqual(cnn_path, "./netmodels/CNN") + self.assertEqual(lstm_path, "./netmodels/LSTM") + self.assertEqual(cnn_prob.shape, (60, 4)) + self.assertEqual(lstm_prob.shape, (60, 4)) + + def test_reshape_function(self): + for data, label in self.train_data: + original_data_shape = data.numpy().shape + original_label_shape = label.numpy().shape + break + self.train_data = self.train_data.map(reshape_function) + for data, label in self.train_data: + reshaped_data_shape = data.numpy().shape + reshaped_label_shape = label.numpy().shape + break + self.assertEqual( + reshaped_data_shape, + (int(original_data_shape[0] * original_data_shape[1] / 3), 3, 1)) + self.assertEqual(reshaped_label_shape, original_label_shape) + + +if __name__ == "__main__": + unittest.main()