☰ Оглавление

Метрики в машинном обучении: precision, recall и не только

Почему я это написал. Долгое время я был бэкэнд разработчиком. Вращался среди коллег и прекрасно их понимал. Но бэкенд нужен и людям, занимающимся машинным обучением. В какой-то момент я попал в их логово. И вот тут я поднял, что вообще не понимаю их язык. Думаю, что не только я оказывался в такой ситуации, поэтому сейчас расскажу всё понятными словами для нормальных людей.

Немного про машинное обучение

ML работает всего с несколькими вещами:

Приведу пример. Допустим у вас есть яблоня. Каждый год вы ломаете голову, опрыскивать ли её от долгоносика. Решили обучить модель, что она вам предсказывала нашествия долгоносика.

Данные: Допустим, вы знаете только весенние температуры в виде одного числа. (Не важно, что это, допустим, просто средняя температура.) То есть у вас есть одна фича — температура.

Разметка У вас есть журнал по годам, где отмечено, был долгоносик или нет. То есть каждой температуре вы можете сопоставить "правильный" ответ. Обратите внимание, что одной и той же температуре могут соответствовать несколько ответов, они могут быть разными (за разные годы)...

Модель Пусть у нас будет модель с одним параметром: пограничной температурой. Модель будет просто говорить "да", если температура выше какой-то черты и "нет" — если ниже. Можно было бы придумать модель с двумя параметрами (она бы смотрела на интервал), или ещё сложнее, но мы сейчас возьмём самую простую, для наглядности.

Предсказания Если применить модель к температурам (фичам), то получим предсказания.

Немного кода

Чтобы было с чем играть, вот вам код. Тут есть и данные, и модель, и всё о чём мы будем говорить.

#!/usr/bin/env python
# coding: U8

import numpy as np

# Наши тестовые данные: набор наблюдений — пар: температура, наличие долгоносика

TEST = np.array([  # 0 - нет долгоносика, 1 - есть
        [0, 0],
        [1, 1],
        [2, 0],
        [3, 0], [3, 0],
        [4, 0], [4, 1],
        [5, 1], [5, 1],
        [6, 1], [6, 0],
        [7, 0], [7, 0],
        [8, 0],
        [9, 1],
        ])

# Наша модель очень проста. Единственный параметр модели — threshold

class Model(object):
    def __init__(self, threshold):
        self.threshold = threshold
    def predict(self, data):
        return data >= self.threshold

# Получаем метрики

def metrics(test_data, model):
    data = test_data[..., 0]  # входные данные (массив температур)
    observations = test_data[..., 1]  # фактические наблюдения
    prediction = model.predict(data)  # предсказания (результат применения модели)

    true_positive = np.logical_and(prediction, observations)  # и в прогнозе, и в реальности было "да"
    false_positive = np.logical_and(prediction, np.logical_not(observations))  # прогноз сказал "да", а в реальности было "нет"
    true_negative = np.logical_and(np.logical_not(prediction), np.logical_not(observations))  # прогноз — "нет" и он прав
    false_negative = np.logical_and(np.logical_not(prediction), observations)  # прогноз — "нет" и ошибся

    tp, fp, tn, fn = (x.sum() for x in (true_positive, false_positive, true_negative, false_negative))

    accuracy = (tp + tn) / (tp + fp + fn + tn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * (recall * precision) / (recall + precision)

    #dump(data, observations, prediction, true_positive, false_positive, true_negative, false_negative)

    return tp, fp, tn, fn, accuracy, precision, recall, f1_score

def dump(temp, obs, pred, tp, fp, tn, fn):
    print('{:5} {}'.format('Temp', ' '.join(map(str, temp))))
    for name, v, comment in (
            ('Obs', obs, '(TP+FN)'),
            ('Pred', pred, '(TP+FP)'),
            ('TP', tp, ''),
            ('FP', fp, ''),
            ('TN', tn, ''),
            ('FN', fn, '')):
        print('{:5} {} {:2}{:>10}'.format(name, ' '.join({False: '.', True: 'T'}[bool(x)] for x in v), sum(v), comment))
    pass

print(' T TP FP TN FN Accur Prec  Recll F1')
for model_param in range(10):
    m = metrics(TEST, Model(model_param))
    print('{:2d} {:2d} {:2d} {:2d} {:2d} {:5.3f} {:5.3f} {:5.3f} {:5.3f}'.format(model_param, *m))

Если это запустить, то мы получим метрики для разных моделей:

T TP FP TN FN Accur Prec  Recll F1
0  6  9  0  0 0.400 0.400 1.000 0.571
1  6  8  1  0 0.467 0.429 1.000 0.600
2  5  8  1  1 0.400 0.385 0.833 0.526
3  5  7  2  1 0.467 0.417 0.833 0.556
4  5  5  4  1 0.600 0.500 0.833 0.625
5  4  4  5  2 0.600 0.500 0.667 0.571
6  2  4  5  4 0.467 0.333 0.333 0.333
7  1  3  6  5 0.467 0.250 0.167 0.200
8  1  1  8  5 0.600 0.500 0.167 0.250
9  1  0  9  5 0.667 1.000 0.167 0.286

T — это параметр модели. То есть мы получили метрики, фактически, для 10 разных моделей.

Если раскомментировать dump() то будет видна детальная информация.

Сейчас мы со всем разберёмся.

TP, TN, FP, FN и друге буквы

Когда люди начинают жонглировать этими буквами, с непривычки, можно очень легко запутаться и потерять нить. Чтобы всё встало на свои места, нам понадобится ещё несколько букв и пара полезных соотношений.

Давайте посмотрим на входные данные (разметку, фактические наблюдения). Введём две буквы:

Теперь посмотрим на прогнозы модели. Здесь тоже есть positive и negative, но их сразу же делят на четыре группы:

Тут важно проникнуться простыми соотношениями:

P = TP + FN
N = TN + FP

Остановитесь тут и подумайте минуту.

Ну и, конечно, ясно что такое TP+FP (это все ответы "да", полученные от модели) и NT+FN (все ответы "нет").

Ценность метрик

У нас появились первые метрики. Давайте посмотрим, на сколько они полезны.

Метрики нужны, чтобы понять, какая модель лучше. Выше мы видели все метрики для всех моделей. Поглядите на колонки TP, TN, FP, FN.

Видно, что ни одна из этих метрик не позволяет нам выбрать лучшую модель. Например, модель, которая всегда говорит только "да", показывает лучший TP. Это и понятно: везде, где в наблюдениях было "да", наша модель сказала "да". Однако, ясно, что это глупейшая модель.

Аналогично не работают и другие три метрики. Нужно что-то получше.

Accuracy

Первое, что приходит в голову: давайте поделим все правильные ответы на все вообще ответы.

            TP + TN          TP + TN
Accuracy = ───────── = ───────────────────
             P + N      TP + FN + TN + FP

Такая метрика уже лучше, чем ничего, но всё же, она очень плоха. Даже в моём примере (хотя я не подгонял специально числа) видно, что, с одной стороны, разумные модели имеют высокую точность, однако, побеждает по точности просто самая пессимистичная модель.

Вы можете поиграться с данными и посмотреть, как это происходит. Но понять смысл очень просто на другом примере. Допустим вы хотите предсказывать землетрясения (какое-то очень редкое явление). Ясно, что по этой метрике всегда будет побеждать модель, которая даже не пытается ничего предсказывать, а просто говорит всегда "нет". Те же модели, которые будут пытаться говорить когда-то "да", будут иногда ошибаться в позитивных прогнозах и сразу же терять очки.

Поэтому, про эту метрику вы, скорее всего, даже не услышите никогда. Я тут её упомянул только чтобы показать её неэффективность при, кажущейся, логичности.

Precision, recall и их друзья

Лично я чаще всего сталкивался именно с этими словами. При том, что, как мне кажется, это самые неудачные варианты. Я буду приводить альтернативные называния, которые, как мне кажется, на много лучше отражают суть.

Recall aka sensitivity, hit rate, or true positive rate (TPR)

Мне кажется, hit rate и TPR лучше всего отражают суть. В этой метрике мы рассматриваем только P-случаи: когда в реальных наблюдениях было "да". И считаем, какую долю из этих случаев модель предсказала правильно.

Все случаи "нет" мы отбрасываем.

                          TP       TP
TPR (recall, hit rate) = ──── = ─────────
                          P      TP + FN

Recall сам по себе довольно бесполезен. Взгляните на результаты для нашей модели: модель, которая всегда тупо говорить "да" — безусловно побеждает. Фактически, recall пропорционален TP, если P — константа (напомню, что это просто количество ответов "да" в наших фактических данных).

У recall есть брат-близнец:

Specificity, selectivity or true negative rate (TNR)

       TN       TN
TNR = ──── = ─────────
       N      TN + FP

Здесь верны все те же самые оговорки. Специфичность, фактически, пропорциональна TN.

Важно, так же, заметить, что если T и P сильно отличаются (как в примере с землетрясениями), то сравнивать recall и специфичность надо очень осторожно.

Precision aka positive predictive value (PPV)

Какая часть наших предсказаний "да" действительно сбылась:

         TP
PPV = ─────────
       TP + FP

Недостатки этой метрики аналогичны: она вообще никак не учитывает предсказания "нет". Из наших результатов видно, что побеждает модель, которая почти всегда говорить "нет". Она как бы снижает риск проиграть, выводя большую часть своих ответ за рамки рассмотрения.

У этой метрики есть аналогичный близнец

Negative predictive value (NPV)

         TN
NPV = ─────────
       TN + FN

Какая часть "нет"-предсказаний сбылась.

И ещё немного метрик

Перечислю кратко и другие метрики. Это далеко не все существующие, а просто аналоги вышеперечисленных, только относительно отрицательных прогнозов.

Miss rate aka false negative rate (FNR)

       FN       FN
FNT = ──── = ─────────
       P      TP + FN

Fall-out aka false positive rate (FPR)

       FP       FP
FPR = ──── = ─────────
       N      TN + FP

False discovery rate (FDR)

         FP
FDR = ─────────
       FP + TP

False omission rate (FOR)

         FN
FOR = ─────────
       FN + TN

И что же со всем этим делать

Как вы уже видели, каждая из этих метрик рассматривает только какое-то подмножество предсказаний. Поэтому их эффективность очень сомнительна.

Однако, их очень часто используют для двух вещей:

На втором я хотел бы остановиться в некотором философском ключе.

Давайте задумаемся, а что значит "одна модель лучше другой"? Единого ответа тут нет.

В нашем примере с долгоносиком всё зависит от наших приоритетов.

Если мы хотим ни в коем случае не потерять урожай, то нам надо максимизировать TP любой ценой. Фактически, в предельном случае, мы можем выкинуть любые модели и просто опрыскивать дерево химикатами всегда.

Если мы хотим минимизировать применение ядов, то нам надо максимизировать TN. В предельном случае, нам просто надо никогда не опрыскивать дерево: потеря урожая для нас не так страшна, как безосновательное применение ядохимикатов.

В реальной же жизни, мы ищем некоторый компромисс. Во многих случаях он может быть совершенно чётко сформулирован, с учётом цен на химикаты, стоимости урожая, репутационных потерь и прочего.

Не редко, люди придумывают собственные метрики. Но есть и готовые, пригодные во многих случаях.

F1-score

Это комбинация recall и precision:

          recall * precision
F1 = 2 * ───────────────────
          recall + precision

Но мне кажется, поведение этой функции становится гораздо понятней, если записать её так:

               2
F1 = ─────────────────────
         1           1
     ───────── + ─────────
      recall     precision

То есть, это гармоническое среднее.

Максимальный F1-score мы получим, если и recall, и precision достаточно далеки от нуля. Он позволяет найти некое компромиссное решение, фактически, между максимизацией TP по разным шкалам.

Это не единственная возможная метрика. И у неё, как вы видите, тоже есть чёткий фокус, а значит и недостатки. Однако, даже её достаточно, чтобы среди наших моделей выиграла та, у которой пограничная температура равна 4. Давайте ещё раз взглянем на сравнение всех моделей:

T TP FP TN FN Accur Prec  Recll F1
0  6  9  0  0 0.400 0.400 1.000 0.571
1  6  8  1  0 0.467 0.429 1.000 0.600
2  5  8  1  1 0.400 0.385 0.833 0.526
3  5  7  2  1 0.467 0.417 0.833 0.556
4  5  5  4  1 0.600 0.500 0.833 0.625 <-- победитель по F1
5  4  4  5  2 0.600 0.500 0.667 0.571
6  2  4  5  4 0.467 0.333 0.333 0.333
7  1  3  6  5 0.467 0.250 0.167 0.200
8  1  1  8  5 0.600 0.500 0.167 0.250
9  1  0  9  5 0.667 1.000 0.167 0.286

И вот детализация по этой конкретной модели (с T=4):

Temp  0 1 2 3 3 4 4 5 5 6 6 7 7 8 9
Obs   . T . . . . T T T T . . . . T  6   (TP+FN)
Pred  . . . . . T T T T T T T T T T 10   (TP+FP)
TP    . . . . . . T T T T . . . . T  5
FP    . . . . . T . . . . T T T T .  5
TN    T . T T T . . . . . . . . . .  4
FN    . T . . . . . . . . . . . . .  1

Вы можете взять мой код, раскомментировать функцию dump() и посмотреть детализацию по всем моделям.

Видно, что ответ разумный, но выработан с фокусом на TP. То есть на то, чтобы больше перебдеть. Возможно, это как раз то, что нужно. А может и нет. Поэтому и существует ещё множество других метрик, но основа у них одинаковая.

Надеюсь, я пролил некоторый свет на вопрос.