本篇實作的目標是利用 Maximize Likelihood Estimation(MLE) 的手段訓練一個能夠辨認手寫文字。想了解 MLE 的數學意義的可以看前篇理論篇;我的完整原始碼可以參考這裡。下面我們會先了解資料類別,接著建模,然後訓練並測試。在開始之前,我們先確認一下軟體版本的狀態吧!所需要的module如下圖一:
接著我們檢查一下自己的tensorflow是哪個版本。我使用的版本如圖二:
根據使用的硬體不同,Tensorflow分為許多版本。最基本的是CPU版本,意思是只用CPU來訓練模型,特點是安裝簡單、使用方便;GPU版本則是利用了GPU去處理平行矩陣運算,因此會快很多,這也是最常用的加速方案;我使用的是M1處理器的版本,有neural engine加持,效能還是很不錯的,但就是記憶體小了一點。如果有想要在M1處理器訓練模型的讀者,我建議要買到16GB記憶體的版本。因為機器模型沒有一定的資料量是練不起來的,經驗上是一萬筆左右。也就是說,記憶體的大小必須要跨過資料量門檻,才能談訓練。
MNIST
MNIST 是 Modified National Institute of Standards and Technology database 的簡寫,是一個手寫數字資料庫,包含了零到九的十個阿拉伯數字。該資料庫已經貼心的將圖片與標籤一對一的做好了,讀法如下圖三:
為什麼有些資料不能用於訓練?
資料庫把資料分作訓練用(training) 以及測試用(testing) 兩堆,因為在測試模型的時候要使用模型沒見過的資料。如果使用訓練用資料測試模型的話,那我們就無法判斷模型是否已經過擬合(overfitting)了。過擬合通常會發生在模型太強但資料太簡單的情況。在思想實驗中,如果模型的參數很多,其輸出就可以在空間中折成任意的形狀,並輕易地穿過每一個資料點。更糟的是、由於資料並不完備(也不可能完備),輸出會在沒有資料的空間中恣意拉伸,這並不是我們想要的。我們希望模型也要能夠在未知的資料集做得好。如果能用沒看過的資料檢驗模型的話,就可以發現模型是不是過擬合了。請記得,測試資料只能用來作為指標,本身不能防止模型過擬合。未來我會找機會討論防止過擬合的作法。
Dense Neural Network
為了簡單起見,我們使用最最簡單的Dense Neural Network (DNN) 來判斷手寫文字吧。模型本身不是我們本節想要討論的主題,所以讀者可以先把以下結論記在心上,以後我們再慢慢討論。第一、DNN本質上是矩陣乘法,也就是輸出等於輸入乘上一個或多個張量(Tensor)。如果張量的每一個參數都屬於實數的話,那麼該矩陣乘法就幾乎是線性的。第二、在相同的參數量的情況下,非線性方程式能比線性方程式學得更快更好,所以要添加一個活化函數(Activation function)在兩個張量之間使得模型更加非線性化。示範的模型如下圖四:
本次的核心命題是猜數字,所以輸出必須是一種機率。在最後一層(layer)我使用了softmax,目的是使輸出機率化。請記得機率化的方式有千百種,這裡使用softmax是為了搭配損失函數KL divergence。
更新模型
模型的所有參數都是用亂數初始化的,也就是說未經訓練的模型只會亂猜。我們接著要調整參數直到模型可以準確的預測答案,這個過程稱為「最佳化」。因為這次要示範的是MLE,所以我選擇「交叉熵越小越好」作為我的最佳化目標,也可稱之為「對數似然度越高越好」或「KL divergence 越小越好」。我在理論篇解釋了為什麼這三者等價。疑?我總是說最小化KL divergence 就是最小化交叉熵。為什麼我要刻意呼叫交叉熵呢?因為算KL divergence 需要比較多電腦資源(多算一個熵,即使熵對參數的微分為零),所以我們改用交叉熵越小越好作為最佳化的目標。
最佳化算法我使用Adam,他是隨機梯度法(Stochastic Gradient Decent)的延伸型,本質上也是基於梯度的數值法。與隨機梯度法不同的是,Adam也會利用歷史梯度評估大尺度下的梯度,有點像投資學中「月線」的概念,但計算上與月線不相同。選好了最佳化目標以及最佳化算法後,就可以準備開始訓練模型了。訓練開始之前要告訴模型我們剛剛決定的最佳化作法,令loss是最佳化目標、optimizer是最佳化算法,寫法如圖五:
Metric
Metric也是mleD.compile()
的輸入之一字面上的意思是恆量模型好壞的指標,功能上會影響指令「evaluation(評估)」的回傳值,但不會直接決定哪一組參數做為最終的模型。所以嚴格來說,metrics是個名不符實的引數。其他metrics範例如圖六。當metric為loss的時候,evaluation就會回傳我們指定的損失函數的值;當metric為accuracy的時候,evaluation就會回傳猜中的百分比。
檢查模型的形狀
圖五的指令「summary」會總結模型的形狀,並且印給使用者看,例如圖七。這一步雖然不會直接影響模型的效能,但是我強烈強烈建議任何模型在訓練之前最好都利用「summary」印一下模型的形狀。理由是我們的記憶力有限:模型的性能跟它的形狀有關,例如有幾層、一層有幾個參數、用的是什麼最佳化算法等等。操作上我們一次只會改變一個變因,接著看模型有沒有進步,於是我們經常會面對許多類似的程式碼與模型。如果沒有把當時使用的模型印下來,幾天後又忘記當初是用什麼模型的話,那就很頭大了。
執行最佳化
以往訓練模型要自己寫迴圈抓資料、算預測機率、算損失函數、算梯度、更新參數、等等等等。Tensorflow 上的Keras API很貼心的提供一個function自動的完成所有更新模型該做的事。讓我們瞧瞧Keras是怎麼做的吧!如圖八:
下方的更新過程都是指令「fit」完成的,一共花了16個epoch完成訓練,並且準確率達到驚人的97.59%。讓我們一個個的分析fit的引數(argument)吧!
分派訓練資料
x_train、y_train都是訓練資料。x_train是手寫數字,y_train是相對應的標籤。batch_size指的是一次要使用多少對資料更新模型,「一對資料」指的是一筆來自x_train的資料搭配上相對應的來自y_train的資料。這個指令的目的是讓訓練的過程可平行化。如果每次更新參數時都只能丟入一對資料,那麼資料與資料之間就存在著先後關係,也就不能同時運算了。如果一次性投入128筆資料,那麼這128個梯度彼此之間就不存在先後關係,也就可以同時運算了。batch_size越大,對更新是有好處的。由多個梯度平均出來的梯度通常可以引導模型更快地走向最佳解。但是一次性投入越多對資料,電腦就要花越多時間計算,也會需要更多的記憶體。記得要在好處與壞處之間取得平衡。
Epoch、Validation data及Metric
當模型利用每一對訓練資料訓練過一次後,我們稱之為完成了一個epoch。引數epochs指的是最多完成幾次的epoch。300就是最多300次,然後就會停下來不再更新參數。引數shuffle是指每完成一個epoch就把訓練資料重新洗牌一次(x和y的配對關係不變),目的是避免模型記得資料的順序。validation_data 也是用來評判模型好壞的。跟「metric」不同,「validation_data」決定的是要用什麼資料評判模型的好壞。「metric」決定的是要用什麼算法決定模型的好壞。validation_data也不會直接決定哪一組參數為最終模型。
訓練小幫手Callbacks
引數callbacks扮演小幫手的角色,提供直接更新參數以外的服務。常用的有ModelCheckpoint、EarlyStopping。
ModelCheckpoint
ModelCheckpoint可以決定哪些參數會被存下來,條件由monitor決定。我們可以選幾個預設的條件,讓我們舉兩個常見的例子:如果monitor是「val_loss」,那就是當損失函數套用在validation data上並且損失函數的值創新低時,相對應的參數會被記錄在硬碟;如果monitor是「accuracy」,那就是當損失函數套用在training data上並且準確率創新高時,相對應的參數會被記錄在硬碟。會有這樣的差異是因為Tensorflow 默認accuracy越高越好,而loss越低越好。
EarlyStopping
EarlyStopping會根據指定條件停止更新參數。該條件也是由「monitor」指定,並且跟ModelCheckpoint的monitor用的是一樣的邏輯。舉個例子,我令monitor=‘val_loss’
且patient=5
。接下來本次訓練中,如果損失函數沒有在連續的五個epochs內創新低的話,那麼訓練中止。通常模型連續幾次無法得到更低的損失函數的話,接下來也只會越來越糟。我們稱這個「泰極否來」的過程為「過擬合(overfitting)」。我強調一下這不可以說是「損失函數無法下降後,模型進入了過擬合。」這樣的說法恰恰倒果為因。應該是模型先進入過擬合的狀態,損失函數無法進一步下降只是現象。
小結
MLE模型的概念並不複雜,概要是想辦法讓預測越接近事實越好。但要命的是魔鬼藏在細節裡。並不是所有的命題都可以用機器學習解決,最重要的是要有足夠多的資料。也不是任意的模型都有足夠的潛力猜出正解。尋找最佳模型的過程中,嘗試不同的深度、不同的參數量、甚至是不同的損失函數都是可能是必要的。這過程並沒有什麼一定這樣那樣做,未知的方法沒試過是不知道的,MLE僅僅是大家試誤後得到的其中一種好方法而已。所幸軟體工程師已經開發了許多方便的工具協助我們建立想要的模型,也可以避免走別人走過的彎路。文末我也提供我的原始碼,如果讀者有機會遇到類似的其他命題,請不要客氣使用。
備註
MNIST曾經是圖片辨識的benchmark,可是已經式微了。其中一個理由是圖片辨識模型越來越強,導致其結果缺乏鑒別度。實際上隨便一個上過半學期機器學習課程的大學生都可以做出一個模型把手寫數字辨識得七七八八。另一個理由是風格單一,這些零到九的數字雖然各有不同,但都是手寫風格。深度學習現在不只講究要判斷得準,還要能夠生成足以混淆人眼的逼真資料。更有甚者,還會提供可控條件,讓使用者調整想要的輸出資料特色。而MNIST很難評判這些特質,所以科學家已經不用MNIST作為benchmark了。