Как классифицировать данные без разметки +54


Пользователи iFunny ежедневно загружают в приложение около 100 000 единиц контента, среди которого не только мемы, но и расизм, насилие, порнография и другие недопустимые вещи. 

Раньше мы отсматривали это вручную, а сейчас разрабатываем автоматическую модерацию на основе свёрточных нейросетей. Систему уже обучили на разделение контента по трём классам: она распознает, что пропустить в ленты пользователей, что удалить, а что скрыть из общей ленты. Чтобы сделать алгоритмы точнее, решили добавить конкретизацию причины удаления контента, у которого до этого не было подобной разметки. 

Как мы это в итоге сделали — расскажу под катом на наглядном примере. Статья рассчитана на тех, кто знаком с Python (при этом необязательно разбираться в Data Science и Machine Learning).

Классификация без разметки

Задача: сделать классификацию объектов.

Дано: множество данных без разметки и каких-либо подробностей.

Решение:

Для начала загрузим данные и проведём их первичный анализ: 

from sklearn.datasets import load_digits
 
dataset = load_digits()
dataset['data'].shape

У нас в наличии датасет размера (1797, 64). Это сравнительно небольшой набор данных — он меньше 2000, но и этого может быть достаточно, если выборка  репрезентативная (отражает особенности всего исследуемого множества). При этом у каждого объекта 64 признака — если они все бинарные (принимают значение 0 и 1), то нам потребуется 2^64 примеров, чтобы покрыть все возможные варианты. Для признаков, которые принимают 3 и больше значений, размер всеобъемлющей выборки будет ещё больше. На практике лишь небольшое число признаков несёт основную информацию об объекте и принимает гораздо меньше значений из допустимого множества. 

Выведем несколько строк из набора на экран:

dataset.data[10:15]

Полезно бывает смотреть на сырые данные без дополнительных агрегаций информации. Например, сейчас видно, что массив сохранен в формате float, но не видно ни одного элемента с числом после точки, будто бы все они целочисленные.

Перед работой с любыми данными стоит смотреть на статистику по разным признакам (столбцам). Взглянем на несколько случайных столбцов — возьмем с 30-го по 35-й и выведем статистику с помощью библиотеки pandas. 

Метод describe позволяет посмотреть набор самых часто используемых статистик из таблицы ниже. Значения признаков группируются около нуля, на что указывают их средний показатель. Также есть признаки с нулевым значением у всех объектов выборки, значит они неинформативны и их можно не использовать при дальнейшем анализе.

dataset_df = pd.DataFrame(dataset.data[:, 30:35])
dataset_df.describe()

Есть большое количество методов для анализа данных, многие из которых связаны с графическим отображением. Один из любимых способов Data Science инженеров — график попарных корреляций. Он позволяет обнаружить зависимость между признаками, которая может вести к уменьшению признакового пространства. Также с его помощью можно обнаружить корреляцию между признаком и таргетом (искомой величиной), но у нас нет разметки, поэтому данный сценарий нереализуем.

import seaborn as sns
sns.pairplot(dataset_df);

В нашем случае видно лишь то, что все признаки принимают целочисленные значения. Отсутствие парных корреляций не исключает наличия зависимости между большим числом признаков одновременно. Но увидеть такие особенности данных невозможно — у нас 64-мерное признаковое пространство. Даже если в нём есть области, где объекты группируются, то обнаружить это каким-либо графическим методом будет крайне сложно (а может и совсем невозможно). 

В такой ситуации нужно уменьшить размерность пространства признаков, отобразив его в двух- или трёхмерном, с которыми наше сознание в состоянии справиться.

Понижение размерности

Для начала избавимся от константных признаков. Выше мы отметили наличие признаков со значением 0 у всех объектов, поэтому спокойно их удаляем во всей выборке. Наша цель — разделить объекты, а значит, основная информация заключается в отличии их друг от друга. 

Есть множество способов уменьшить размерность признакового пространства, сохранив его информативность. Для статьи возьмём алгоритм UMap, так как уже используем его в своих задачах. Одно из его преимуществ перед другими алгоритмами нелинейного снижения размерности — возможность обучать на одном наборе данных, а затем использовать его в дальнейшем на новых данных, применяя одно и то же преобразование.

Используем уже готовую библиотеку. Здесь самый важный параметр — количество компонент, которое нужно получить на выходе (до какой размерности сжать текущее пространство признаков). Выбираем два, потому что 2D-плоскость можно наглядно отобразить на рисунке:

import umap
reducer = umap.UMAP(n_components=2, random_state=47)

Делаем обучение командой fit. Данных не так много, поэтому обучаем на всём наборе, но, как было сказано ранее, он может быть меньше итогового:

reducer.fit(dataset.data)

Далее преобразуем все данные:

embeddings = reducer.transform(dataset.data)

И на выходе получаем уменьшенную размерность — количество образцов то же самое, но признаков всего два: (1797, 2).

Коротко о том, как это работает: UMap строит взвешенный граф, соединяя ребрами ближайших соседей в n-мерном пространстве, затем создает другой граф в низкоразмерном пространстве и приближает его к исходному так, чтобы сохранить относительное положение объектов. То есть близкие объекты оставляет ближе, дальние — дальше, но уже в уменьшенной размерности.

Построим график полученных 2D-векторов :

plt.scatter(embeddings[:, 0], embeddings[:, 1], s=5)

На графике видно 10 больших групп точек и ещё несколько поменьше. Проведём кластеризацию — разобьём на области, основываясь на каком-либо параметре или правиле. 

Кластеризация

Воспользуемся алгоритмом k-средних (KMeans), который основывается на минимизации суммарного квадратичного отклонения точек кластеров от центров этих кластеров.

Задаем поиск 10 кластеров (на предыдущем графике видно 10), делаем обучение и предсказание итоговых классов:

clustering = KMeans(n_clusters=10)
classes = clustering.fit_predict(embeddings)

Раскрасим картинку с кластерами. Алгоритм очень хорошо их разделил:

plt.scatter(embeddings[:, 0], embeddings[:, 1], c=classes, cmap='Spectral', s=5)
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

Полученные порядковые номера кластеров можно считать классами неразмеченной выборки. Для классификации новых данных нужно последовательно применить к ним уже обученные алгоритмы UMap и KMeans и получить номер кластера для этих объектов.

А теперь открою небольшую тайну — это были не просто данные.

Исходные данные

В тренировочном примере данные являются картинками 8×8 пикселей с рукописными числами. Если значения интенсивности всех пикселей слева-направо и сверху-вниз выложить в одну строку, то получится вектор длины 64 — именно тот, с которым работали до этого. Интенсивность в пикселе записана в формате uint8 и принимает только целочисленные значения от 0 до 255, а значит наши наблюдения в самом начале были верны. 

Всего в датасете представлены цифры от 0 до 9, то есть как раз 10 классов (столько же кластеров нам удалось выделить):

Теперь у нас есть истинный класс и известно, какой строке соответствует какая метка. Если изобразить истинное распределение классов в пространстве меньшей размерности с помощью найденного преобразования, то получится следующее:

Точность классификации

На картинке выше видно, что в большинстве случаев отличаются только цвета, которые отвечают за номер кластера. Это связано с тем, что метод k-средних расставлял метки случайным образом, не вкладывая в 0-й класс смысл наличия нулей в его изображениях. Если поменять нумерацию, то станет видно, какое число примеров было выделено правильно. 

Есть много метрик, которые одним числом указывают, насколько способ хорош. Самая известная — точность (accuracy), которая является отношением верных ответов ко всем примерам в тестовом наборе. У такого подхода есть большой недостаток — он не говорит, в чем именно ошибка. Использование этой и других интегральных метрик будет особенно неудобным в случае многоклассовой классификации, где по одному числу непонятно, какие классы путаются между собой. 

Именно в такой ситуации мы сейчас находимся, поэтому стоит обратиться к матрице ошибок. Для ее построения используем библиотеку pycm:

from matplotlib.pyplot import cm
from pycm import ConfusionMatrix

y_true = dataset.target
conf_matrix = ConfusionMatrix(actual_vector=y_true, predict_vector=y_pred)
conf_matrix.plot(cmap=cm.Greens, number_label=True);

В этом коде y_pred — перенумерованные значения кластеров, найденных нами ранее. В качестве нового значения использовался наиболее встречающийся в нём истинный класс. Ниже показана полученная матрица ошибок:

  • По горизонтали — классы, предсказанные нашим методом. 

  • По вертикали — истинные классы.

  • В клетках пересечения — количество объектов, удовлетворяющих двум условиям. 

27 образцов из истинного класса единиц почему-то определились как шестерки. Разберёмся, почему так вышло и посмотрим на картинки из датасета.

Единицы, классифицированные как шестерки
Единицы, классифицированные как шестерки

На первый взгляд эти объекты не выглядят, как шестерки. Вернёмся к истинной разметке классов и увидим, что есть небольшая группа единиц, которая очень далеко от остальных и находится как раз ближе к шестеркам. 

А реальные шестёрки и правда порой похожи на единицы (особенно первая в третьем ряду и третья в первом), поэтому тут вопросы не к нашей модели, а к тому, кто так пишет: 

Шестерки, классифицированные верно
Шестерки, классифицированные верно

Вместо заключения

Похожим образом мы работаем над тем, чтобы спорный контент вроде пейзажей, оружия и девушек в купальниках попадался не всем пользователям, а только тем, кто не против. Только вместо значений пикселей, как это было в нашем примере, берутся определённые паттерны.

Эти паттерны выделяет нейросеть, предобученная на большом наборе данных. Но для основной задачи удаления нежелательного контента в своём исходном состоянии она нам не подходит, потому что не знает три наших класса:

  • approved — картинки идут в раздел приложения collective;

  • not suitable — не попадают в общую ленту, но остаются в ленте пользователя (девушки в купальниках и мужчины в плавках, селфи и всё, что не является мемами);

  • risked — такой контент получает бан и перестает быть доступным для всех пользователей iFunny (расизм, порнография, расчленёнка и всё, что попадает под определение «противоправный контент»).

Нам предстояло дообучить сеть на эти классы. Но об этом подробно поговорим уже в следующей статье.




К сожалению, не доступен сервер mySQL