いつもの作業の備忘録

作業を忘れがちな自分のためのブログ

【Caffe】C++でのMemoryDataLayerの扱い方

0.課題

 caffe-masterのサンプルとして提供されているcaffe.exeを利用するだけでは、Caffeを使ったアプリケーション開発には不便である。そのようなアプリケーション開発をする場合、画像データやラベルデータをプログラムの中で自由に読み込み、ネットワークにかけることが必要となる。本稿では、通常の多クラス識別の学習&評価をC++プログラムを通じて行うことをゴールとする。コード類は参考のためにGitHubに掲載する。
https://github.com/whg-res/MemoryDataLayerSample.git

1.データの準備

 今回はCaltech101データセットを用いた102クラス識別(101オブジェクトクラス+1背景クラス)を行う。利用するネットワーク構造はGoogLeNetとする。Caltech101データセットは以下のサイトからダウンロードする。
https://www.vision.caltech.edu/Image_Datasets/Caltech101/
 GoogLeNetはダウンロードしたcaffe-master中の、caffe-master/models/bvlc_googlenet/にprototxtなどのファイルが存在する。ただし、重み情報を格納したcaffemodelファイルは別途ダウンロードする必要がある。
https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet

 続いて、Caltech101データセットの学習リストと評価リストを作成する。これらは通常のcaffeでの多クラス識別と同じ以下の形式とする。

 <画像パス> <クラスNo>

 学習リストと評価リストはランダムにソートしておく。学習の際、ランダムに読み込まなければうまく収束しないためである。

G:/DATA/101_ObjectCategories/cougar_face/image_0060.jpg 22
G:/DATA/101_ObjectCategories/revolver/image_0072.jpg 76
G:/DATA/101_ObjectCategories/trilobite/image_0008.jpg 93
G:/DATA/101_ObjectCategories/crocodile_head/image_0043.jpg 26
G:/DATA/101_ObjectCategories/airplanes/image_0167.jpg 1
G:/DATA/101_ObjectCategories/Faces/image_0423.jpg 37
G:/DATA/101_ObjectCategories/Faces_easy/image_0417.jpg 38
G:/DATA/101_ObjectCategories/scorpion/image_0066.jpg 82

2.prototxtファイルの編集

 続いて、ネットワークの定義を変更する。変更箇所である2か所について説明する。

 まず、入力のデータレイヤを変更する。デフォルトでは、LMDBにデータベース化されたデータを読み込む設定になているが、プログラムから直接利用する場合は不便である。今回はプログラムからメモリに画像ファイルを読み込み、読み込んだデータをネットワークの入力に設定することが可能なMemoryDataLayerを利用する。入力層を以下のように変更する。

layer {
  name: "data"
  type: "MemoryData"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  memory_data_param{
    batch_size: 5
    channels: 3
    height: 224
    width: 224
  }
}
layer {
  name: "data"
  type: "MemoryData"
  top: "data"
  top: "label"
  include {
    phase: TEST
  }
  memory_data_param {
    channels: 3
    height: 224
    width: 224
    batch_size: 50
  }
}

 入力層のtypeをMemoryDataに設定し、そのサイズ情報をmemory_data_paramで設定している。MemoryDataには2つの入力が設定できる。ひとつは3次元データ(data)であり、今回のケースでは画像に対応する。もうひとつは、ラベルデータ(label)である。これはスカラー量と決まっており、今回の場合はクラス番号がそれにあたる。

 次に、出力層の出力次元数を変更する。GoogLeNetの場合、3つの出力層が存在するため、それらすべてを変更する必要がある。下記は最後の出力層loss3/classifier_の変更例であるが、loss2/classifier_、loss1/classifier_も同様に変更する必要がある。変更すべきはnum_outputの値。今回は101クラス+背景クラスを識別するため102に設定する。

layer {
  name: "loss3/classifier_"
  type: "InnerProduct"
  bottom: "pool5/7x7_s1"
  top: "loss3/classifier_"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 102
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

3.Caffe利用コードの作成

 実際にC++でコードを書く。基本的には、①ネットワークを用意する、②読み込んだデータをネットワークにセットする、③学習or評価を実行するという流れになる。今回はMemoryDataLayerを使う上で重要な②について言及しておく。それ以外の部分はソースコードを参照頂きたい。まず、ネットワークに設定するデータはfloat型の1次元配列として用意する。

	train_input_data	= new float[train_data_size*HEIGHT*WIDTH*CHANNEL];
	train_label			= new float[train_data_size];
	test_input_data		= new float[test_data_size*HEIGHT*WIDTH*CHANNEL];
	test_label			= new float[test_data_size];

 カラー(3チャネル)の全画像データを1次元に格納する必要がある。ラベルは1画像に対して長さ1(スカラー)なので配列としては画像枚数分の長さを確保する。このデータをネットワークに設定する。

	//ネットワークに反映
	const auto train_input_layer = boost::dynamic_pointer_cast<MemoryDataLayer<float>>( net->layer_by_name("data") );
	const auto test_input_layer = boost::dynamic_pointer_cast<MemoryDataLayer<float>>( test_net[0]->layer_by_name("data") );
	train_input_layer->Reset((float*)train_input_data, (float*)train_label, train_data_size);
	test_input_layer->Reset((float*)test_input_data, (float*)test_label, test_data_size);

 上記では学習ネットワークnetと評価ネットワークtest_net[0]の"data"と名付けられたMemoryDataLayerに読み込んだ画像データとラベルデータをセットしている部分である。この状態でSolve()関数を実行するだけで、設定されたバッチサイズで学習が行われるため、学習の場合はバッチ処理をあまり意識する必要はない。
 また、肝心のデータ読み込み部分readImgListToFloat()は以下のようにする。

void readImgListToFloat(string list_path, float *data, float *label, int data_len){

	ifstream ifs;
	string str;
	int n = 0;
	ifs.open(list_path, std::ios::in);
	if (!ifs){ LOG(INFO) << "cannot open " << list_path; return; }

	float mean[CHANNEL] = { 104, 117, 123 };

	while (getline(ifs, str)){
		vector<string> entry = split(str, ' ');
		cout << "reading: " << entry[0] << endl;
		cv::Mat img = cv::imread(entry[0]);
		cv::Mat resized_img;
		cv::resize(img, resized_img, cv::Size(WIDTH, HEIGHT));
		for (int y = 0; y < HEIGHT; y++){
			for (int x = 0; x < WIDTH; x++){
				data[y*resized_img.cols + x + resized_img.cols*resized_img.rows*0 + WIDTH * HEIGHT * CHANNEL * n]
					= resized_img.data[y*resized_img.step + x*resized_img.elemSize() + 0] - mean[0];
				data[y*resized_img.cols + x + resized_img.cols*resized_img.rows*1 + WIDTH * HEIGHT * CHANNEL * n]
					= resized_img.data[y*resized_img.step + x*resized_img.elemSize() + 1] - mean[1];
				data[y*resized_img.cols + x + resized_img.cols*resized_img.rows*2 + WIDTH * HEIGHT * CHANNEL * n]
					= resized_img.data[y*resized_img.step + x*resized_img.elemSize() + 2] - mean[2];
			}
		}
		label[n] = stof(entry[1]);
		n++;
	}
}

 今回、画像はBチャネル、Gチャネル、Rチャネルの順で配置されるようにしている。
f:id:whg_res:20170211205558p:plain
 内部で平均値を引いているので、順番が間違っているとまずい。※実際は多少違ってもFine-tuneをかければそこまで大きな問題にならないかもしれないが

 また、評価の際はバッチサイズを意識する必要がある。ネットワークにセットする画像数をバッチサイズ分に限定する必要があり、Forward()関数で得られた結果もセットしたバッチサイズ分だけ帰ってくることを意識してコードを組む必要がある。それぞれ、以下のプログラムでReset()しているデータサイズと、帰ってきた値(result)から結果を表示する部分を参考にしてもらいたい。

	for (int batch = 0; batch < batch_iter; batch++){
		//入力データを選択的にネットワークにセット&識別
		input_test_layer->Reset((float*)test_input_data + batch * WIDTH*HEIGHT*CHANNEL * batch_size, (float*)test_label + batch * batch_size, batch_size);
		const auto result = test_net.Forward();

		//結果を受け取り、一番スコアの高いクラスに識別する
		const auto data = result[1]->cpu_data();
		for (int i = 0; i < batch_size; i++){
			int max_id = 0;
			float max = 0;
			for (int j = 0; j < NCLASS; j++){
				if (max < data[i * NCLASS + j]){
					max = data[i * NCLASS + j];
					max_id = j;
				}
			}
			cout << max_id << ", " << max << endl;
		}
	}

4.実行例

 通常のcaffe.exeによるクラス識別と機能的には変わらないので、データをロードした後は見慣れた学習画面が現れる。学習誤差が減少していることが見て取れる。私の環境では25epoch程度回せばそれなりに収束した。
f:id:whg_res:20170211210611p:plain


【参考】
https://www.vision.caltech.edu/Image_Datasets/Caltech101/
https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet