Создание счётчика жестов «дай пять» с помощью глубокого обучения +4


Около десяти лет я хотел реализовать эту глупую идею – измерить ускорение руки человека, чтобы подсчитать, сколько раз он дает пять в течение дня. Я не знал, как решить данную задачу, используя классические подходы к разработке алгоритмов, основанные на знакомых мне правилах, поэтому проект приостановили. Но когда я делал серию видеороликов MATLAB Tech Talk по Deep Learning, я понял, что Deep Learning идеально подходит для решения этой проблемы!

Темой четвертого видео в серии было «Трансферное обучение», и оказалось, что это была ключевая концепция, которая мне понадобилась для быстрой реализации алгоритма, считающего сколько раз я даю пять в течение дня. В этой статье я подробно расскажу о коде и об инструментах, которые я использовал для подсчета количества жестов «дай пять». Надеюсь, вы сможете использовать это пример как отправную точку для решения сложных задач классификации, над которыми сидели последние 10 лет.

Итак, приступим!

Обзор оборудования

Настроить оборудование довольно просто. У меня есть акселерометр, подключенный к Arduino Uno через шину I2C. Затем Arduino подключается к моему компьютеру через USB.

Чтобы измерить ускорение, я использую MPU-9250. Это инерциальный измерительный датчик с 9 степенями свободы от TDK InvenSense. Вместо того, чтобы интегрировать датчик в свою собственную схему, я использую макетную плату, через которую обеспечивается питание и подводятся контакты связи I2C.

Вы можете видеть, что мое устройство довольно грубо сконструировано с использованием лишь макетной платы и нескольких перемычек, но это не так уж и плохо, ведь вам не нужно будет заниматься сложной настройкой для правильной работы устройства.

Чтение данных с акселерометра в MATLAB

Для считывания данных об ускорении от MPU-9250 через Arduino используется MATLAB Support Package for Arduino Hardware. Этот пакет позволяет вам общаться с Arduino без необходимости компилировать для него код. Кроме того, есть встроенная функция mpu9250, позволяющая считывать показания датчика с помощью однострочной команды.

Для подключения к Arduino, создания экземпляра объекта MPU9250 и считывания показаний акселерометра требуется всего три строки кода.

Предварительная обработка данных и скалограмм

Если вы смотрели 4-е видео из серии Tech Talk о глубоком обучении, вы знаете, что я решил преобразовать данные о трехосевом ускорении в изображение, чтобы воспользоваться преимуществами GoogLeNet – сети, обученной распознавать изображения. В частности, я использовал непрерывное вейвлет-преобразование для создания скалограммы.

Скалограмма – это временно-частотное представление, которое подходит для сигналов, существующих в нескольких масштабах. То есть сигналы, которые являются низкочастотными и медленно меняющимися, но затем время от времени прерываются высокочастотными переходными процессами. Оказалось, что они полезны для визуализации данных ускорения руки во время жеста «дай пять».

Версия кода MATLAB, которую я использовал для построения вышеуказанного графика:

close all
clear

 % If your computer is not able to run this real-time, reduce the sample 
% rate or comment out the scalogram part
fs = 50; % Run at 50 Hz

a = arduino('COM3', 'Uno', 'Libraries', 'I2C');  % Change to your arduino
imu = mpu9250(a);

buffer_length_sec = 2; % Seconds of data to store in buffer
accel = zeros(floor(buffer_length_sec * fs) + 1, 3); % Init buffer

t = 0:1/fs:(buffer_length_sec(end)); % Time vector

subplot(2, 1, 1)
plot_accel = plot(t, accel); % Set up accel plot
axis([0, buffer_length_sec, -50, 50]);

subplot(2, 1, 2)
plot_scale = image(zeros(224, 224, 3)); % Set up scalogram

tic % Start timer
last_read_time = 0;

i = 0;
% Run for 20 seconds
while(toc <= 20)
    current_read_time = toc;
    if (current_read_time - last_read_time) >= 1/fs
        i = i + 1;

        accel(1:end-1, :) = accel(2:end, :); % Shift values in FIFO buffer
        accel(end, :) = readAcceleration(imu);

        plot_accel(1).YData = accel(:, 1);
        plot_accel(2).YData = accel(:, 2);
        plot_accel(3).YData = accel(:, 3);

        % Only run scalogram every 3rd sample to save on compute time
        if mod(i, 3) == 0

        fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
            'VoicesPerOctave', 12);
        sig = accel(:, 1);
        [cfs, ~] = wt(fb, sig);
        cfs_abs = abs(cfs);
        accel_i = imresize(cfs_abs/8, [224 224]);


        fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
            'VoicesPerOctave', 12);
        sig = accel(:, 2);
        [cfs, ~] = wt(fb, sig);
        cfs_abs = abs(cfs);
        accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]);

        fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
            'VoicesPerOctave', 12);
        sig = accel(:, 3);
        [cfs, ~] = wt(fb, sig);
        cfs_abs = abs(cfs);
        accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]);

        if~(isempty(accel_i(accel_i>1)))
            accel_i(accel_i>1) = 1;
        end

        plot_scale.CData = accel_i;
        end

        last_read_time = current_read_time;
    end
end

Обратите внимание: этот код использует функцию cwtfilterbank для создания скалограммы, которая является частью Wavelet Toolbox. Если у вас нет доступа к этому набору инструментов, и вы не хотите писать код самостоятельно, попробуйте применить другой тип частотно-временной визуализации. Возможно, сработает спектограмма или какой-нибудь другой алгоритм, который вы придумаете. Что бы вы ни выбрали, идея заключается в создании изображения, в котором будут присутствовать уникальные и различимые паттерны «дай пять». Видно, что скалограмма работает неплохо, хотя можно попробовать и другие методы.

Создание обучающих данных

Чтобы научить сеть распознавать жест «дай пять», нам нужно несколько примеров того, как выглядит и как не выглядит этот жест. Поскольку мы начнем с предварительно обученной сети, нам не понадобится столько обучающих примеров, сколько необходимо при обучении сети с нуля. Я не знаю точно, сколько обучающих данных нужно, чтобы полностью охватить пространство решений для всех возможных «дай пять», однако я собрал данные для ста «дай пять» и ста «не дай пять», и это, похоже, неплохо сработало. Я думаю, что, если бы я действительно создавал продукт, я бы использовал гораздо больше примеров. Вы можете поиграть с объемом помеченных данных для обучения и посмотреть, как это повлияет на результат.

Сбор 200 изображений кажется большой работой, но я написал небольшой скрипт, который циклически просматривает их одно за другим и сохраняет изображения в соответствующей папке. Я дважды запускал данный скрипт: один раз с меткой «high five» с изображениями, сохраняемыми в папку data/high_five, и один раз с меткой «no_high_five» с изображениями, сохраняемыми в папке data/no_high_five.

 % This script collects training data and places it in the specified
% label subfolder. 3 seconds of data is collected from the
% sensor but only keeps and saves off the last 2 seconds.
% This gives the user some buffer time to start the high five.

% The program pauses between images and prompts the user to continue.

% Note, you'll want to move the figure away from the MATLAB window so that
% you can see the acceleration after you respond to the wait prompt.

close all
clear all

% If your computer is not able to run this real-time, reduce the sample rate
fs = 50; % Run at 50 Hz

parentDir = pwd;
dataDir = 'data';

%% Set the label for the data that you are generating
% labels = 'no_high_five';
labels = 'high_five';

a = arduino('COM3', 'Uno', 'Libraries', 'I2C');  % Change to your arduino
imu = mpu9250(a);

buffer_length_sec = 2; % Seconds of data to store in buffer
accel = zeros(floor(buffer_length_sec * fs) + 1, 3);  % Init buffer

t = 0:1/fs:(buffer_length_sec(end)); % Time vector

subplot(2, 1, 1)
plot_accel = plot(t, accel); % Set up accel plot
axis([0 buffer_length_sec -50 50]);

subplot(2, 1, 2)
plot_scale = image(zeros(224, 224, 3)); % Set up scalogram

for j = 1:100 % Collect 100 images

    % Prompt user to be ready to record next high five
    H = input('Hit enter when ready: ');

    tic % Start timer
    last_read_time = 0;

    i = 0;
    % Run for 3 seconds
    while(toc <= 3)
        current_read_time = toc;
        if (current_read_time - last_read_time) >= 1/fs
            i = i + 1;

            accel(1:end-1, :) = accel(2:end, :);  % Shift values in buffer
            accel(end, :) = readAcceleration(imu);

            plot_accel(1).YData = accel(:, 1);
            plot_accel(2).YData = accel(:, 2);
            plot_accel(3).YData = accel(:, 3);

            % Run scalogram every 3rd sample

            if mod(i, 3) == 0

                fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 1);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i = imresize(cfs_abs/8, [224 224]);


                fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 2);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]);

                fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 3);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]);

                if~(isempty(accel_i(accel_i>1)))
                    accel_i(accel_i>1) = 1;
                end

                plot_scale.CData = accel_i;
            end

            last_read_time = current_read_time;
        end
    end

    % Save image to data folder
    imageRoot = fullfile(parentDir,dataDir);
    imgLoc = fullfile(imageRoot,char(labels));
    imFileName = strcat(char(labels),'_',num2str(j),'.jpg');

    imwrite(plot_scale.CData, fullfile(imgLoc,imFileName), 'JPEG');
end

После запуска скрипта я вручную просмотрел обучающие данные и удалил изображения, которые, как я думал, могут ухудшить результаты обучения. Это были изображения, где жест «дай пять» не находился в середине кадра, или изображения, на которых я знал, что сделал плохое движение рукой. На гифке ниже я удалил изображение 49, потому что оно не было в центре кадра.

Трансферное обучение и GoogLeNet

Когда все мои обучающие данные находятся в соответствующих папках, следующим шагом будет настройка сети. В этой части я следовал примеру Classify Time Series Using Wavelet Analysis and Deep Learning, но вместо того, чтобы запускать все через сценарий MATLAB, мне было проще настроить и обучить сеть с помощью приложения Deep Network Designer.

Я начал с предварительно обученной сети GoogLeNet, чтобы воспользоваться всеми знаниями этой сети для распознавания объектов на изображениях. GoogLeNet была обучена распознавать на изображениях такие вещи, как рыба и хот-доги – явно не то, что я ищу, – но именно в таком случае полезно трансферное обучение. Благодаря трансферному обучению я могу сохранить большую часть существующей сети без изменений и заменить только два слоя в конце сети, которые объединяют общие функции в конкретные шаблоны, которые я ищу. Затем, когда я обучаю сеть, необходимо обновлять в основном только эти два слоя, поэтому обучение происходит намного быстрее с трансферным обучением.

Если вы хотите точно знать, как я заменил слои и какие параметры обучения я использовал, рекомендую вам следовать примеру MATLAB или посмотреть Tech Talk, однако можете сами поэкспериментировать с сетью. Вы можете попробовать начать с другой предварительно обученной сети, такой как SqueezeNet, или можно заменить больше слоев в GoogLeNet или изменить параметры обучения. Здесь есть много вариантов, и я думаю, что отклонение от того, что я сделал, может помочь вам развить некоторую интуицию в отношении того, как все эти параметры влияют на результат.

Обучение сети

Когда сеть уже готова, очень просто обучить ее с помощью Deep Network Designer. Во вкладке Data Tab я импортировал данные обучения, выбрав папку, в которой я сохранил набор изображений «дал пять и не дал пять». Также было выделено 20 процентов изображений, которые будут использоваться для проверки в процессе обучения.

Затем на вкладке training tab я устанавливаю свои параметры обучения. Здесь я использовал те же параметры, которые использовались в примере MATLAB, однако еще раз рекомендую вам поиграть с некоторыми из этих значений и посмотреть, как они влияют на результаты.

Обучение на моем единственном процессоре заняло чуть более четырех минут и достигло точности около 97%. Неплохо для пары часов работы!

Тестирование счетчика

Теперь, когда у меня есть обученная сеть, я использую функцию classify из Deep Learning Toolbox, чтобы передавать скалограмму в каждый момент времени выборки и возвращать метку «high_five» или «no_high_five». Если возвращаемая метка была «high_five», я увеличиваю счетчик. Чтобы не считать один и тот же жест несколько раз, поскольку данные об ускорении проходят по всему буферу, я добавил задержку, которая не будет считать новый жест «дай пять», если с момента предыдущего жеста прошло не менее двух секунд.

Ниже приведен код, который я использовал для подсчета жеста:

close all
clear

%% Update to the name of your trained network
load trainedGN
trainedNetwork = trainedGN;

% If your computer is not able to run this real-time, reduce the sample
% rate or comment out the scalogram part
fs = 50; % Run at 50 Hz

a = arduino('COM3', 'Uno', 'Libraries', 'I2C');  % Change to your arduino
imu = mpu9250(a);

buffer_length_sec = 2; % Seconds of data to store in buffer
accel = zeros(floor(buffer_length_sec * fs) + 1, 3); % Init buffer

t = 0:1/fs:(buffer_length_sec(end)); % Time vector

% Set up plots
h = figure;
h.Position = [100         100        900         700];
p1 = subplot(2, 1, 1);
plot_accel = plot(t, accel);
plot_accel(1).LineWidth = 3;
plot_accel(2).LineWidth = 3;
plot_accel(3).LineWidth = 3;
p1.FontSize = 20;
p1.Title.String = 'Acceleration';
axis([0 t(end) -50 60]);
xlabel('Seconds');
ylabel('Acceleration, mpss');
grid on;
label_string = text(1.3, 45, 'No High Five');
label_string.Interpreter = 'none';
label_string.FontSize = 25;
count_string = text(0.1, 45, 'High five counter:');
count_string.Interpreter = 'none';
count_string.FontSize = 15;
val_string = text(0.65, 45, '0');
val_string.Interpreter = 'none';
val_string.FontSize = 15;
p2 = subplot(2, 1, 2);
scale_accel = image(zeros(224, 224, 3));
p2.Title.String = 'Scalogram';
p2.FontSize = 20;

telapse = 0;
hfcount = 0;

tic  % Start timer
last_read_time = 0;

i = 0;
% Run high five counter for 20 seconds
while(toc <= 20)
    current_read_time = toc;
    if (current_read_time - last_read_time) >= 1/fs
        i = i + 1;
        telapse = telapse + 1;
        % Read accel
        accel(1:end-1, :) = accel(2:end, :); % Shift values in FIFO buffer
        accel(end, :) = readAcceleration(imu);

        plot_accel(1).YData = accel(:, 1);
        plot_accel(2).YData = accel(:, 2);
        plot_accel(3).YData = accel(:, 3);

        % Only run scalogram every 3rd sample to save on compute time
        if mod(i, 3) == 0

            % Scalogram
            fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 1);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i = imresize(cfs_abs/8, [224 224]);


                fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 2);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i(:, :, 2) = imresize(cfs_abs/8, [224 224]);

                fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ...
                    'VoicesPerOctave', 12);
                sig = accel(:, 3);
                [cfs, ~] = wt(fb, sig);
                cfs_abs = abs(cfs);
                accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]);

            % Saturate pixels at 1
            if ~(isempty(accel_i(accel_i>1)))
                accel_i(accel_i>1) = 1;
            end

            scale_accel.CData = im2uint8(accel_i);

            % Classify Scalogram
            [YPred,probs] = classify(trainedNetwork,scale_accel.CData);
            if strcmp(string(YPred), 'high_five')
                label_string.BackgroundColor = [1 0 0];
                label_string.String = "High Five!";
                % Only count if 100 samples have past since last high five
                if telapse > 100
                    hfcount = hfcount + 1;
                    val_string.String = string(hfcount);
                    telapse = 0;
                end
            else
                label_string.BackgroundColor = [1 1 1];
                label_string.String = "No High Five";
            end
        end
    end
end

И вот он в действии!




К сожалению, не доступен сервер mySQL