オルトプラスエンジニアの日常をお伝えします!

エラー・コードと型消去(Type Erasure)

こんにちは、id:mitsutaka-takadaです。

C++でエラー通知というとエラー・コードや例外が通常の手段かと思います。エラー・コードは戻り値でenumを返すことでエラーを通知します。 例外と比較してenumオブジェクトを返すのみで、とても軽量な通知手段です。

今回の記事では複数のエラー・コードを型消去によってまとめて扱う方法を書きたいと思います。

ゴール

例えば、ネットワーク・リクエストとデータベース・アクセスをしている関数があるとします。それぞれ自身のドメインに関するエラーをエラー・コード(NetworkError&DatabaseError)で通知してくるとき、この関数はエラー・コード(SomeError)として何を返せばよいでしょうか?

#include <string>
#include <experimental/string_view>

// ネットワーク関連のエラー・コード。
enum class NetworkError{
    NoError,
    SomeNetworkError
};

// データベース関連のエラー・コード。
enum class DatabaseError{
    NoError,
    SomeDatabaseError
};

std::pair<std::string, DatabaseError>
getAccountIdFromDatabase(std::string_view userId){
    return {"", DatabaseError::SomeDatabaseError};
}

std::pair<int, NetworkError>
getAccountBalanceFromNetwork(std::string_view accountId){
    return {-1, NetworkError::SomeNetworkError};
}

// エラー情報。ネットワーク&データベースの両方のエラー情報を返したい。
struct SomeError{};

std::pair<int, SomeError>
getUserBalance(std::string_view userId){

    if(auto [accountId, databaseError] = getAccountIdFromDatabase(userId);
       databaseError == DatabaseError::SomeDatabaseError){
        // データベース関連のエラー情報を返したい
        return {};
    }
    else if(auto [balance, networkError] = getAccountBalanceFromNetwork(accountId);               networkError == NetworkError::SomeNetworkError){
        // ネットワーク関連のエラー情報を返したい
        return {};
    }
    else {
        return {balance, {}};
    }
}

<system_error>

エラーコードの和集合、Variantを利用する(下記、おまけ参照)など、いくつか考えられると思いますが、 今回はC++11で導入された<system_error>を利用した方式を紹介したいと思います。

SomeErrorNetworkErrorDatabaseErrorが代入でき(型消去)、SomeErrorからその情報を取得するのが目標です。

<system_error>では、複数ドメインのエラーコードを1つのenumとして統一的に扱うためstd::error_codeを用意しています。std::error_codeはエラーコード(整数値)とカテゴリ(ドメインを表すstd::error_categoryオブジェクト)の対です。エラーコードは整数値で複数ドメインをまたいで一意性が保証されていません。例えば、NetworkError::SomeErrorDatabaseError::SomeErrorに同じ1という値が割り当てられているかもしれません。そのため単純に整数値を比較するだけではNetworkError::SomeErrorDatabaseError::SomeErrorを識別できなくなります。これを防ぐためにカテゴリを使用します。

std::error_codeを使用するには、以下の3ステップが必要です。

  1. エラーコードをstd::error_codeで使用できるように登録する。(std::is_error_code_enumの特殊化)
  2. ドメイン用のカテゴリを定義する。(std::error_categoryの派生クラスの定義)
  3. エラーコードとカテゴリの紐づけを行う。(make_error_codeオーバーロードの定義)

実際にstd::error_codeを使用したコードを見てみましょう。

#include <iostream>
#include <string>
#include <experimental/string_view>
#include <system_error>

enum class NetworkError{
    NoError,
    SomeNetworkError
};

enum class DatabaseError{
    NoError,
    SomeDatabaseError
};

// 1. エラーコードをstd::error_codeで利用できるように登録。
namespace std{
    template<>
    struct std::is_error_code_enum<NetworkError> : std::true_type{};

    template<>
    struct std::is_error_code_enum<DatabaseError> : std::true_type{};
}

// 2. ドメイン用カテゴリの定義。
struct NetworkErrorCategory : std::error_category{
    const char* name() const noexcept override{
        return "NetworkErrorCategory";
    }

    std::string message(int ev) const override{
        switch(static_cast<NetworkError>(ev)){
            case NetworkError::SomeNetworkError:
                return "some network error occured";
            case NetworkError::NoError:
                return "no error";
        }
    }
};

struct DatabaseErrorCategory : std::error_category{
    const char* name() const noexcept override{
        return "DatabaseErrorCategory";
    }

    std::string message(int ev) const override{
        switch(static_cast<NetworkError>(ev)){
            case NetworkError::SomeNetworkError:
                return "some database error occured";
            case NetworkError::NoError:
                return "no error";
        }
    }
};

// カテゴリオブジェクトの比較にアドレス比較を用いるため、
// カテゴリオブジェクトはシングルトンでなければいけません!
NetworkErrorCategory const networkErrorCategoryInstance;
DatabaseErrorCategory const databaseErrorCategoryInstance;

// 3. エラーコードとカテゴリの紐づけ。
inline
std::error_code make_error_code(NetworkError ne){
    return {static_cast<std::underlying_type_t<NetworkError>>(ne), networkErrorCategoryInstance};
}

inline
std::error_code make_error_code(DatabaseError de){
    return {static_cast<std::underlying_type_t<DatabaseError>>(de), databaseErrorCategoryInstance};
}

std::pair<std::string, DatabaseError>
getAccountIdFromDatabase(std::string_view userId){
    return {"", DatabaseError::SomeDatabaseError};
}

std::pair<int, NetworkError>
getAccountBalanceFromNetwork(std::string_view accountId){
    return {-1, NetworkError::SomeNetworkError};
}

std::pair<int, std::error_code>
getUserBalance(std::string_view userId){
    if(auto [accountId, databaseError] = getAccountIdFromDatabase(userId);
        databaseError == DatabaseError::SomeDatabaseError){
        return {-1, databaseError};
    }
    else if(auto [balance, networkError] = getAccountBalanceFromNetwork(accountId);
            networkError == NetworkError::SomeNetworkError) {
        return {-1, networkError};
    }
    else {
        // Successを表現するにはデフォルト・コンストラクタを使用する。
        return {balance, std::error_code{}};
    }
}

int main(){

    if(auto const [balance, error] = getUserBalance("mitsutaka-takeda");
       !error // エラーがあるときは、std::error_codeオブジェクトがtrueになる。
       ){
        // 成功!
        std::cout << "my balance is " << balance << std::endl;
    }
    else{
        // 失敗!
        if(error == NetworkError::SomeNetworkError){
           // handle network error!
        }
        else if(error == DatabaseError::SomeDatabaseError){
           // handle database error!
        }
    }
}

まず、getUserBalanceで複数ドメインのエラーをstd::error_codeとして統一できていることに注目してください。各ドメインのエラーコードNetworkErrorDatabaseErrorからstd::error_codeへの変換は暗黙的に行われます。

またmain関数のエラーハンドリングで、std::error_codeからエラー情報を取得する際、各ドメインのエラーコード(NetworkError::SomeNetworkErrorDatabaseError::SomeDatabaseError)と直接比較しています。

std::error_code自体はポリモーフィズムも利用せず整数値とオブジェクトへの参照の対で軽量な構造体であり、既存のエラーコードに非侵入的に使用できるため色々な場面で活躍できます。またC++の標準ライブラリでも使用されており統一されたエラーハンドリングを行うための基礎になります。

他にも複数のエラーコードをグルーピングするstd::error_conditionなど応用もあるので興味がある方は参考のリンクを見てください。

参考

おまけ

エラーコードの和集合

SomeErrorNetworkErrorDatabaseErrorの値の和として定義することで、2つのエラー情報を持つエラーコードを返すことができます。SomeErrorNetworkErrorDatabaseErrorのコードに対応するコードをすべて追加して、NetworkError/DatabaseErrorからSomeErrorへの変換処理fromNetworkError/fromDatabaseErrorを書きます。

std::error_codeと比較すると、各ドメインのエラーコードへの修正がSomeErrorなど他の箇所にも影響を与えます。

#include <string>
#include <experimental/string_view>

enum class NetworkError{
    NoError,
    SomeNetworkError
};

enum class DatabaseError{
    NoError,
    SomeDatabaseError
};

// ネットワークとデータベースのエラーコードの和集合。
enum class SomeError{
    NoError,
    SomeNetworkError,
    SomeDatabaseError
};

SomeError
fromNetworkError(NetworkError networkError){
    // NetworkErrorからSomeErrorへの変換。
    return networkError == NetworkError::NoError ? SomeError::NoError : SomeError::SomeNetworkError;
}

SomeError
fromDatabaseError(DatabaseError databaseError){
    // DatabaseErrorからSomeErrorへの変換。
    return databaseError == DatabaseError::NoError ? SomeError::NoError : SomeError::SomeDatabaseError;
}

std::pair<std::string, DatabaseError>
getAccountIdFromDatabase(std::string_view userId){
    return {"", DatabaseError::SomeDatabaseError};
}

std::pair<int, NetworkError>
getAccountBalanceFromNetwork(std::string_view accountId){
    return {-1, NetworkError::SomeNetworkError};
}

std::pair<int, SomeError>
getUserBalance(std::string_view userId){
    if(auto [accountId, databaseError] = getAccountIdFromDatabase(userId);
       databaseError == DatabaseError::SomeDatabaseError){
        // DatabaseErrorをSomeErrorに変換して返す。
        return {-1, fromDatabaseError(databaseError)};
    }
    else if(auto [balance, networkError] = getAccountBalanceFromNetwork(accountId);
            networkError == NetworkError::SomeNetworkError) {
        // NetworkErrorをSomeErrorに変換して返す。
        return {-1, fromNetworkError(networkError)};
    }
    else {
        return {balance, SomeError::NoError};
    }
}

C++17 std::variant

SomeErrorNetworkError/DatabaseErrorのC++17で導入されるvariantとして定義する方法です。和集合と比べて、対応する値の定義や変換処理は不要になります。

std::error_codeを使用した方法と比較すると、SomeErrorの型を事前に決定しておかなければいけません。 例えば、getUserBalanceが他ドメイン(ファイルシステム)のエラーを追加で扱わなければいけないとき、 SomeErrorの定義をstd::variant<NoError, NetworkError, DatabaseError, FileSystemError>に変更しなければいけません。std::error_codeでは、型情報は消去されているので、そのような修正入りません。

#include <string>
#include <experimental/string_view>
#include <variant>

enum class NetworkError{
    NoError,
    SomeNetworkError
};

enum class DatabaseError{
    NoError,
    SomeDatabaseError
};

// ネットワークとデータベースのエラーコードのVariant。
struct NoError{};
using SomeError = std::variant<NoError, NetworkError, DatabaseError>;

std::pair<std::string, DatabaseError>
getAccountIdFromDatabase(std::string_view userId){
    return {"", DatabaseError::SomeDatabaseError};
}

std::pair<int, NetworkError>
getAccountBalanceFromNetwork(std::string_view accountId){
    return {-1, NetworkError::SomeNetworkError};
}

std::pair<int, SomeError>
getUserBalance(std::string_view userId){
    if(auto [accountId, databaseError] = getAccountIdFromDatabase(userId);
        databaseError == DatabaseError::SomeDatabaseError){
        // DatabaseErrorをSomeErrorに変換して返す。
        return {-1, databaseError};
    }
    else if(auto [balance, networkError] = getAccountBalanceFromNetwork(accountId);
            networkError == NetworkError::SomeNetworkError) {
        // NetworkErrorをSomeErrorに変換して返す。
        return {-1, networkError};
    }
    else {
        return {balance, NoError{}};
    }
}

型安全な通貨型

仕事ではJava&C#を書いてるC++愛好家のid:mitsutaka-takadaです。

今日は型安全な通貨型の設計について書いてみようと思います。

ゴール

何かと話題のビットコインですが、ビットコインには通貨の単位としてBTCとsatoshiと呼ばれる単位があります。1BTC = 100,000,000Sathoshi = 100 Million Satoshiという関係になります。日本では昔は銭という単位がありましたが、現在は円だけなのでデノミネーションという言葉は意識しないかもしれません。

今回設計する型安全な通貨では許可していない演算はコンパイル時エラーにしバグの侵入を防ぎつつ、 必要なデノミネーションの変換を自動的に行うことで使い勝手の良い型になることを目指します。

Satoshi originalPrice{100};
auto premiumPrice = price + 10; // 10の単位がわからないのでコンパイル・エラー。

Btc myWalletBalance{1};
auto recievePayment = myWalletBalance + originalPrice; // 単位の自動変換。
assert(receivePayment == Satoshi{100'000'100});

設計

通貨は、何単位あるかという量を表現する型(100や1)と、デノミネーション(SatoshiやBtc)を表現する型の組み合わせで表現できそうです。 クラス・テンプレートMonetaryAmountで通貨を表現すると以下のようになります。 型パラメータRepが量を表現する型で、Denomがデノミネーションを表現する型です。

template<typename Rep, typename Denom>
class MonetaryAmount{

};

量を保持できるようにRep型のメンバ変数とコンストラクタを追加します。

template <typename Rep, typename Denom>
class MonetaryAmount{
    Rep count_;
public:
    template <typename Rep2>
    constexpr MonetaryAmount(Rep2 const& count)
    : count_(count)
    {}
};

次に加算演算子を追加します。まずは同じデノミネーションのみの加算を考慮します。 加算演算子をフリー関数で定義するために、量にアクセスするためのCountメンバ関数を定義します。 また、動作確認のために等価演算子も定義します。 クラス・テンプレートに型引数を与えて、ここまでで動作確認をしてみます。単位がついていないプリミティブ型との 加算は意図通りコンパイル・エラーになります。

#include <cstdint>

template <typename Rep, typename Denom>
class MonetaryAmount{
    Rep count_;
public:
    template <typename Rep2>
    constexpr MonetaryAmount(Rep2 const& count)
    : count_(count)
    {}

    constexpr Rep Count() const { return count_; }
};

template <typename Rep, typename Denom>
MonetaryAmount<Rep, Denom>
constexpr operator+(
    MonetaryAmount<Rep, Denom> const& lhs,
    MonetaryAmount<Rep, Denom> const& rhs
    )
{
    return lhs.Count() + rhs.Count();
}

template <typename Rep, typename Denom>
bool
constexpr operator==(
    MonetaryAmount<Rep, Denom> const& lhs,
    MonetaryAmount<Rep, Denom> const& rhs
    )
{
    return lhs.Count() == rhs.Count();
}

// 動作確認。
using Satoshi = MonetaryAmount<std::int64_t, struct ignored>;

constexpr Satoshi a{1}, b{2};

static_assert(a + b == Satoshi{3});
static_assert(a + 10); // コンパイル・エラー。10の単位が不明。

次に異なるデノミネーション間(Btc <-> Satoshi)の変換をサポートしてみます。 デノミネーション間の変換情報には何が必要でしょうか?BtcとSatoshiを変換するには、 1Btcが100,000,000Satoshiという情報があれば良さそうです。Sastoshiのほうが細かい単位なので、 Satoshiを基準に考えると、1Satoshiは1/100,000,000 Btc(1億分の1)になります。 型でこれを表現するには標準ライブラリのクラステンプレートstd::ratioを使用します。

std::ratioは分数を表現する型です。例えば3分の2はstd::ratio<2, 3>と表現できます。 分数が型で表現されていることに注意してください。std::ratio<2, 3>std::ratio<1,3>は異なる型です。

このstd::ratioを利用すると、Btcは MonetaryAmount<std::int64_t, std::ratio<100'000'000, 1>> 、 SatoshiはMonetaryAmount<std::int64_t, std::ratio<1>>で表現できます。つまり型パラメータDenomは、 1 Countあたり何Satoshiに相当するかの情報を表しています。Satoshiの1 Countは1 Satoshiなので、Denom = std::ratio<1>、 Btcの1 Countは100,000,000 Satoshiに相当するので、Denom = std::ratio<100'000'000>となります。

異なるデノミネーションを持つ通貨の演算は、より細かいデノミネーションに合わせて量を行うことで実現します。 つまり、1 Btc + 100 Satoshiは、1 BtcをSatoshiに変換して、100,000,000 Satoshi + 100 Satoshi = 100,000,100 Satoshiになります。

BtcとSatoshiは違う型であることに注意します。クラス・テンプレートMonetaryAmountに異なる型を指定して実体化するからです。 そのため、1 Btcから100,000,000 Satoshiへの変換は型の変換が必要です。 この型変換のために標準ライブラリに用意されたクラス・テンプレートstd::common_typeと、異なる通貨型をとるコンストラクタを利用します。

#include <cstdint>
#include <ratio>
#include <type_traits>

template <typename Rep, typename Denom>
class MonetaryAmount{
    Rep count_;
public:
    template <typename Rep2>
    constexpr MonetaryAmount(Rep2 const& count)
    : count_(count)
    {}

    template<typename Rep2, typename Denom2,
             // 解説1
             typename = std::enable_if_t<std::ratio_divide<Denom2, Denom>::den == 1> >
    constexpr MonetaryAmount(MonetaryAmount<Rep2, Denom2> const& m)
         : count_(m.Count() * std::ratio_divide<Denom2, Denom>::num)
    {}

    constexpr Rep Count() const{
        return count_;
    }
};

template<typename Rep1, typename Denom1, typename Rep2, typename Denom2>
// 解説2
std::common_type_t<MonetaryAmount<Rep1, Denom1>,
                   MonetaryAmount<Rep2, Denom2> >
constexpr operator+(
    MonetaryAmount<Rep1, Denom1> const& lhs,
    MonetaryAmount<Rep2, Denom2> const& rhs
    )
{
    using common_t = std::common_type_t<MonetaryAmount<Rep1, Denom1>,
                                        MonetaryAmount<Rep2, Denom2> >;
    return static_cast<common_t>(lhs).Count() + static_cast<common_t>(rhs).Count();
}

template<typename Rep1, typename Denom1, typename Rep2, typename Denom2>
bool
constexpr operator==(MonetaryAmount<Rep1, Denom1> const& lhs,
                     MonetaryAmount<Rep2, Denom2> const& rhs){
    using common_t = std::common_type_t<MonetaryAmount<Rep1, Denom1>,
                                        MonetaryAmount<Rep2, Denom2> >;
    return static_cast<common_t>(lhs).Count() == static_cast<common_t>(rhs).Count();
}

namespace detail {
    // gcdのヘルパ関数。
    template <typename Integral,
              typename = std::enable_if_t<std::is_integral<Integral>::value > >
    constexpr Integral abs(Integral x){
        return x < 0 ? -x : x;
    }

    // 下のgcdのヘルパ関数。
    template <typename M, typename N>
    constexpr std::common_type_t<M, N> gcd(M m, N n){
        return n == 0 ? abs(m) : gcd(n , abs(m) % abs(n));
    }

    // 型レベルで最大公約数(GCD)を計算する関数。
    template<std::intmax_t Num1, std::intmax_t Denom1, std::intmax_t Num2, std::intmax_t Denom2>
    constexpr auto gcd(std::ratio<Num1, Denom1> const& x,
                       std::ratio<Num2, Denom2> const& y){
        return std::ratio<gcd(Num1, Num2), Denom1 * Denom2>{};
    }
} // namespace detail

namespace std {
    // 解説2
    template<typename Rep1, typename Denom1, typename Rep2, typename Denom2>
    struct common_type<MonetaryAmount<Rep1, Denom1>,
                       MonetaryAmount<Rep2, Denom2> >{
        using type = MonetaryAmount<
            std::common_type_t<Rep1, Rep2>, decltype(detail::gcd(Denom1{}, Denom2{}))>;
    };
} // namespace std

// 動作確認。
using Satoshi = MonetaryAmount<std::int64_t, std::ratio<1>>;
using Btc     = MonetaryAmount<std::int64_t, std::ratio<100'000'000, 1>>;

constexpr Satoshi originalPrice{100};
// constexpr auto premiumPrice = originalPrice + 10; // 10の単位がわからないのでコンパイル・エラー。
constexpr Btc myWalletBalance{1};

// デノミネーションの自動変換。
static_assert((myWalletBalance + originalPrice) == Satoshi{100'000'100});

解説1

SFINAE(Substitution Failure Is Not A Error)を利用して、コンストラクタの利用を制限しています。

異なるデノミネーション間の変換はより大きなデノミネーションから細かなデノミネーションへの変換は誤差なしに行えますが、 逆の変換は切り捨てが発生して誤差が生まれてしまいます。例えば、1 Btcは100'000'000 Satoshiですが、 1 SatoshiはBtcの整数単位では切り捨てられて0 Btcになってしまいます。

この条件をstd::ratio_divide<Denom2, Denom>::den == 1で表現しています。 std::ratio_divideは分数の割り算を行います。3分の1割る2分の1はstd::ratio_divide<std::ratio<1/3>, std::ratio<1, 2>> == std::ratio<2, 3>となります。BtcからSatoshiへの変換はstd::ratio_divide<std::ratio<100'000'000, 1>, std::ratio<1>> == std::ratio<100'000'000, 1>となり変換が許可されますが、SatoshiからBtcへの変換はstd::ratio_divice<std::ratio<1>, std::ratio<100'000'000, 1>> == std::ratio<1, 100'000'000>となり許可されません。

解説2

std::common_typeは2つの型を取って、共通の型を返すクラス・テンプレートです。ユーザ定義型に対して使用する場合は、 std名前空間内で特殊化します。型から型への型レベルの関数です。

型Btc(MonetaryAmount<std::int64_t, std::ratio<100'000'000, 1>>)と型Satoshi(MonetaryAmount<std::int64_t, std::ratio<1>>)の 共通の型は何になるでしょうか。上記のように異なるデノミネーションの演算は、より細かなデノミネーションに合わせて行うことで誤差なしに行えます。 そのためBtcとSatoshiの場合は共通の型はSatoshiとします。

最後に

今回の記事では安全な仮想通貨型の設計を見てきました。一見複雑そうなことをしているように見えますが、ほとんどのロジックは型レベルで行われています。Btc/Satoshiオブジェクトの実態は実は64bitの整数のみです。コンパイラの最適化レベルを上げるとBtc/Satoshiを利用したコードはint64_tを利用したコードとまったく同じになります。コンパイラの最適化処理には脱帽です。記事内のすべてのコードはコンパイル可能になっているので、興味があればCompiler Explorerなどに張り付けていろいろいじってみてください。

この設計はstd::chronoライブラリを参考にしています。個人的にはC++で最も綺麗に設計されたライブラリの1つだと思います。 std::chronoについて詳しく知りたい人は以下の動画がおすすめです。

CPPCON2016 A \ Tutorial by Howard Hinnant

改めて「ITエンジニアのための機械学習理論入門」を改めて読む 〜第2章前編〜

こんにちは。オルトプラスラボに入って2週間と4日の橘です。

今回は前回に引き続き、2章を読んでいきます。
2章は機械学習の基礎中の基礎である2乗誤差についてです。
基礎とは言え、微分や行列が出てくるため、その辺をじっくり見ていきたいと思います。

多項式近似と誤差関数の設定

多項式近似とは何か、を見ていく前に、多項式近似のイメージを次の画像で見てみましょう。



多項式近似とは、「適当な線を引き、学習させたい点(データ)に近づくように線を曲げる」ようなイメージです。まず、これを頭のなかに入れておいてください。

f:id:s_tachibana:20170711175337p:plain


それでは数学的な説明に移ります。多項式とは、

 f(x) = w_{0} + w_{1} x + w_{2} + x^{2} + ... + w_{M} x^{M}

を言います。上の図でいうところの青い線が多項式です。つまり、この多項式を色々動かしてデータを分析したり予測したりしていくわけです。

ちなみに、予測したい各点を  (x_{1},t_{1}), (x_{2}, t_{2}), ... , (x_{N}, t_{N}) と表します。例えば、 x_{1}地点の f(x_{1})との差は次のように見ることができます。

f:id:s_tachibana:20170711185733p:plain

この差を縮めるように、線である多項式を調整していきます。この時、多項式の中の xはデータの値のため、変更できません。そのため、変更できるパラメータは w_{0}, w_{1}, ..., w_{M}です。つまり、機械学習で求めていく値はこの w_{0}, w_{1}, ..., w_{M}ということになります。このことはとても重要なため、よく覚えておいてください。

各データの点と線の差が次の式です。2乗誤差といいます。

 E_{D} = \frac{1}{2} \sum_{n=1}^{N} \{ f(x_{n}) - t_{n} \}^2

なぜ  \frac{1}{2}しているのかというと、後々の計算を楽にするためです。 \frac{1}{2}しても誤差が0になるわけではないので、多めに見てください。各データと線の差である f(x_{n}) - t_{n} を2乗しているのは、すべての差を正の値にするためです。なぜ正の値にする必要があるかというと、例えば3つの点の誤差を見たときに(-0.5, 1, -0.5)だったとき、誤差を足してしまうと0になってしまい、正確にデータを予測しているように見えてしまいます。こうなることを防ぐために、すべての差を2乗しているわけです。この誤差が小さくなるように、 w_{0}, w_{1}, ..., w_{M}の値を求めていきます。


ここからがP64の数学徒の小部屋の解説に入ります。ここで微分が登場します。なぜ微分が登場するかは、たまたま偶然ミラクルに同じ苗字の橘氏がデータホテル様のテックブログで解説してくれているので、そちらをご参照ください。いやー、偶然ってあるんだなー。

datahotel.io

要するに、「ある地点で微分した結果が0のとき、その地点で最大値か最小値になる」という性質があります。これを、各 w_{0}, w_{1}, ..., w_{M}について求めていくわけです。 w_{0}, w_{1}, ..., w_{M}の中のある w_{m'}で微分したときに0になると仮定します。(あえて本の中での m m'とが逆にしています。)

 \sum_{n=1}^{N} ( \sum_{m=0}^{M} w_{m} x_{n}^{m} - t_{n} ) x_{n}^{m'} = 0

一見複雑に見えますが \sumは足し算ですので、順番を変更する事ができます。本の中でも「作為的」と書いてありますが、この計算を行列で表現するための工夫です。

 \sum_{m=0}^{M}w_{m} \sum_{n=0}^{N} x_{n}^{m} m_{n}^{m'} - \sum_{n=0}^{N} t_{n} x_{n}^{m'} = 0

この式は m'について微分した結果ですが、各mについて微分した結果を行列の形に表したものが、次の式です。

 w^{T}\Phi^{T}\Phi - t^{T} \Phi = 0

各行列は本の65ページのとおりです。なぜ、こうなるか困惑した方も多いのではないかと思います。ですので、N=2、M=2のときで計算を試してみました。行列の演算の仕方はデータホテル様のテックブログを御覧ください。


datahotel.io


f:id:s_tachibana:20170720125117j:plain

f:id:s_tachibana:20170720130323j:plain


行列の中の式が

 \sum_{m=0}^{M}w_{m} \sum_{n=0}^{N} x_{n}^{m} m_{n}^{m'} - \sum_{n=0}^{N} t_{n} x_{n}^{m'}

と同じになっています。行列の各行が w_{0}, w_{1}, w_{2}について微分した式と一致していることがわかるかと思います。これをmが0からMのときまで実行しても同じ結果が得られます。この

 w^{T}\Phi^{T}\Phi - t^{T} \Phi = 0

はwについての方程式ですので、w に関して求めると次のような結果になります。

 w = (\Phi^{T} \Phi)^{-1} \Phi^{T}t

 \Phi  x_{1}, x_{2}, ..., x_{n}からなる行列ですので、既にわかっているデータです。また、 tも各 t_{1}, t_{2}, ..., t_{n}からなるベクトルですので、既にわかっているデータです。つまり、この行列の計算をしてしまえば、最も適切な w_{0}, w_{1}, ..., w_{m} が求められます。


(行列の正定値性、ヘッセ行列に関しては次回解説します。)

サンプルコード

Numpyのnp.polyfitメソッドを使うことで、簡単に多項式近似できます。

import numpy as np
import matplotlib.pyplot  as plt
import pandas as pd
from pandas import Series, DataFrame
from numpy.random import normal

# データセットを用意
def create_dataset(num):
    dataset = DataFrame(columns=["x", "y"])
    for i in range(num):
        x = float(i)/float(num-1)
        # 平均0.0、分散0.3の正規分布を誤差として与えている
        y = np.sin(2 * np.pi * x) + normal(scale=0.3)
        dataset = dataset.append(Series([x, y], index=["x", "y"]), ignore_index=True)
    return dataset

if __name__ == "__main__":
    // データセットを作る
    df = create_dataset(10)

    X = df.x.values
    y = df.y.values
    x = np.arange(0, 1.1, 0.01)

    // データセットをプロットする
    plt.scatter(X, y)
    // 近似した線を引く
    plt.plot(x, np.poly1d(np.polyfit(X, y, 4))(x), color="r")
    plt.show()


実行結果は以下のとおりです。

f:id:s_tachibana:20170720211226p:plain


次回

次回は第2章の後半と今回紹介できなかった正定値性、ヘッセ行列に関して解説します。