損失函數是什麼? Cross-entropy 交叉熵怎麼用?

損失函數是什麼? Cross-entropy 交叉熵怎麼用?

為了比較各種預測與事實之間的差距,Tensorflow 提供了許多損失函數,但什麼時候要用哪一個呢?下面我會介紹tensorflow.keras中,與熵(entropy) 有關的損失函數。熵是衡量資訊含量的重要指標,也被廣泛地應用在機器學習上,尤其是分類與自然語言處理 (Natural Language Processing)。

Keras是什麼?跟損失函數有什麼關係?

Keras包含一系列深度學習的簡化API,從常用的神經網路架構到不同的最佳化手段都有。Keras隱藏了不必要的可調參數,並且把Tensorflow的原生API組合成一個更高階的API。我十分建議讀者從Keras內的API開始學,因為他們更接近論文上看到的數學式,比較不容易被誤用。以下我提到的損失函數都可以在 tensorflow.keras.losses 類別 (class) 內找到。表一列出常見的機器學習任務,以及用來搭配的輸出活化函數與損失函數。本篇著重於連結損失函數與任務之間的關係,想知道如何使用的話請看MLE實作篇

舉個例子,有時我們想要找出某件發生的機率,相對應的損失函數可以是交叉熵 (Cross-entropy);另一些時候是物理量,則可能的損失函數是方差法 (MeanSquaredError)。

任務輸出資料型態輸出活化函數損失函數
二元判斷scalarSigmoidBinaryCrossentropy
類別判斷vectorSoftmaxSparseCategoricalCrossentropy
圖片生成matrixSigmoidBinaryCrossentropy
自然語言生成sequenceSoftmaxSparseCategoricalCrossentropy
表一:常見的任務、活化函數及損失函數。

為什麼Keras沒有提供「熵(Entropy)」函數?

因為熵不能作為損失函數使用。如果某件事有$n$種可能性、每種發生的機率是$p_i$。那麼它的熵可以寫成式(1)。根據定義,熵只跟事實有關並且跟預測無關。因此「熵」對模型參數的微分為零,也就無法用來更新參數,交叉熵 (Cross-entropy) 或KL Divergence才能評估事實與預測的差距。交叉熵是從熵延伸出來的概念,交叉熵是指對模型所輸出的機率分佈$q_i$來說,事實$p_i$有多令人驚訝。方程式可以寫作式(2)。KL Divergence的定義是「用$q_i$預測$p_i$」比「用$p_i$預測$p_i$」更讓人驚訝多少,也就是「交叉熵」減去「熵」的意思。寫作式(3)。

$H(p) = – \Sigma_i^n p_i \cdot log (p_i)$ -(1)

$H(p,q) = – \Sigma_i^n p_i \cdot log (q_i)$ -(2)

$D_{KL} = H(p,q) – H(p)$ -(3)

為什麼Keras有那麼多種交叉熵?

為的是讓電腦的計算更有效率,這與數學無關。舉個例子,如果只是要分辨「有」或「沒有」,那麼用一個範圍在零到一的數即可。當然也可以使用兩個數分別代表「有」或「沒有」,但是比較沒效率,因為用 1 減去任何一方都會等於剩下的那一方。Binary Cross-entropy 正是為了更快的比較「有」或「沒有」的機率分佈而存在的。Keras 提供的所有交叉熵的數學概念都是相同的,但是適用的 domain 不一樣,導致有不同的名字。

Binary Cross-entropy

這個損失函數比較的是兩個二元機率分佈 (Binary Probability Distribution) 的交叉熵。Binary Cross-entropy 經常用在兩個地方,一是為二元判斷、二是生成圖片。二元判斷很容易理解,就是用0到1為資料評分,例如:判斷 IMDB 上的留言是好評或是差評。生成圖片就比較有趣了,我們把一張圖的每一個點的亮度都當作一種機率,零代表最暗、一代表最亮。圖一我用3×3的圖片示範如何用機率表示亮度。

圖一:用零到一表示3×3圖片的每一個點的亮度

如果是有顏色的圖,那就用三個機率分別代表紅綠藍三種顏色。所以儘管模型輸出的圖片有許多資料點,但每一個資料點都可以用零到一代表該點的亮度,剛好使像素的亮度符合機率的定義。每個像素的最大與最低亮度都是獨立的,所以可以用 Binary Cross-entropy 。

Categorical Cross-entropy

請使用 Categorical Cross-entropy作為多元機率分布的損失函數 。其實用 Categorical Cross-entropy 作為二元機率分佈的損失函數也是可以的,但是比較沒效率。Binary Cross-entropy 的輸入用一個數就可以表達兩種結果的機率。 Categorical Cross-entropy 則必須使用兩個數才能表達兩種結果發生的機率,儘管他們其中一方減一必定等於另一方。

自然語言生成用的損失函數也是 Categorical Cross-entropy ,因為機器模型總是從給定的字典中選一個字作為輸出,字典有多大,類別就有多少種。也因為機器不一定只講一個字,也有可能一次講一堆字,例如翻譯機器人。翻譯機器人的輸出相當於一串條件機率。一個機器輸出的句子長得像圖二。每一個輸出都用機率分佈代表,然後選取機率最大的字作為輸出。

圖二:神經網路生成句子示意圖。每次要生成下一個文字前,神經網路會根據前文產生機率分布,然後從中採樣文字作為輸出。

One-hot encoding

Categorical Cross-entropy 的比較的是兩組機率:事實的機率分佈$p$與預測的機率分佈$q$。大部分的情況我們不會知道事實的機率分佈長怎樣。如果已經知道了,其實我也不需要訓練模型了對吧!但是沒有事實的機率分佈,我們要怎麼令損失函數收斂?

舉語言當例子,人在對話的時候,可以聽到對方講出來的每一個字,但不會知道他在講話的時候正在想什麼事。所以我們會透過大量採樣的方式近似事實的機率分佈,如式(4):

$E[H(p, q)] \sim \Sigma_i H(p_i, q_i) $ -(4)

訓練模型的時候我們會先收集很多觀測到的事件$p_i$,用來推敲事實的機率分佈$p$。只要$i$足夠大,式(4)就越準。由於收集到的事件$p_i$已經發生,所以我們把該事件相對應的類別的機率標為 1 、其餘為零,這個做法就叫做 One-hot encoding ,如圖三:

圖三:十個類別的 One-hot encoding 範例,意思是第三類別的發生機率100%。

類別越多,One-hot encoding 就越沒效率,因為沒有用卻佔空間的零會變多。事實的機率分佈總是未知的,使得這類的 One-hot encoding 相當常見。tensorflow 團隊也因應這個情境做出 Sparse Categorical Cross-entropy。

Sparse Categorical Cross-entropy

本函數跟Categorical Cross-entropy效果是一樣的,卻省資源得多。Sparse matrix 的意思是只有不為零的欄位會存在記憶體,其餘的不會。這麼做可以節省記憶體,同時經過 tensorflow 的優化,計算上更快。該損失函數要求輸入是「事實的類別」以及輸出是「模型的機率分佈」。

某些情況下不適用 Sparse Categorical Cross-entropy。舉例來說,為了穩定模型、或是為了避免過擬合 (overfitting) ,我們可能會在事實的機率分佈中添加亂數,使得事實的機率不再是 One-hot encoding ,也因此不能繼續使用 Sparse Categorical Cross-entropy 。

為什麼 KL divergence 很少作為損失函數?

因為使用 KL Divergence 或 Cross-entropy 搭配隨機梯度下降法 (Stochastic Gradient Descent) 會得到一樣的效果,詳情請看這裡。但 KL Divergence 會消耗更多的電腦資源,所以不建議使用 。但也不是所有的 KL Divergence 都可以用 Cross-entropy 替換掉。有些特殊的模型、例如 Proximal Policy Optimization 使用 KL Divergence 懲罰快速更新的模型;Cooperative Training 使用從 KL Divergence 延伸出來的損失函數 JS Divergence 來更新模型。

並不一定要用上述的損失函數

以上提到的損失函數與任務的相對關係是從前人的經驗得到的。換句話說,不這麼做也是可以的,還有其他的可能性存在。舉個例子:在對抗生成網路 (Generative Adversarial Network) 中,有一個模型負責生成資料,另一個負責判斷真偽。Wasserstein GAN 系列的論文會使用 Earth mover distance 來更新負責判斷真偽的模型而不是 Cross-entropy。並且在特定任務得到更好的結果。前面提到的Proximal Policy Optimization 及 Coperative Training 用的也不是單純Cross-entropy。也有一派學者主張要用Quantile Regression替代Mean Square Root。所以請讀者不要把表一的做法當作金科玉律一般看待。

小結

Tensorflow 提供了很多名稱類似但是用途相差甚多的損失函數,剛接觸的讀者可能比較難從名稱辨別用途。為了避免不了解而誤用,這也是我寫這一篇的原因。他們都是有優化過的,跑起來比自己寫得快得多。我並不建議自己撰寫損失函數,有現成的輪胎(程式碼)就用現成的,把心力用在真正想知道的事物吧。


所有文章分類

訂閱我吧

不再錯過每一篇新文章

*

Yi-Lung Chiu