概要
ニューラルネットワークの学習アルゴリズムを、ミニバッチ、推論、損失関数、勾配、パラメータ更新の流れとして整理する。
これまで扱った個別の処理は、学習ループの中で順番につながっている。
ここでは、入力データから損失を計算し、勾配を求めて重みを更新するまでの全体像をまとめる。
この記事で扱うこと
- ニューラルネットワーク学習の全体的な処理順序。
- ミニバッチ、損失関数、勾配、更新式の関係。
- 推論処理と学習処理の違い。
- これまでの各記事が学習アルゴリズムのどこに対応するか。
作業前に確認すること
| 確認項目 | 内容 |
|---|---|
| 前提記事 | 損失関数、数値微分、勾配降下法、重み更新の記事を確認しておく。 |
| 用語 | 重み、バイアス、ミニバッチ、学習率の意味を整理しておく。 |
| 実装視点 | 処理の順番を追いながら、どこで値が変わるかを意識する。 |
学習アルゴリズムの流れ(まとめ)
ニューラルネットワークの学習とは、訓練データを基に意図した結果となる最適な重みパラメータを判別する処理を指す。
その処理を大きく分類すると、以下の手順(*1)~(*4)の4つに分かれ、これを繰り返すことで限りなく最適な重みパラメータに近づけていく。
その結果、最適な重みパラメータによって、正解率がほぼ100%となる画像認識や膨大なデータに基づく根拠ある推測が実現できるようになる。
-
(*1)訓練データの抽出(ミニバッチの決定)
訓練データの中からランダムに100件程度(ある程度信頼性がある件数)抽出したデータ群をミニバッチといい、以降の(*2)~(*4)では、このミニバッチ単位に損失関数の結果と勾配を求めるアルゴリズムになっている。※ 参考
Python - ニューラルネットワーク: 交差エントロピー誤差のミニバッチ学習と実装サンプル > ミニバッチ学習とは -
(*2)損失関数の結果取得
ニューラルネットワークの学習では、損失関数というニューラルネットワークの性能の悪さ表す指標を基準にする。損失関数の有名どころでは2乗和誤差と交差エントロピー誤差があるが、このシリーズでは、交差エントロピー誤差にスポットを当てた実装サンプルを記載している。
※ 参考
-
(*3)勾配の取得
損失関数の結果が最も小さい値となる重みとなるように自己探索(最適な重みパラメータを探す)していくことになる。損失関数の結果は、小さいほど正解に近づいているわけだが、もちろんそこで終わりではなく今の結果より正解に近いパラメータ候補(重み、バイアス)を決めてさらに正解に近づけていく必要がある。
そこで基準になるのが重みパラメータの微分結果(勾配値)。
※ 参考
- Python - ニューラルネットワーク: 損失関数と数値微分(勾配)の実装サンプル > 損失関数と微分の関係
- Python - ニューラルネットワーク: 損失関数と数値微分(勾配)の実装サンプル > 微分のおさらい
- Python - ニューラルネットワーク: 損失関数と数値微分(勾配)の実装サンプル > 数値微分の関数定義(Python実装サンプル)
- Python - ニューラルネットワーク: 損失関数と数値微分(勾配)の実装サンプル > 数値微分の例(Python実装サンプル)
- Python - ニューラルネットワーク: 偏微分と勾配の実装サンプル > 偏微分のおさらい
- Python - ニューラルネットワーク: 偏微分と勾配の実装サンプル > 偏微分のPython実装サンプル
- Python - ニューラルネットワーク: 偏微分と勾配の実装サンプル > 勾配のPython実装サンプル
-
(*4)重みパラメータの更新
重みパラメータを勾配方向(より損失が少ない重みパラメータへ)へ微小量だけ更新する。※ 参考
勾配法についての補足
上記の勾配に関する記述は、すべて勾配降下法によってパラメータを更新する方法となる。
そしてミニバッチは、ランダム抽出したデータ群であることから確率的勾配降下法といい、ディープラーニングのフレームワークでは、一般的にSGD(stochastic gradient descent)と呼ばれる。
※ 参考
- Python - ニューラルネットワーク: 勾配降下法の実装サンプル > 勾配法とは
- Python - ニューラルネットワーク: 勾配降下法の実装サンプル > 勾配降下法のPython実装サンプル
- Python - ニューラルネットワーク: 勾配降下法の実装サンプル > 勾配降下法の実装サンプル実行例
違いを整理する
| 比較する項目 | 整理するポイント |
|---|---|
| 推論と学習の混同 | 推論は予測を出す処理、学習は重みを更新する処理。 |
| 損失を直接下げるわけではない | 勾配を使ってパラメータを動かし、その結果として損失を下げる。 |
| 一回で最適化されない | ミニバッチを変えながら何度も更新を繰り返す。 |
実務とのつながり
- 学習ログの読み方
損失や精度の推移を見ると、学習が進んでいるか、発散しているかを判断しやすい。 - フレームワーク理解
PyTorchやTensorFlowを使う場合でも、裏側では同じ流れでパラメータ更新が行われる。
まとめ
- 学習は、ミニバッチ抽出、推論、損失計算、勾配計算、重み更新を繰り返す処理。
- 損失関数はモデルの悪さを測り、勾配は改善方向を探す手がかりになる。
- 一連の流れを理解すると、深層学習フレームワークの動作も追いやすくなる。
参考文献
- 斎藤 康毅(\(2018\))『ゼロから作るDeep Learning - Pythonで学ぶディープラーニングの理論と実装』株式会社オライリー・ジャパン