Rust 也能實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)?
作者 |?Nathan J. Goldbaum
譯者 | 彎月,責(zé)編 | 屠敏
出品 | CSDN(ID:CSDNnews)
以下為譯文:
?
我在前一篇帖子(http://neuralnetworksanddeeplearning.com/chap1.html)中介紹了MNIST數(shù)據(jù)集(http://yann.lecun.com/exdb/mnist/)以及分辨手寫(xiě)數(shù)字的問(wèn)題。在這篇文章中,我將利用前一篇帖子中的代碼,通過(guò)Rust實(shí)現(xiàn)一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)。我的目標(biāo)是探索用Rust實(shí)現(xiàn)數(shù)據(jù)科學(xué)工作流程的性能以及人工效率。
Python的實(shí)現(xiàn)
?
我在前一篇帖子中描述了一個(gè)非常簡(jiǎn)單的單層神經(jīng)網(wǎng)絡(luò),其可以利用基于隨機(jī)梯度下降的學(xué)習(xí)算法對(duì)MNIST數(shù)據(jù)集中的手寫(xiě)數(shù)字進(jìn)行分類(lèi)。聽(tīng)起來(lái)有點(diǎn)復(fù)雜,但實(shí)際上只有150行Python代碼,以及大量注釋。
如果你想深入了解神經(jīng)網(wǎng)絡(luò)的基礎(chǔ)知識(shí),請(qǐng)仔細(xì)閱讀我的前一篇帖子。而且請(qǐng)不要只關(guān)注代碼,理解代碼工作原理的細(xì)節(jié)并不是非常重要,你需要了解Python和Rust的實(shí)現(xiàn)差異。
在前一篇帖子中,Python代碼的基本數(shù)據(jù)容器是一個(gè)Network類(lèi),它表示一個(gè)神經(jīng)網(wǎng)絡(luò),其層數(shù)和每層神經(jīng)元數(shù)可以自由控制。在內(nèi)部,Network類(lèi)由NumPy二維數(shù)組的列表表示。該網(wǎng)絡(luò)的每一層都由一個(gè)表示權(quán)重的二維數(shù)組和一個(gè)表示偏差的一維數(shù)組組成,分別包含在Network類(lèi)的屬性weights和biases中。兩者都是二維數(shù)組的列表。偏差是列向量,但仍然添加了一個(gè)無(wú)用的維度,以二維數(shù)組的形式存儲(chǔ)。Network類(lèi)的初始化程序如下所示:
?
?
class?Network(object):
????def?__init__(self,?sizes):
????????"""The?list?``sizes``?contains?the?number?of?neurons?in?the
????????respective?layers?of?the?network.??For?example,?if?the?list
????????was?[2,?3,?1]?then?it?would?be?a?three-layer?network,?with?the
????????first?layer?containing?2?neurons,?the?second?layer?3?neurons,
????????and?the?third?layer?1?neuron.??The?biases?and?weights?for?the
????????network?are?initialized?randomly,?using?a?Gaussian
????????distribution?with?mean?0,?and?variance?1.??Note?that?the?first
????????layer?is?assumed?to?be?an?input?layer,?and?by?convention?we
????????won't?set?any?biases?for?those?neurons,?since?biases?are?only
????????ever?used?in?computing?the?outputs?from?later?layers."""
????????self.num_layers?=?len(sizes)
????????self.sizes?=?sizes
????????self.biases?=?[np.random.randn(y,?1)?for?y?in?sizes[1:]]
????????self.weights?=?[np.random.randn(y,?x)
????????????????????????for?x,?y?in?zip(sizes[:-1],?sizes[1:])]
在這個(gè)簡(jiǎn)單的實(shí)現(xiàn)中,權(quán)重和偏差的初始化呈標(biāo)準(zhǔn)正態(tài)分布——即均值為零,標(biāo)準(zhǔn)差為1的正態(tài)分布。我們可以看到,偏差明確地初始化為列向量。
這個(gè)Network類(lèi)公開(kāi)了兩個(gè)用戶(hù)可以直接調(diào)用的方法。第一個(gè)是evaluate方法,它要求網(wǎng)絡(luò)嘗試識(shí)別一組測(cè)試圖像中的數(shù)字,然后根據(jù)已知的正確答案對(duì)結(jié)果進(jìn)行評(píng)分。第二個(gè)是SGD方法,它通過(guò)迭代一組圖像來(lái)運(yùn)行隨機(jī)梯度下降的學(xué)習(xí)過(guò)程,將整組圖像分解成小批次,然后根據(jù)每一小批次的圖像以及用戶(hù)指定的學(xué)習(xí)速率eta更新該網(wǎng)絡(luò)的狀態(tài);最后再根據(jù)用戶(hù)指定的迭代次數(shù),隨機(jī)選擇一組小批次圖像,重新運(yùn)行這個(gè)訓(xùn)練過(guò)程。該算法的核心(每一小批次圖像處理以及神經(jīng)網(wǎng)絡(luò)的狀態(tài)更新)代碼如下所示:
?
?
def?update_mini_batch(self,?mini_batch,?eta):
????"""Update?the?network's?weights?and?biases?by?applying
????gradient?descent?using?backpropagation?to?a?single?mini?batch.
????The?``mini_batch``?is?a?list?of?tuples?``(x,?y)``,?and?``eta``
????is?the?learning?rate."""
????nabla_b?=?[np.zeros(b.shape)?for?b?in?self.biases]
????nabla_w?=?[np.zeros(w.shape)?for?w?in?self.weights]
????for?x,?y?in?mini_batch:
????????delta_nabla_b,?delta_nabla_w?=?self.backprop(x,?y)
????????nabla_b?=?[nb+dnb?for?nb,?dnb?in?zip(nabla_b,?delta_nabla_b)]
????????nabla_w?=?[nw+dnw?for?nw,?dnw?in?zip(nabla_w,?delta_nabla_w)]
????self.weights?=?[w-(eta/len(mini_batch))*nw
????????????????????for?w,?nw?in?zip(self.weights,?nabla_w)]
????self.biases?=?[b-(eta/len(mini_batch))*nb
???????????????????for?b,?nb?in?zip(self.biases,?nabla_b)]
?
我們可以針對(duì)小批次中的每個(gè)訓(xùn)練圖像,通過(guò)反向傳播(在backprop函數(shù)中實(shí)現(xiàn))求出代價(jià)函數(shù)的梯度的估計(jì)值的總和。在處理完所有的小批次后,我們可以根據(jù)估計(jì)的梯度調(diào)整權(quán)重和偏差。更新時(shí)在分母中加入了len(mini_batch),因?yàn)槲覀兿胍∨沃兴泄烙?jì)的平均梯度。我們還可以通過(guò)調(diào)整學(xué)習(xí)速率eta來(lái)控制權(quán)重和偏差的更新速度,eta可以在全局范圍內(nèi)調(diào)整每個(gè)小批次更新的大小。
backprop函數(shù)在計(jì)算該神經(jīng)網(wǎng)絡(luò)的代價(jià)函數(shù)的梯度時(shí),首先從輸入圖像的正確輸出開(kāi)始,然后將錯(cuò)誤反向傳播至網(wǎng)絡(luò)的各層。這需要大量的數(shù)據(jù)調(diào)整,在將代碼移植到Rust時(shí)我在此花費(fèi)了大量的時(shí)間,在此篇幅有限,我無(wú)法深入講解,如果你想了解具體的詳情,請(qǐng)參照這本書(shū)(http://neuralnetworksanddeeplearning.com/chap2.html)。
?
Rust的實(shí)現(xiàn)
?
首先,我們需要弄清楚如何加載數(shù)據(jù)。這個(gè)過(guò)程非常繁瑣,所以我另寫(xiě)了一篇文章專(zhuān)門(mén)討論(https://ngoldbaum.github.io/posts/loading-mnist-data-in-rust/)。在這之后,下一步我們必須弄清楚如何用Rust表示Python代碼中的Network類(lèi)。最終我決定使用struct:
?
?
use?ndarray::Array2;
#[derive(Debug)]
struct?Network?{
????num_layers:?usize,
????sizes:?Vec
????biases:?Vec
????weights:?Vec
}
該結(jié)構(gòu)的初始化與Python的實(shí)現(xiàn)大致相同:根據(jù)每層中的神經(jīng)元數(shù)量進(jìn)行初始化。
?
?
use?rand::distributions::StandardNormal;
use?ndarray::{Array,?Array2};
use?ndarray_rand::RandomExt;
impl?Network?{
???????fn?new(sizes:?&[usize])?->?Network?{
????????let?num_layers?=?sizes.len();
????????let?mut?biases:?Vec
????????let?mut?weights:?Vec
????????for?i?in?1..num_layers?{
????????????biases.push(Array::random((sizes[i],?1),?StandardNormal));
????????????weights.push(Array::random((sizes[i],?sizes[i?-?1]),?StandardNormal));
????????}
????????Network?{
????????????num_layers:?num_layers,
????????????sizes:?sizes.to_owned(),
????????????biases:?biases,
????????????weights:?weights,
????????}
????}?
}
有一點(diǎn)區(qū)別在于,在Python中我們使用numpy.random.randn初始化偏差和權(quán)重,而在Rust中我們使用ndarray::Array::random函數(shù),并以rand::distribution::Distribution為參數(shù),允許選擇任意的分布。在上述代碼中,我們使用了rand::distributions::StandardNormal分布。注意,我們使用了三個(gè)不同的包中定義的接口,其中兩個(gè)ndarray本身和ndarray-rand由ndarray作者維護(hù),另一個(gè)rand則由其他開(kāi)發(fā)人員維護(hù)。
?
整體式包的優(yōu)點(diǎn)
?
原則上,最好不要將隨機(jī)數(shù)生成器放到ndarray代碼庫(kù)中,這樣當(dāng)rand函數(shù)支持新的隨機(jī)分布時(shí),ndarray以及Rust生態(tài)系統(tǒng)中所有需要隨機(jī)數(shù)的包都會(huì)受益。另一方面,這確實(shí)會(huì)增加一些認(rèn)知開(kāi)銷(xiāo),因?yàn)闆](méi)有集中的位置,查閱文檔時(shí)需要參考多個(gè)包的文檔。我的情況有點(diǎn)特殊,我沒(méi)想到做這個(gè)項(xiàng)目的時(shí)候,恰逢rand發(fā)布改變了其公共API的版本。導(dǎo)致ndarray-rand(依賴(lài)于rand版本0.6)和我的項(xiàng)目所依賴(lài)的版本0.7之間產(chǎn)生了不兼容性。
我聽(tīng)說(shuō)cargo和Rust的構(gòu)建系統(tǒng)可以很好地處理這類(lèi)問(wèn)題,但至少我遇到了一個(gè)非常令人困惑的錯(cuò)誤信息:我傳入的隨機(jī)數(shù)分布不能滿足Distribution這個(gè)trait的要求。雖然這話不假——它符合0.7版本的rand,但不符合ndarray-rand要求的0.6版本的rand,但這依然非常令人費(fèi)解,因?yàn)殄e(cuò)誤信息中沒(méi)有給出各種包的版本號(hào)。最后我報(bào)告了這個(gè)問(wèn)題。我發(fā)現(xiàn)這些有關(guān)API版本不兼容的錯(cuò)誤消息是Rust語(yǔ)言長(zhǎng)期存在的一個(gè)問(wèn)題。希望將來(lái)Rust可以顯示更多有用的錯(cuò)誤信息。
最后,這種關(guān)注點(diǎn)的分離給我這個(gè)新用戶(hù)帶來(lái)了很大困難。在Python中,我可以簡(jiǎn)單通過(guò)import numpy完成。我確實(shí)認(rèn)為NumPy在整體式上走得太遠(yuǎn)了(當(dāng)時(shí)打包和分發(fā)帶有C擴(kuò)展的Python代碼與現(xiàn)在相比太難了),但我也認(rèn)為在另一個(gè)極端上漸行漸遠(yuǎn),會(huì)導(dǎo)致語(yǔ)言或生態(tài)系統(tǒng)的學(xué)習(xí)難度增大。
?
類(lèi)型和所有權(quán)
?
下面我將詳細(xì)介紹一下Rust版本的update_mini_batch:
?
?
impl?Network?{
????fn?update_mini_batch(
????????&mut?self,
????????training_data:?&[MnistImage],
????????mini_batch_indices:?&[usize],
????????eta:?f64,
????)?{
????????let?mut?nabla_b:?Vec
????????let?mut?nabla_w:?Vec
????????for?i?in?mini_batch_indices?{
????????????let?(delta_nabla_b,?delta_nabla_w)?=?self.backprop(&training_data[*i]);
????????????for?(nb,?dnb)?in?nabla_b.iter_mut().zip(delta_nabla_b.iter())?{
????????????????*nb?+=?dnb;
????????????}
????????????for?(nw,?dnw)?in?nabla_w.iter_mut().zip(delta_nabla_w.iter())?{
????????????????*nw?+=?dnw;
????????????}
????????}
????????let?nbatch?=?mini_batch_indices.len()?as?f64;
????????for?(w,?nw)?in?self.weights.iter_mut().zip(nabla_w.iter())?{
????????????*w?-=?&nw.mapv(|x|?x?*?eta?/?nbatch);
????????}
????????for?(b,?nb)?in?self.biases.iter_mut().zip(nabla_b.iter())?{
????????????*b?-=?&nb.mapv(|x|?x?*?eta?/?nbatch);
????????}
????}
}
該函數(shù)使用了我定義的兩個(gè)輔助函數(shù),因此更為簡(jiǎn)潔:
?
?
fn?to_tuple(inp:?&[usize])?->?(usize,?usize)?{
????match?inp?{
????????[a,?b]?=>?(*a,?*b),
????????_?=>?panic!(),
????}
}
fn?zero_vec_like(inp:?&[Array2
????inp.iter()
????????.map(|x|?Array2::zeros(to_tuple(x.shape())))
????????.collect()
}
與Python實(shí)現(xiàn)相比,調(diào)用update_mini_batch的接口有點(diǎn)不同。這里,我們沒(méi)有直接傳遞對(duì)象列表,而是傳遞了整套訓(xùn)練數(shù)據(jù)的引用以及數(shù)據(jù)集中的索引的切片。由于這種做法不會(huì)觸發(fā)借用檢查,因此更容易理解。
在zero_vec_like中創(chuàng)建nabla_b和nabla_w與我們?cè)赑ython中使用的列表非常相似。其中有一個(gè)波折讓我有些沮喪,本來(lái)我想設(shè)法使用Array2::zeros創(chuàng)建一個(gè)初始化為零的數(shù)組,并將其傳遞給圖像的切片或Vec,這樣我就可以得到一個(gè)ArrayD實(shí)例。如果想獲得一個(gè)Array2(顯然這是一個(gè)二維數(shù)組,而不是一個(gè)通用的D維數(shù)組),我需要將一個(gè)元組傳遞給Array::zeros。然而,由于ndarray::shape會(huì)返回一個(gè)切片,我需要通過(guò)to_tuple函數(shù)手動(dòng)將切片轉(zhuǎn)換為元組。這種情況在Python很容易處理,但在Rust中,元組和切片之間的差異非常重要,就像在這個(gè)API中一樣。
利用反向傳播估計(jì)權(quán)重和偏差更新的代碼與python的實(shí)現(xiàn)結(jié)構(gòu)非常相似。我們分批訓(xùn)練每個(gè)示例圖像,并獲得二次成本梯度的估計(jì)值作為偏差和權(quán)重的函數(shù):
?
?
let?(delta_nabla_b,?delta_nabla_w)?=?self.backprop(&training_data[*i]);
然后累加這些估計(jì)值:
?
?
for?(nb,?dnb)?in?nabla_b.iter_mut().zip(delta_nabla_b.iter())?{
????*nb?+=?dnb;
}
for?(nw,?dnw)?in?nabla_w.iter_mut().zip(delta_nabla_w.iter())?{
????*nw?+=?dnw;
}
在處理完小批次后,我們根據(jù)學(xué)習(xí)速率調(diào)整權(quán)重和偏差:
?
?
let?nbatch?=?mini_batch_indices.len()?as?f64;
for?(w,?nw)?in?self.weights.iter_mut().zip(nabla_w.iter())?{
????*w?-=?&nw.mapv(|x|?x?*?eta?/?nbatch);
}
for?(b,?nb)?in?self.biases.iter_mut().zip(nabla_b.iter())?{
????*b?-=?&nb.mapv(|x|?x?*?eta?/?nbatch);
}
這個(gè)例子說(shuō)明與Python相比,在Rust中使用數(shù)組數(shù)據(jù)所付出的人力有非常大的區(qū)別。首先,我們沒(méi)有讓這個(gè)數(shù)組乘以浮點(diǎn)數(shù)eta / nbatch,而是使用了Array::mapv,并定義了一個(gè)閉包,以矢量化的方式映射了整個(gè)數(shù)組。這種做法在Python中會(huì)很慢,因?yàn)楹瘮?shù)調(diào)用非常慢。然而,在Rust中沒(méi)有太大的區(qū)別。在做減法時(shí),我們還需要通過(guò)&借用mapv的返回值,以免在迭代時(shí)消耗數(shù)組數(shù)據(jù)。在編寫(xiě)Rust代碼時(shí)需要仔細(xì)考慮函數(shù)是否消耗數(shù)據(jù)或引用,因此在編寫(xiě)類(lèi)似于Python的代碼時(shí),Rust的要求更高。另一方面,我更加確信我的代碼在編譯時(shí)是正確的。我不確定這段代碼是否有必要,因?yàn)镽ust真的很難寫(xiě),可能是因?yàn)槲业腞ust編程經(jīng)驗(yàn)遠(yuǎn)不及Python。
?
用Rust重新編寫(xiě),一切都會(huì)好起來(lái)
?
到此為止,我用Rust編寫(xiě)的代碼運(yùn)行速度超過(guò)了我最初編寫(xiě)的未經(jīng)優(yōu)化的Python代碼。然而,從Python這樣的動(dòng)態(tài)解釋語(yǔ)言過(guò)渡到Rust這樣的性能優(yōu)先的編譯語(yǔ)言,應(yīng)該能達(dá)到10倍或更高性能,然而我只觀察到大約2倍的提升。我該如何測(cè)量Rust代碼的性能?幸運(yùn)的是,有一個(gè)非常優(yōu)秀的項(xiàng)目flamegraph(https://github.com/ferrous-systems/flamegraph)可以很容易地為Rust項(xiàng)目生成火焰圖。這個(gè)工具為cargo添加了一個(gè)flamegraph子命令,因此你只需運(yùn)行cargo flamegraph,就可以運(yùn)行代碼,然后寫(xiě)一個(gè)flamegraph的svg文件,就可以通過(guò)Web瀏覽器觀測(cè)。
可能你以前從未見(jiàn)過(guò)火焰圖,因此在此簡(jiǎn)單地說(shuō)明一下,例程中程序的運(yùn)行時(shí)間比例與該例程的條形寬度成正比。主函數(shù)位于圖形的底部,主函數(shù)調(diào)用的函數(shù)堆疊在上面。你可以通過(guò)這個(gè)圖形簡(jiǎn)單地了解哪些函數(shù)在程序中占用的時(shí)間最多——圖中非常“寬”的函數(shù)都在運(yùn)行中占用了大量時(shí)間,而非常高且寬的函數(shù)棧都代表其包含非常深入的棧調(diào)用,其代碼的運(yùn)行占用了大量時(shí)間。通過(guò)以上火焰圖,我們可以看到我的程序大約一半的時(shí)間都花在了dgemm_kernel_HASWELL等函數(shù)上,這些是OpenBLAS線性代數(shù)庫(kù)中的函數(shù)。其余的時(shí)間都花在了`update_mini_batch和分配數(shù)組中等數(shù)組操作上,而程序中其他部分的運(yùn)行時(shí)間可以忽略不計(jì)。
如果我們?yōu)镻ython代碼制作了一個(gè)類(lèi)似的火焰圖,則也會(huì)看到一個(gè)類(lèi)似的模式——大部分時(shí)間花在線性代數(shù)上(在反向傳播例程中調(diào)用np.dot)。因此,由于Rust或Python中的大部分時(shí)間都花在數(shù)值線性代數(shù)庫(kù)中,所以我們永遠(yuǎn)也無(wú)法得到10倍的提速。
實(shí)際情況可能比這更糟。上述我提到的書(shū)中有一個(gè)練習(xí)是使用向量化矩陣乘法重寫(xiě)Python代碼。在這個(gè)方法中,每個(gè)小批次中所有圖像的反向傳播都需要通過(guò)一組矢量化矩陣乘法運(yùn)算完成。這需要在二維和三維數(shù)組間運(yùn)行矩陣乘法。由于每個(gè)矩陣乘法運(yùn)算使用的數(shù)據(jù)量大于非向量化的情況,因此OpenBLAS能夠更有效地使用CPU緩存和寄存器,最終可以更好地利用我的筆記本電腦上的CPU資源。重寫(xiě)的Python版本比Rust版本更快,但也只有大約兩倍左右。
原則上,我們可以用相同的方式優(yōu)化Rust代碼,但是ndarray包還不支持高于二維的矩陣乘法。我們也可以利用rayon等庫(kù)實(shí)現(xiàn)小批次更新線程的并行化。我在自己的筆記本電腦上試了試,并沒(méi)有看到任何提速,但可能更強(qiáng)大的機(jī)器有更多CPU線程。我還嘗試了使用使用不同的低級(jí)線性代數(shù)實(shí)現(xiàn),例如,利用Rust版的tensorflow和torch,但當(dāng)時(shí)我覺(jué)得我完全可以利用Python版的這些庫(kù)。
?
Rust是否適合數(shù)據(jù)科學(xué)工作流程?
?
目前,我不得不說(shuō)答案是“尚未”。如果我需要編寫(xiě)能夠?qū)⒁蕾?lài)性降到最低的、經(jīng)過(guò)優(yōu)化的低級(jí)代碼,那么我肯定會(huì)使用Rust。然而,要想利用Rust完全取代Python或C++,那么我們尚需要等待更穩(wěn)定和更完善的包生態(tài)系統(tǒng)。
原文:https://ngoldbaum.github.io/posts/python-vs-rust-nn/
本文為 CSDN 翻譯,轉(zhuǎn)載請(qǐng)注明來(lái)源出處。
【End】
還在擔(dān)憂Python的就業(yè)前景? 快來(lái) 看看這些!
https://edu.csdn.net/topic/python115?utm_source=csdn_bw
?熱 文 ?推 薦?
?這位博士都 50 多歲了,為啥還在敲代碼?
?C# 導(dǎo)出 Excel 的 6 種簡(jiǎn)單方法!你會(huì)幾種?
?這位博士都 50 多歲了,為啥還在敲代碼?
?2019 編程語(yǔ)言排行榜:Java、Python 龍爭(zhēng)虎斗!PHP 屹立不倒!
?2億日活,日均千萬(wàn)級(jí)視頻上傳,快手推薦系統(tǒng)如何應(yīng)對(duì)技術(shù)挑戰(zhàn)?
?Docker容器化部署Python應(yīng)用
?給面試官講明白:一致性Hash的原理和實(shí)踐
?預(yù)警,CSW的50萬(wàn)枚塵封BTC即將重返市場(chǎng)?
?她說(shuō):行!沒(méi)事別嫁程序員!
點(diǎn)擊閱讀原文,輸入關(guān)鍵詞,即可搜索您想要的 CSDN 文章。
?
?
?
?
?
?
?
你點(diǎn)的每個(gè)“在看”,我都認(rèn)真當(dāng)成了喜歡
更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號(hào)聯(lián)系: 360901061
您的支持是博主寫(xiě)作最大的動(dòng)力,如果您喜歡我的文章,感覺(jué)我的文章對(duì)您有幫助,請(qǐng)用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長(zhǎng)非常感激您!手機(jī)微信長(zhǎng)按不能支付解決辦法:請(qǐng)將微信支付二維碼保存到相冊(cè),切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對(duì)您有幫助就好】元
