オルトプラスエンジニアの日常をお伝えします!

改めて「ITエンジニアのための機械学習理論入門」を改めて読む 〜第2章前編〜

こんにちは。オルトプラスラボに入って2週間と4日の橘です。

今回は前回に引き続き、2章を読んでいきます。
2章は機械学習の基礎中の基礎である2乗誤差についてです。
基礎とは言え、微分や行列が出てくるため、その辺をじっくり見ていきたいと思います。

多項式近似と誤差関数の設定

多項式近似とは何か、を見ていく前に、多項式近似のイメージを次の画像で見てみましょう。



多項式近似とは、「適当な線を引き、学習させたい点(データ)に近づくように線を曲げる」ようなイメージです。まず、これを頭のなかに入れておいてください。

f:id:s_tachibana:20170711175337p:plain


それでは数学的な説明に移ります。多項式とは、

 f(x) = w_{0} + w_{1} x + w_{2} + x^{2} + ... + w_{M} x^{M}

を言います。上の図でいうところの青い線が多項式です。つまり、この多項式を色々動かしてデータを分析したり予測したりしていくわけです。

ちなみに、予測したい各点を  (x_{1},t_{1}), (x_{2}, t_{2}), ... , (x_{N}, t_{N}) と表します。例えば、 x_{1}地点の f(x_{1})との差は次のように見ることができます。

f:id:s_tachibana:20170711185733p:plain

この差を縮めるように、線である多項式を調整していきます。この時、多項式の中の xはデータの値のため、変更できません。そのため、変更できるパラメータは w_{0}, w_{1}, ..., w_{M}です。つまり、機械学習で求めていく値はこの w_{0}, w_{1}, ..., w_{M}ということになります。このことはとても重要なため、よく覚えておいてください。

各データの点と線の差が次の式です。2乗誤差といいます。

 E_{D} = \frac{1}{2} \sum_{n=1}^{N} \{ f(x_{n}) - t_{n} \}^2

なぜ  \frac{1}{2}しているのかというと、後々の計算を楽にするためです。 \frac{1}{2}しても誤差が0になるわけではないので、多めに見てください。各データと線の差である f(x_{n}) - t_{n} を2乗しているのは、すべての差を正の値にするためです。なぜ正の値にする必要があるかというと、例えば3つの点の誤差を見たときに(-0.5, 1, -0.5)だったとき、誤差を足してしまうと0になってしまい、正確にデータを予測しているように見えてしまいます。こうなることを防ぐために、すべての差を2乗しているわけです。この誤差が小さくなるように、 w_{0}, w_{1}, ..., w_{M}の値を求めていきます。


ここからがP64の数学徒の小部屋の解説に入ります。ここで微分が登場します。なぜ微分が登場するかは、たまたま偶然ミラクルに同じ苗字の橘氏がデータホテル様のテックブログで解説してくれているので、そちらをご参照ください。いやー、偶然ってあるんだなー。

datahotel.io

要するに、「ある地点で微分した結果が0のとき、その地点で最大値か最小値になる」という性質があります。これを、各 w_{0}, w_{1}, ..., w_{M}について求めていくわけです。 w_{0}, w_{1}, ..., w_{M}の中のある w_{m'}で微分したときに0になると仮定します。(あえて本の中での m m'とが逆にしています。)

 \sum_{n=1}^{N} ( \sum_{m=0}^{M} w_{m} x_{n}^{m} - t_{n} ) x_{n}^{m'} = 0

一見複雑に見えますが \sumは足し算ですので、順番を変更する事ができます。本の中でも「作為的」と書いてありますが、この計算を行列で表現するための工夫です。

 \sum_{m=0}^{M}w_{m} \sum_{n=0}^{N} x_{n}^{m} m_{n}^{m'} - \sum_{n=0}^{N} t_{n} x_{n}^{m'} = 0

この式は m'について微分した結果ですが、各mについて微分した結果を行列の形に表したものが、次の式です。

 w^{T}\Phi^{T}\Phi - t^{T} \Phi = 0

各行列は本の65ページのとおりです。なぜ、こうなるか困惑した方も多いのではないかと思います。ですので、N=2、M=2のときで計算を試してみました。行列の演算の仕方はデータホテル様のテックブログを御覧ください。


datahotel.io


f:id:s_tachibana:20170720125117j:plain

f:id:s_tachibana:20170720130323j:plain


行列の中の式が

 \sum_{m=0}^{M}w_{m} \sum_{n=0}^{N} x_{n}^{m} m_{n}^{m'} - \sum_{n=0}^{N} t_{n} x_{n}^{m'}

と同じになっています。行列の各行が w_{0}, w_{1}, w_{2}について微分した式と一致していることがわかるかと思います。これをmが0からMのときまで実行しても同じ結果が得られます。この

 w^{T}\Phi^{T}\Phi - t^{T} \Phi = 0

はwについての方程式ですので、w に関して求めると次のような結果になります。

 w = (\Phi^{T} \Phi)^{-1} \Phi^{T}t

 \Phi  x_{1}, x_{2}, ..., x_{n}からなる行列ですので、既にわかっているデータです。また、 tも各 t_{1}, t_{2}, ..., t_{n}からなるベクトルですので、既にわかっているデータです。つまり、この行列の計算をしてしまえば、最も適切な w_{0}, w_{1}, ..., w_{m} が求められます。


(行列の正定値性、ヘッセ行列に関しては次回解説します。)

サンプルコード

Numpyのnp.polyfitメソッドを使うことで、簡単に多項式近似できます。

import numpy as np
import matplotlib.pyplot  as plt
import pandas as pd
from pandas import Series, DataFrame
from numpy.random import normal

# データセットを用意
def create_dataset(num):
    dataset = DataFrame(columns=["x", "y"])
    for i in range(num):
        x = float(i)/float(num-1)
        # 平均0.0、分散0.3の正規分布を誤差として与えている
        y = np.sin(2 * np.pi * x) + normal(scale=0.3)
        dataset = dataset.append(Series([x, y], index=["x", "y"]), ignore_index=True)
    return dataset

if __name__ == "__main__":
    // データセットを作る
    df = create_dataset(10)

    X = df.x.values
    y = df.y.values
    x = np.arange(0, 1.1, 0.01)

    // データセットをプロットする
    plt.scatter(X, y)
    // 近似した線を引く
    plt.plot(x, np.poly1d(np.polyfit(X, y, 4))(x), color="r")
    plt.show()


実行結果は以下のとおりです。

f:id:s_tachibana:20170720211226p:plain


次回

次回は第2章の後半と今回紹介できなかった正定値性、ヘッセ行列に関して解説します。