いつもの作業の備忘録

作業を忘れがちな自分のためのブログ

【Caffe】Fine Tuningする

以下のサイトの後半に書かれているファインチューニングの手順を実行
Caffeで手軽に画像分類 - Yahoo! JAPAN Tech Blog

1.データ準備

学習時には3つのデータが必要。

  • train.txt:ネットワークのパラメタ(重み)を学習するデータ
  • val.txt :学習時に定期的にテストし、過学習を起こしていないかチェックするデータ
  • test.txt :本番テスト用データ

上記はすべて画像ファイルパスと正解クラスがタブ区切りとなったデータ。

$ cat train.txt
accordion/image_0001.jpg 0
accordion/image_0002.jpg 0
   :

train.txtとtest.txtはせっかくなので、前回利用した20枚の学習データ、8枚の評価データと同じものを準備。val.txtは新しく10枚ほど調達

2.平均画像生成

Caffeでは平均画像をあらかじめ引き算した画像を利用するらしい。背景が決まっているときなどは特に有効そう。平均画像を使わなくても学習できるらしいが、今回は平均画像を利用する。
Caffeの仕様が変わって元サイト通りにやると Incorrect data field size というエラーが出てコアダンプするため以下のように変更

$ $CAFFE_HOME/build/tools/convert_imageset.bin -gray -resize_width 256 -resize_height 256 $CAFFE_HOME/data/101_ObjectCategories/ train.txt caltech101_train_leveldb
$ $CAFFE_HOME/build/tools/convert_imageset.bin -gray -resize_width 256 -resize_height 256 $CAFFE_HOME/data/101_ObjectCategories/ val.txt caltech101_val_leveldb
$ $CAFFE_HOME/build/tools/compute_image_mean.bin caltech101_train_leveldb/ caltech101_mean.binaryproto

3.モデルファイルの編集

もともとbvlc_reference_caffenetで使っていたモデルを元にFine Tuningを実施するため、元のモデルをコピーし、ネットワークの名前を変更する

$ cd $CAFFE_HOME
$ cp $CAFFE_HOME/models/bvlc_reference_caffenet/*.prototxt $CAFFE_HOME
$ sed -i -e 's/fc8/fc8ft/g' train_val.prototxt deploy.prototxt

ファイル編集
solver.prototxt

#変更前
net: "models/bvlc_reference_caffenet/train_val.prototxt"
test_iter: 1000
test_interval: 1000
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 100000
display: 20
max_iter: 450000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "models/bvlc_reference_caffenet/caffenet_train"
solver_mode: GPU
↓
#変更後
net: "train_val.prototxt"
test_iter: 1000
test_interval: 1000
base_lr: 0.001
lr_policy: "step"
gamma: 0.1
stepsize: 100000
display: 20
max_iter: 450000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "snapshot"
solver_mode: GPU

train_val.prototxt

#変更前
  transform_param {
    mirror: true
    crop_size: 227
    mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"
  }
↓
#変更後
  transform_param {
    mirror: true
    crop_size: 227
    mean_file: "caltech101_mean.binaryproto"
  }
#変更前
  data_param {
    source: "examples/imagenet/ilsvrc12_train_lmdb"
    batch_size: 256
    backend: LMDB
  }
↓
#変更後
  data_param {
    source: "caltech101_train_leveldb"
    batch_size: 256
    backend: LMDB
  }
#変更前
  transform_param {
    mirror: false
    crop_size: 227
    mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"
  }
↓
#変更後
  transform_param {
    mirror: false
    crop_size: 227
    mean_file: "caltech101_mean.binaryproto"
  }
#変更前
  data_param {
    source: "examples/imagenet/ilsvrc12_val_lmdb"
    batch_size: 50
    backend: LMDB
  } 
↓
#変更後
  data_param {
    source: "caltech101_val_leveldb"
    batch_size: 50
    backend: LMDB
  }
#変更前
  inner_product_param {
    num_output: 1000
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
↓
#変更後
  inner_product_param {
    num_output: 102
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }

deploy.prototxt

#変更前
name: "CaffeNet"
layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param { shape: { dim: 10 dim: 3 dim: 227 dim: 227 } }
}
↓
#変更後
name: "CaffeNet"
layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param { shape: { dim: 10 dim: 3 dim: 256 dim: 256 } }
}
#変更前
layer {
  name: "fc8ft"
  type: "InnerProduct"
  bottom: "fc7"
  top: "fc8ft"
  inner_product_param {
    num_output: 1000
  }
}
↓
#変更後
layer {
  name: "fc8ft"
  type: "InnerProduct"
  bottom: "fc7"
  top: "fc8ft"
  inner_product_param {
    num_output: 102
  }
}

4.学習実行

$ $CAFFE_HOME/build/tools/caffe train -solver solver.prototxt

結果

I0304 23:01:09.303612  5094 solver.cpp:244]     Train net output #0: loss = 0.00736864 (* 1 = 0.00736864 loss)
I0304 23:01:09.303623  5094 sgd_solver.cpp:106] Iteration 5920, lr = 0.001
I0304 23:01:14.809386  5094 solver.cpp:228] Iteration 5940, loss = 0.0244083
I0304 23:01:14.809422  5094 solver.cpp:244]     Train net output #0: loss = 0.0244083 (* 1 = 0.0244083 loss)
I0304 23:01:14.809435  5094 sgd_solver.cpp:106] Iteration 5940, lr = 0.001
I0304 23:01:20.310364  5094 solver.cpp:228] Iteration 5960, loss = 0.0299701
I0304 23:01:20.310396  5094 solver.cpp:244]     Train net output #0: loss = 0.0299701 (* 1 = 0.0299701 loss)
I0304 23:01:20.310410  5094 sgd_solver.cpp:106] Iteration 5960, lr = 0.001
I0304 23:01:25.807876  5094 solver.cpp:228] Iteration 5980, loss = 0.00960639
I0304 23:01:25.807914  5094 solver.cpp:244]     Train net output #0: loss = 0.00960641 (* 1 = 0.00960641 loss)
I0304 23:01:25.807925  5094 sgd_solver.cpp:106] Iteration 5980, lr = 0.001
I0304 23:01:31.037081  5094 solver.cpp:337] Iteration 6000, Testing net (#0)
I0304 23:01:52.026041  5094 solver.cpp:404]     Test net output #0: accuracy = 0.31178
I0304 23:01:52.026075  5094 solver.cpp:404]     Test net output #1: loss = 7.09174 (* 1 = 7.09174 loss)
I0304 23:01:52.117260  5094 solver.cpp:228] Iteration 6000, loss = 0.00708895
I0304 23:01:52.117293  5094 solver.cpp:244]     Train net output #0: loss = 0.00708897 (* 1 = 0.00708897 loss)
I0304 23:01:52.117307  5094 sgd_solver.cpp:106] Iteration 6000, lr = 0.001
I0304 23:01:57.617044  5094 solver.cpp:228] Iteration 6020, loss = 0.0195226
I0304 23:01:57.617076  5094 solver.cpp:244]     Train net output #0: loss = 0.0195226 (* 1 = 0.0195226 loss)
I0304 23:01:57.617087  5094 sgd_solver.cpp:106] Iteration 6020, lr = 0.001
^CI0304 23:01:58.722149  5094 solver.cpp:454] Snapshotting to binary proto file snapshot_iter_6025.caffemodel
I0304 23:01:59.396268  5094 sgd_solver.cpp:273] Snapshotting solver state to binary proto file snapshot_iter_6025.solverstate
I0304 23:01:59.598551  5094 solver.cpp:301] Optimization stopped early.
I0304 23:01:59.598588  5094 caffe.cpp:222] Optimization Done.

6000回目のiterationの段階で、validationデータの正解率(Test net output #0: accuracy)が0.31178と低めで収束した。過学習している状態だと予想される。学習データが20枚しかないのが原因のように思われる。学習率を変えたり、Data Augumentationしたりと学習のノウハウが必要なところだろう


※参考
http://techblog.yahoo.co.jp/programming/caffe-intro/
http://qiita.com/CORDEA/items/9fad27ae024928b6a7b1
https://gist.github.com/rezoo/a1c8d1459b222fc5658f