Juliaでスパム判定の機械学習分類器を作る
あと1週間でクリスマスイブですね。
Julia Advent Calendar 2014の17日目の記事(@yutajuly)です。
ここでは、スパムデータ分類を例にとり、実データ特有の問題に対応した機械学習モデルの構築をJuliaで実装します。
■分類器構築の流れ
以下のような流れを考え、スパムデータを分類する分類器を構築する。
1. 学習データの取得
2. 前処理の前半戦
2−1. ラベル付け
2−2. 特徴選択・抽出
3. 前処理の後半戦
3−1. データのスケール調整
3−2. 不均衡データ処理
4. 分類器の構築
4−1. パラメータチューニング
4−2. 精度評価
■1.学習データの取得
ここでは、HP研が収集したSpam E-mail Databaseを扱う。Rのkernlabパッケージにあるspamデータです。
- データ数:4601通(spam:1813通, non-spam:2788通)
- 1〜57列目:各メールの特徴量(単語や記号文字の出現頻度、大文字の連なりの長さ etc)
- 58列目:各メールのラベル(spam ,non-spamを表す)
読み込みは以下のように実行。
(Cドライブ直下にspam.csvが置いてある想定)
##データを読み込む Pkg.add("DataFrames") #初回のみ実行 using DataFrames spam = readtable("C:spam.csv", header = true) head(spam) #データ数行確認
■2.前処理の前半戦
よく教科書にあるデータでは、当然のように、ラベルも特徴量も所与。
ここで扱う上述のスパムデータも同様。
しかしながら、機械学習の一番大変なところや本質はここにあるといえる。
2-1. ラベル付け
ラベルとは正例, 負例(spam, nonspam)のこと。もちろんラベルは最初からあるわけではない。
Spam E-mail Databaseでは、HP研がGeorge(メールデータの提供者)にヒアリングをしながら、4601通のメール1つ1つにラベル付けを行ったと想定され、その手間は想像を絶する。
2-2. 特徴選択・抽出
機械学習とパターン認識において最も大切なフェーズ。どんなに優れたなアルゴリズムを使っても、分類に影響する特徴を見ないと分類できるはずがない。
Spam E-mail Databaseでは、本当の最初はメールそのものがあるだけ。メールから色んな単語を取ってきて、spam, non-spamの分類に影響する57種類の特徴量を作ったと思われる。
(参考)みにくいアヒルの子の定理
-何らかの形で特徴に重要性を考えたり、取捨選択しなければ、みにくいアヒルの子と普通のアヒルの子の区別もできない
■3.前処理の後半戦
さて、ここからようやく、ちゃんとJuliaコーディング。
とはいえ、まだまだデータの前処理。
3-1. データのスケール調整
全ての特徴量について、平均0, 分散1に揃える(など)の調整を行う処理。
取りうる値の範囲が特徴量により異なる場合において、範囲が大きい特徴量が分類に対して支配的になりうることを避けるために行う。
##特徴量のスケールを調整する l = spam[:,58] f = spam[:,1:57] spam_scaled = DataFrame() #空き箱を用意 for i in 1:ncol(f) spam_scaled[i] = (f[:,i]-mean(f[:,i]))/sqrt(var(f[:,i])) end spam_scaled[58] = l
3-2. 不均衡データ処理
正例と負例のデータ比に偏りがある場合、アルゴリズムが偏った学習をしてしまう。
これを避けるための処理を不均衡データ処理という。
例えば、10,000件のメールの内、spam10件、non-spam9,990件の場合、すべてnon-spamと判定しても正解率99.9%となるので、アルゴリズムの学習が偏りがちになる。
対処法は大きく以下の2つ。
1. 少ない方を間違えた時のペナルティを、多い方を間違えた時より大きくする
-Weighted SVM など
2. データ数を調整して正例数=負例数にする
-Over Sampling
-Under Sampling
今回は、Over Samplingで対応。
##オーバーサンプリング srand(123) #乱数シード固定 spam_scaled_p = spam_scaled[spam_scaled[:x58] .== "spam",:] #正例のデータのみ抽出 spam_scaled_f = spam_scaled[spam_scaled[:x58] .== "nonspam",:] #負例のデータのみ抽出 #正例数と負例数の差分を算出 imbalance = nrow(spam_scaled_f) - nrow(spam_scaled_p) #正例データから、先ほどの差分数をサンプリング(IDを決定) samplingID = rand(1:nrow(spam_scaled_p), imbalance) #正例データから、先ほどの差分数をサンプリング(当該IDを抽出) spam_scaled_p_over = spam_scaled_p[samplingID,:] spam_scaled_sampled = rbind(spam_scaled_p, spam_scaled_p_over, spam_scaled_f)
■4. 分類器の構築
4-1. パラメータチューニング
機械学習による分類器には、SVM(Support Vector Machine)を用いる。
パラメータチューニングは、グリッドサーチと交差検証で行う。
まずは準備。パッケージを読みこんで、データをSVMパッケージの形式に合わせる。
##パラメータチューニングの準備 #各種ライブラリの読み込み Pkg.add("SVM") #初回のみ実行 using SVM Pkg.add("MLBase") #初回のみ実行 using MLBase #SVMの入力形式に合わせ、ラベルを"spam", "nonspam"から、1, -1に変更 label = DataFrame(zeros(nrow(spam_scaled_sampled),1)) for i in 1:nrow(spam_scaled_sampled) if spam_scaled_sampled[i,58] == "spam" label[i,1] = 1 else label[i,1] = -1 end end feature = spam_scaled_sampled[:,1:57]
では、いよいよ、機械学習。
ここでは以下の2つのパッケージを用いる。
SVM.jl:SVMの中でもPegasosアルゴリズムで双対問題を解く実装を採用しているSVM.jlパッケージ。線形カーネルを用いたオンラインSVMにあたる。ちなみに、JuliaではLIBSVMの実装もある。
MLBase:グリッドサーチと交差検証には、機械学習手続きのフレームワークを提供するMLBaseパッケージを活用。 (14日目のsfchaosさんの記事も、是非ご参照ください)
この2つのパッケージを使って、交差検証とグリッドサーチによるパラメータチューニングを行う。
精度評価の指標は、Accuracyを用いる。
##交差検証によるパラメータのグリッドサーチ #交差検証は3hold。Kfold関数でデータIDを分割 cv = 3 datanum = nrow(spam_scaled_sampled) gen = collect(Kfold(datanum, cv)) model_dict = Dict() #交差検証ごとのモデルを格納する空ディクショナリ score_dict = Dict() #交差検証ごとのスコアを格納する空ディクショナリ #モデル推定関数の定義 #3holdの交差検証で3種類のデータセットでモデルを構築して格納 function estfun(lambda, k, T) for k in 1:cv learnID = gen[k] feature_learn = transpose(array(feature[learnID,:])) label_learn = array(label[learnID,:])[:,1] model_dict[k] = svm(feature_learn, label_learn, lambda = lambda, k = k, T = T) end return model_dict end #精度評価関数 #3holdの交差検証で3種類のデータセットでAccuracyを算出。3回の平均値で評価 function evalfun(model_dict) for k in 1:cv testID = setdiff(1:datanum, gen[k]) feature_test = transpose(array(feature[testID,:])) label_test = array(label[testID,:])[:,1] label_pred = predict(model_dict[k], feature_test) score_dict[k] = correctrate(int(label_test), int(label_pred)) end return mean(values(score_dict)) end #グリッドサーチの実行 out = gridtune(estfun, evalfun, ("lambda", [0.0001, 0.001, 0.01, 0.1, 1.0]), ("k", [5, 10, 50, 100, 500, 1000, 5000]), ("T", [100, 500, 1000, 5000, 10000, 50000]); ord=Forward, # 精度評価関数の評価値は、昇順か、降順かを指定 verbose=true) # 推定毎の結果出力の有無を指定 #ベストモデル、ベストパラメータ、ベストスコアを出力 best_model, best_cfg, best_score = out
4-2. 精度評価
上記の出力の結果を示す。
ベストパラメータ
lambda | k | T |
---|---|---|
0.01 | 5000 | 50000 |
ベストスコア
Accuracy |
---|
0.922 |
ちなみに、svm.jlのsrcコードを覗くとわかるが、デフォルトパラメータは、lambda:0.1, k:5, T:100であり、この場合、今回のデータではAccuracyが0.871であった。
パラメータチューニング大事。
なお、MLBaseでは、Accuracy以外にも、混合行列、Precision, Recall, ROC曲線なども精度指標として出力できる。
最後に、ベストパラメータモデルで、混合行列と各種指標を算出する。
##ベストパラメータでの各種評価指標を出力 #ベストパラメータモデルを作成(データは交差検証の1hold目を用いる) lambda = best_cfg[1] k = best_cfg[2] T = best_cfg[3] #交差検証ごとの混合行列の計算 cmatrix = [0 0; 0 0] for k in 1:cv learnID = gen[k] feature_learn = transpose(array(feature[learnID,:])) label_learn = array(label[learnID,:])[:,1] testID = setdiff(1:datanum, gen[k]) feature_test = transpose(array(feature[testID,:])) label_test = array(label[testID,:])[:,1] best_model = svm(feature_learn, label_learn, lambda = lambda, k = k, T = T) pred = int(predict(best_model, feature_test)) #分類の出力が1, 2, …じゃないとダメぽい。-1, 1だったので変換して入力 gt0 = int(label_test) gt = DataFrame(zeros(length(gt0),1)) for i in 1:length(gt0) if gt0[i,1] == 1 gt[i,1] = 1 else gt[i,1] = 2 end end pred0 = int(predict(best_model, feature_test)) pred = DataFrame(zeros(length(pred0),1)) for i in 1:length(pred0) if pred0[i,1] == 1 pred[i,1] = 1 else pred[i,1] = 2 end end cmatrix = cmatrix + confusmat(2, int(array(gt))[:], int(array(pred))[:]) end #混合行列から、precision, recall, Fvalueを算出 precision = cmatrix[1, 1]/sum(cmatrix[:, 1]) recall = cmatrix[1, 1]/sum(cmatrix[1, :]) Fvalue = 2/(1/precision + 1/recall)
結果は以下。
混合行列
2537 251 182 2606
混合行列から各種指標を算出
precision | recall | Fvalue |
---|---|---|
0.933 | 0.910 | 0.921 |
どうもありがとうございました。
あと1週間、アドベントカレンダーを楽しみましょう。メリークリスマス!
■参考文献
@chezou, Julia v0.3.0でSVM.jlを使う
@sfchaos, Juliaによる機械学習の予測モデル構築・評価
@sfchaos, 不均衡データのクラス分類
@sleepy_yoshi, SVM実践ガイド (A Practical Guide to Support Vector Classification)
さいごの碧, kernlabパッケージのspamデータ