続き
コードは書籍記載のgithubのそれを見ると一発です。
7章 畳み込みニューラルネットワーク
画像認識の分野で広く使われているいわゆるCNNのお話。これまでのネットワークは隣接のニューロンの間で結合があったので、全結合と呼ばれています。
全結合層の問題点
入力は画像なので、本来上下左右といった空間的なつながり、RGBであればその色方向のつながり、そう言った意味があるはずなのに全結合ではそれを1次元にまとめてしまうため、それら有効な情報が欠落してしまいます。その問題を解決するのがCNNです。
特徴マップ
畳み込み層の入出力データを特に特徴マップと呼びます。これは画像のような2次元のデータだからマップと呼んでおり、全結合の時とは少し趣が異なります。
畳み込み演算
フィルタカーネルを使って畳み込むだけ。この畳み込みが積和の演算になります。フィルタの大きさによって画像が縮むので周囲にパディングしたり、フィルタを適用する画素をスキップする数(ストライド)だったりとパラメータが出てきます。
入力に対して出力される特徴マップの大きさが変化するので注意が必要です。計算式も載っています。
RGBのような3次元データの畳み込みではそれぞれ異なるフィルタが3つあってその和を出すと思えばよいです。3×3のフィルタが3枚で重みの係数は27個で出力は2次元(1枚)になります。
これをブロックで考えるとわかりやすく、入力が例えばHxWxCの画像(特徴マップ)であれば、必要なフィルタはFHxHWxCの大きさとなり、この結果OHxOWx1の画像が得られます。このフィルタがFN個ある層であれば、この層から出てくる画像のサイズはOHxOWxFNとなります。
従って各層で必要となるフィルタの大きさを表現するためには、4次元(output_channel, input_channel, height, width)のデータが必要になります。
バイアスも存在し、最後に係数を加算します。なのでバイアスの大きさは出力のチャネル数FNだけ必要になります。
この畳み込みもバッチ処理できます。なので入出力の画像も複数束ねたバッチサイズ分必要になります。そのためフィルタと同じように4次元(batch_num, channel, height, width)のデータとなります。
プーリング層
縦横の方向を空間的に小さくする演算です。いわゆる画像の縮小で、ある単位の最大値で間引いたり、平均値で間引いたりします。学習パラメータは存在せず、入出力でチャネル数の変化もありません。特徴としては微小な位置変化に対してロバストになります。(1画素ずれたもの同士のMAXプーリングの結果が同じになる)
実装
im2col関数による工夫で、フィルタの畳み込みを行列の掛け算に落とし込んでいます。画像データを重複させて持たせるため、メモリは大きくなります。直感的にはわかるのですがコードがややこしい…。フィルタを積和する画素を横に並べ、積和する回数分縦に並べるイメージで、行列の掛け算で畳み込みを済ませようという理解です。フィルタ(重み)の方は重複させません。
このim2colはかなり一般的な実装のようで、これだけを解説しているページもいろいろあります。そちらを読むとさらにわかった気になります。
そして順伝播と逆伝播の部分だけコードを眺めてみます。
畳み込み
def forward(self, x): FN, C, FH, FW = self.W.shape N, C, H, W = x.shape out_h = 1 + int((H + 2*self.pad - FH) / self.stride) out_w = 1 + int((W + 2*self.pad - FW) / self.stride) col = im2col(x, FH, FW, self.stride, self.pad) col_W = self.W.reshape(FN, -1).T out = np.dot(col, col_W) + self.b out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2) self.x = x self.col = col self.col_W = col_W return out def backward(self, dout): FN, C, FH, FW = self.W.shape dout = dout.transpose(0,2,3,1).reshape(-1, FN) self.db = np.sum(dout, axis=0) self.dW = np.dot(self.col.T, dout) self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW) dcol = np.dot(dout, self.col_W.T) dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
プーリング
def forward(self, x): N, C, H, W = x.shape out_h = int(1 + (H - self.pool_h) / self.stride) out_w = int(1 + (W - self.pool_w) / self.stride) col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad) col = col.reshape(-1, self.pool_h*self.pool_w) arg_max = np.argmax(col, axis=1) out = np.max(col, axis=1) out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2) self.x = x self.arg_max = arg_max return out def backward(self, dout): dout = dout.transpose(0, 2, 3, 1) pool_size = self.pool_h * self.pool_w dmax = np.zeros((dout.size, pool_size)) dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten() dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad) return dx
わかったようなわからないような
CNNの可視化
学習後のフィルタの2次元のデータを画像として可視化するとどんな情報に反応しているかなんとなく想像できるよね。って話です。学習前がランダムだったのに、学習後は何らかの検出器になってそうな雰囲気がわかります。
層の間の2次元データに関しても同様に観察することができて、どんなテクスチャやエッジを検出したのかがなんとなく観察できます。層の浅いところから深いところでどんなものに反応しているのか観察することができます。
col2imの関数もあるのでこいつを使えば可視化ができるのだと思われます。
8章 ディープラーニング
技術的な話ではなく、読み物ですね。
コメント