教師あり学習 -K-Nearest Neighboursで予想-





K-Nearest Neighbours(k近傍法)は、ある点のクラスを予想する際に、近くの点のクラスを用いて決定します。例えば、近くの点のクラスがA:4個、B:6個ならば、多数決にてBというクラスを予想します。

logistic regressionと異なる点は、予想値の表現です。

K-Nearest Neighboursは、クラスをAかBかと予測するのに対し、logistic regressionは、クラスAになる確率を予想します。

前者をdirect classifier、後者をprobabilistic classifierといいます。

ISLRパッケージの中のdefault dataset(債務不履行)を利用します。

library(class)
library(ISLR)
library(tidyverse)
set.seed(123)

Defaultデータセットを見てみます。

df <- Default
str(df)
## 'data.frame':    10000 obs. of  4 variables:
##  $ default: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
##  $ student: Factor w/ 2 levels "No","Yes": 1 2 1 1 1 2 1 2 1 1 ...
##  $ balance: num  730 817 1074 529 786 ...
##  $ income : num  44362 12106 31767 35704 38463 ...

変数は、default、student、balance、incomeです。

Defaultデータセットの散布図を作成します。balance(残高)をxの位置に、income(収入)をyの位置に、default(デフォルトしたかどうか)を色にマッピングします。

df %>% 
  arrange(default) %>% 
  ggplot(aes(x = balance, y = income, colour = default))+
  geom_point()

arrangeは、defaultのyesが後から(上から)プロットされるようにデータの順番を変えています。

残高が高い方が、defaultしやすいみたいです。

データの下準備

ifelse()を用いてstudentをダミー変数に変換します(0 = 学生ではない、1 = 学生)。

df$student <- ifelse(df$student=="Yes",1,0)

データセットをtrainigとtestに分割します。割合は、8:2とします。

# trainとtestの列を作る
splits <- c(rep("train", 0.8*nrow(df)),
            rep("test", 0.2*nrow(df)))
# 元のデータセットに1列加える
# sampleでランダムにサンプルされる
df <- df %>% mutate(splits=sample(splits))

#データセットを分割
default_train <- df %>% 
  filter(splits=="train") %>% 
  select(-splits)

default_test <- df %>% 
  filter(splits=="test") %>% 
  select(-splits)

K-Nearest Neighboursで予測

パッケージclassからknn()を使用します。 testデータセットのdefaultを予測します。

knnの使い方:

knn(train = トレーニングデータセット,

test = テストデータセット,

cl = トレーニングデータセットの正解列,

k = 使用する近傍点の数)

トレーニングデータセットとテストデータセットは、正解の列を抜いた形にしてください。

knn_5_prediction <- knn(train = default_train[,-1],
                  test = default_test[,-1],
                  cl = default_train[,1],
                  k = 5)

knn_5_predictionには、testデータセットのdefaultに対する予測が入っています。

head(knn_5_prediction)
## [1] No No No No No No
## Levels: No Yes

グラフを書いて、予測が正しいか見てみます。

  1. x = テストセットbalance , y = テストセットincome, color = 正解
  2. x = テストセットbalance , y = テストセットincome, color = 予測

balance(残高)をxの位置に、income(収入)をyの位置に、default(デフォルトしたかどうか)を色にマッピングします。

g1 <- default_test %>% 
  ggplot(aes(x = balance, y = income,colour = default))  +
  geom_point() +
  labs(title = "test set Answers")
  
g2 <- bind_cols(default_test, prediction = knn_5_prediction) %>% 
  ggplot(aes(x = balance, y = income,colour = prediction))  +
  geom_point() +
  labs(title = "test set Predicted by knn 5")

gridExtra::grid.arrange(g1, g2)

正解が少なめに見えます。

k近傍点を2と5と100で比べてみます。

knn_2_prediction <- knn(train = default_train[,-1],
                  test = default_test[,-1],
                  cl = default_train[,1],
                  k = 2)

knn_5_prediction <- knn(train = default_train[,-1],
                  test = default_test[,-1],
                  cl = default_train[,1],
                  k = 5)

knn_100_prediction <- knn(train = default_train[,-1],
                  test = default_test[,-1],
                  cl = default_train[,1],
                  k = 100)

g1 <- default_test %>% 
  ggplot(aes(x = balance, y = income,colour = default))  +
  geom_point() +
  labs(title = "test set Answers")+
  xlab("")
  
g2 <- bind_cols(default_test, prediction = knn_2_prediction) %>% 
  ggplot(aes(x = balance, y = income,colour = prediction))  +
  geom_point() +
  labs(title = "test set Predicted by knn 2")+
  xlab("")
g3 <- bind_cols(default_test, prediction = knn_5_prediction) %>% 
  ggplot(aes(x = balance, y = income,colour = prediction))  +
  geom_point() +
  labs(title = "test set Predicted by knn 5")+
  xlab("")

g4 <- bind_cols(default_test, prediction = knn_100_prediction) %>% 
  ggplot(aes(x = balance, y = income,colour = prediction))  +
  geom_point() +
  labs(title = "test set Predicted by knn 100")+
  xlab("")

gridExtra::grid.arrange(g1, g2, g3 ,g4, nrow = 4)

k近傍点を増やすことでover fittingしてしまい、テストセットに対しての判別能力が悪くなっていますね。

confusion matrixから、accuracy等を出したいと思います。 TP, FP, TN, FNを入れたら色々な指標を表示する関数を作ります。

confusion_matrix <- function(TP,FP,TN,FN){

accuracy = (TP + TN)/(TP + TN + FP + FN)
f1score = TP/(TP+0.5*(FP + FN))
sensitivity = TP/(TP + FN)
specificity = TN/(TN + FP)
false_positive_rate = FP/(FP + TN)
positive_predictive_value = TP/(TP + FP)
negative_predictive_value = TN/(TN + FN)

cat(paste("accuracy:",accuracy,"\n",
          "f1score:",f1score,"\n",
          "sensitivity:",sensitivity,"\n",
          "specificity:",specificity,"\n",
          "false_positive_rate:",false_positive_rate,"\n",
          "positive_predictive_value:",positive_predictive_value,"\n",
          "negative_predictive_value:",negative_predictive_value)) 
  
}

knn2のconfusion matrixを作ります。

table(true = default_test$default, predicted = knn_2_prediction)
##      predicted
## true    No  Yes
##   No  1880   46
##   Yes   55   19
TP = 19
FP = 46
TN = 1880
FN = 55
confusion_matrix(TP,FP,TN,FN)
## accuracy: 0.9495 
##  f1score: 0.273381294964029 
##  sensitivity: 0.256756756756757 
##  specificity: 0.976116303219107 
##  false_positive_rate: 0.023883696780893 
##  positive_predictive_value: 0.292307692307692 
##  negative_predictive_value: 0.971576227390181

knn5のconfusion matrixを作ります。

table(true = default_test$default, predicted = knn_5_prediction)
##      predicted
## true    No  Yes
##   No  1916   10
##   Yes   64   10
TP = 10
FP = 10
TN = 1916
FN = 64
confusion_matrix(TP,FP,TN,FN)
## accuracy: 0.963 
##  f1score: 0.212765957446809 
##  sensitivity: 0.135135135135135 
##  specificity: 0.994807892004154 
##  false_positive_rate: 0.00519210799584631 
##  positive_predictive_value: 0.5 
##  negative_predictive_value: 0.967676767676768

knn100のconfusion matrixを作ります。

table(true = default_test$default, predicted = knn_100_prediction)
##      predicted
## true    No  Yes
##   No  1926    0
##   Yes   74    0
TP = 0
FP = 0
TN = 1926
FN = 74
confusion_matrix(TP,FP,TN,FN)
## accuracy: 0.963 
##  f1score: 0 
##  sensitivity: 0 
##  specificity: 1 
##  false_positive_rate: 0 
##  positive_predictive_value: NaN 
##  negative_predictive_value: 0.963

accuracyは、knn2:0.95, knn5:0.96, knn100:0.96です。

knn100は、全部No作戦なのにかなり高いaccuracyを持っています。

これは、元のデータがunbalanceだからです。言い換えると、defaultのyesとnoが均等ではなく、NoばかりでありYesはレアであることに起因しています。


Categories:

category