使用python/numpy實現im2col的學習心得
- 背景
- 書上的程序
- 分析
- 首先是:
- 其次:
- 寫在最后
背景
最近在看深度學習的東西。使用的參考書是《深度學習入門——基于python的理論與實現》。在看到7.4時,里面引入了一個im2col的函數,從而方便講不斷循環進行地相乘相加操作變成矩陣的運算,通過空間資源換取時間效率。
為什么要這么操作和操作以后col矩陣的樣子比較好理解。由于對python和numpy不太熟悉,理解書上給出的程序實現想了很久。終于有點感覺了,記錄下來。
書上的程序
def
im2col
(
input_data
,
filter_h
,
filter_w
,
stride
=
1
,
pad
=
0
)
:
"""
Parameters
----------
input_data : 由(數據量, 通道, 高, 長)的4維數組構成的輸入數據
filter_h : 濾波器的高
filter_w : 濾波器的長
stride : 步幅
pad : 填充
Returns
-------
col : 2維數組
"""
N
,
C
,
H
,
W
=
input_data
.
shape
out_h
=
(
H
+
2
*
pad
-
filter_h
)
//
stride
+
1
out_w
=
(
W
+
2
*
pad
-
filter_w
)
//
stride
+
1
img
=
np
.
pad
(
input_data
,
[
(
0
,
0
)
,
(
0
,
0
)
,
(
pad
,
pad
)
,
(
pad
,
pad
)
]
,
'constant'
)
col
=
np
.
zeros
(
(
N
,
C
,
filter_h
,
filter_w
,
out_h
,
out_w
)
)
for
y
in
range
(
filter_h
)
:
y_max
=
y
+
stride
*
out_h
for
x
in
range
(
filter_w
)
:
x_max
=
x
+
stride
*
out_w
col
[
:
,
:
,
y
,
x
,
:
,
:
]
=
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
col
=
col
.
transpose
(
0
,
4
,
5
,
1
,
2
,
3
)
.
reshape
(
N
*
out_h
*
out_w
,
-
1
)
return
col
分析
首先只考慮一個數據,即此時 N = 1 N=1 N = 1 。并且假設數據只有一層,比如灰度圖,即 C = 1 C=1 C = 1 。假設數據的高和長分別為4,4。即 H = 4 H=4 H = 4 , W = 4 W=4 W = 4 。濾波器的長和高分別為2,2。即 f i l t e r _ h = 2 filter\_h=2 f i l t e r _ h = 2 , f i l t e r _ w = 2 filter\_w=2 f i l t e r _ w = 2 。更進一步地,將Pad簡化為0。
此時,img就是一個
4 ? 4 4*4
4
?
4
的矩陣,假設如下:
濾波器是
2 ? 2 2*2
2
?
2
的矩陣,假設為
因此,卷積層輸出是
3 ? 3 3*3
3
?
3
的矩陣。
有了這些預備設定,就可以開始理解程序了。我們重點關注兩句話。
首先是:
col
[
:
,
:
,
y
,
x
,
:
,
:
]
=
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
y和x分別代表濾波器的尺寸,由于設定
N = 1 N=1
N
=
1
、
C = 1 C=1
C
=
1
。因此可以先只看后面四個維度。那么
( y , x , : , : ) (y,x,:,:)
(
y
,
x
,
:
,
:
)
意味著矩陣前兩維和濾波器尺寸一致,即
2 ? 2 2*2
2
?
2
,后面的兩個冒號,就代表了在卷積運算(濾波)時,第y行第x列的濾波器參數,需要和img中運算的數的矩陣。
解釋一下:當y=0,x=0時,對應的a的位置,如圖
此時,完成整個卷積運算的時候,a分別需要做(3*3)=9次的乘法,每次做乘法是對應img中的數如下:
第一次:(綠色表示當前卷積時,每個濾波器參數對應的位置,紅色表示a對應的位置)
第二次:
以此類推……
所以
( 0 , 0 , : , : ) (0,0,:,:)
(
0
,
0
,
:
,
:
)
中存的數如下:(黃色標注的位置),對應濾波器參數a所以進行運算的范圍。
所以:
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
中,y和x分別表示filter中第幾行,第幾列。然后每次移動stride,直到走完img中所有的位置,抵達y_max和x_max。
這句話,以及這個for循環的作用就解釋完了。
其次:
col
=
col
.
transpose
(
0
,
4
,
5
,
1
,
2
,
3
)
.
reshape
(
N
*
out_h
*
out_w
,
-
1
)
這句話目的是把矩陣重新排列,最后呈現出適合進行矩陣運算來代替循環的形式。
所以,這個矩陣一定是
N ? o u t _ h ? o u t _ w N*out\_h*out\_w
N
?
o
u
t
_
h
?
o
u
t
_
w
行。這里就是
3 ? 3 = 9 3*3=9
3
?
3
=
9
行。有多少列呢,肯定是濾波器系數的個數,即
2 ? 2 = 4 2*2=4
2
?
2
=
4
列。
至于transpose函數中的設置,主要是為了配合后面的reshape函數的參數。
多說一句,我覺得這里transpose不要老是想著轉置,我開始也這么想,這么多維度,就轉不過來彎了。
我覺得,其實transpose就是決定一個新的取數順序,依次取出來就可以,然后能夠和原來對應上,就沒問題了。比如 a是一個三維的東西。然后b = a.transpose(1,2,0)。也就是說
a [ y ] [ z ] [ x ] = b [ x ] [ y ] [ z ] a[y][z][x] = b[x][y][z]
a
[
y
]
[
z
]
[
x
]
=
b
[
x
]
[
y
]
[
z
]
transpose第一個參數,0,表示第0維,也就是transpose以后,第0維不變,說明即便展開,輸入的img也是按順序一個一個處理完的。
第2和3的參數,之所以放 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h 和 o u t _ w 的大小,得明白reshape的操作方法。如果沒有指定order參數,并且是默認按照C的存儲格式(這里不理解可以看看reshape的參數有哪些),它是把矩陣按照從第0維開始,依次全部排列開,然后在按需求重組。所以這里,要按照 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h 和 o u t _ w 優先順序排列開,然后再使col總共就 N ? o u t _ h ? o u t _ w N*out\_h*out\_w N ? o u t _ h ? o u t _ w 行,那么reshpe函數會使每行中,就存儲一次卷積所需要所有值,即 C ? f i l t e r h ? f i l t e r w C*filter_h*filter_w C ? f i l t e r h ? ? f i l t e r w ? 列。
后面三個參數保證順序不變就行,方便和濾波器參數位置一一對應。
以上,總結成一句話:其實就是準確找到濾波器每個參數對應需要相乘的所有值,然后再變換一下矩陣的行狀,就可以了。
寫在最后
由于本人水平有限,這一點代碼都想了一下午加一晚上才明白。還得繼續努力了。加油!雖然整理出來了,感覺有些東西不太好表述清楚,大家有什么問題可以留言,多多交流,互相學習。
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
