前言
R語言所以強大最主要的原因是擁有眾多的套件可用,而最困擾的部分也是有如此多的套件可用。因為想從海量的套件中找出所需的套件,實務上是一件非常困難的任務。此外套件的使用也是 一個門檻,如果缺乏文件說明以及範例提供引導,要使用套件也是相當的不容易。因此在尋找相關書籍之後覺得<<Data Mining with R Learning by Case Studies>>這本書可以幫助初學者快速建立應用的基本概念,同時這些案例中所使用的套件也可以初步提供相關的應用。這系列文章是閱讀之後將案例流程及相關說明整理備忘。這個案例所展現的相關技術有:
- 資料視覺化
- 描述性統計
- 處理變數缺漏值的策略
- 迴歸分析作業
- 迴歸分析的評估指標
- 多線性迴歸分析
- 迴歸樹分析
- 運用k-fold交叉驗證對各種模型的比較選擇
- 同時使用各種模型及隨機森林的應用
對R的基本操作技術有:
- 從文字檔載入資料
- 如何從料集中取得相關的統計描述
- 基本的資料視覺化
- 處理資料集中的缺漏值
- 取得迴歸模型
- 運用上述的模型預測測試資料的預測值
問題描述
這個案例試圖預測水藻的生長。由於從河水中採樣取得水藻生長的數據必較困難,科學家想要從河水中取得水的樣本,看看是否可以從水樣本中的化學物質含量預測水藻的生長數據。水質樣本的化學物質計有下列8項:
- Maximum pH value
- Minimum value of O2(oxygen)
- Mean value of Cl (chloride)
- Mean value of NO−3(nitrates)
- Mean value of NH+4(ammonium)
- Mean of PO3−4(orthophosphate)
- Mean of total PO4(phosphate)
- Mean of chlorophyll
資料集分兩組:第一組共有200筆,每一筆資料都有上述8欄數據,及7欄水藻的生長數據。這一組使用來產生所需的預測模型。第二組只有前面8欄,將採用前述的預測模型來預測水藻的生長數據。
從R載入資料料
這兩組資料可以直接從DMwR這個套件中載入:
================================
> library(DMwR)
> head(algae) //列出前6筆資料如下:
-----------------------------------------------------------
season size speed mxPH mnO2 Cl NO3 NH4
1 winter small medium 8.00 9.8 60.800 6.238 578.000
2 spring small medium 8.35 8.0 57.750 1.288 370.000
3 autumn small medium 8.10 11.4 40.020 5.330 346.667
4 spring small medium 8.07 4.8 77.364 2.302 98.182
5 autumn small medium 8.06 9.0 55.350 10.416 233.700
6 winter small high 8.25 13.1 65.750 9.248 430.000
-----------------------------------------------
-----------------------------------------------
oPO4 PO4 Chla a1 a2 a3 a4 a5 a6
1 105.000 170.000 50.0 0.0 0.0 0.0 0.0 34.2 8.3
2 428.750 558.750 1.3 1.4 7.6 4.8 1.9 6.7 0.0
3 125.667 187.057 15.6 3.3 53.6 1.9 0.0 0.0 0.0
4 61.182 138.700 1.4 3.1 41.0 18.9 0.0 1.4 0.0
5 58.222 97.580 10.5 9.2 2.9 7.5 0.0 7.5 4.1
6 18.250 56.667 28.4 15.1 14.6 1.4 0.0 22.5 12.6
------------------------------------------------
------------------------------------------------
a7
1 0.0
2 2.1
3 9.7
4 1.4
5 1.0
6 2.9
================================
================================
algae <- read.table('Analysis.txt',
+ header=F, //不讀入欄位名稱(預設值可省略)
+ dec='.', //小數點符號(預設值可省略)
+ col.names=c('season','size','speed','mxPH','mnO2',
+ 'Cl','NO3','NH4','oPO4','PO4','Chla','a1','a2','a3','a4',
+ 'a5','a6','a7'), //設定欄位名稱
+ na.strings=c('XXXXXXX')) //將缺漏值'XXXXXXX'轉成R的NA值
================================
資料視覺化及總結
使用summary方法可以顯示資料集的統計描述:
================================
> summary(algae)
-----------------------------------------------------------
season size speed mxPH
autumn:40 large :44 high :84 Min. :5.60
spring:53 medium:84 low :33 1st Qu.:7.70
summer:44 small :70 medium:81 Median :8.06
winter:61 Mean :8.02
3rd Qu.:8.40
Max. :9.70
NA's :1
------------------------
mnO2 Cl NO3
Min. : 1.500 Min. : 0.222 Min. : 0.050
1st Qu.: 7.800 1st Qu.: 10.981 1st Qu.: 1.296
Median : 9.800 Median : 32.730 Median : 2.675
Mean : 9.125 Mean : 43.636 Mean : 3.282
3rd Qu.:10.800 3rd Qu.: 57.824 3rd Qu.: 4.446
Max. :13.400 Max. :391.500 Max. :45.650
NA's :1 NA's :8
------------------------
NH4 oPO4 PO4
Min. : 5.00 Min. : 1.00 Min. : 1.0
1st Qu.: 38.33 1st Qu.: 15.70 1st Qu.: 43.5
Median : 103.17 Median : 40.15 Median :104.0
Mean : 501.30 Mean : 73.59 Mean :138.5
3rd Qu.: 226.95 3rd Qu.: 99.33 3rd Qu.:214.0
Max. :24064.00 Max. :564.60 Max. :771.6
NA's :1
------------------------
Chla a1 a2
Min. : 0.200 Min. : 0.000 Min. : 0.000
1st Qu.: 2.000 1st Qu.: 1.525 1st Qu.: 0.000
Median : 5.475 Median : 6.950 Median : 3.000
Mean : 13.971 Mean :16.996 Mean : 7.471
3rd Qu.: 18.308 3rd Qu.:24.800 3rd Qu.:11.275
Max. :110.456 Max. :89.800 Max. :72.600
NA's :10
------------------------
a3 a4 a5
Min. : 0.000 Min. : 0.000 Min. : 0.000
1st Qu.: 0.000 1st Qu.: 0.000 1st Qu.: 0.000
Median : 1.550 Median : 0.000 Median : 2.000
Mean : 4.334 Mean : 1.997 Mean : 5.116
3rd Qu.: 4.975 3rd Qu.: 2.400 3rd Qu.: 7.500
Max. :42.800 Max. :44.600 Max. :44.400
------------------------
a6 a7
Min. : 0.000 Min. : 0.000
1st Qu.: 0.000 1st Qu.: 0.000
Median : 0.000 Median : 1.000
Mean : 6.005 Mean : 2.487
3rd Qu.: 6.975 3rd Qu.: 2.400
Max. :77.600 Max. :31.600
================================
其中season size speed這三個欄位屬於因子變數(nominal),所以列出個因子的數量。
其他的欄位是數字變數,列出相關的統計數字:
Min. : 最小值
1st Qu.:第一四分位數
Median : 中值(第三四分位數)
Mean : 平均值
3rd Qu.: 第三四分位數
Max. :最大值
NA's :缺漏值數量
你也可以使用Hmisc這個套件中的describe這個方法顯示相關的統計數字:
>library(Hmisc)
>describe(algae)
以上是以列表方式顯示相關的統計描述數據,提供初步的資料分布狀態,雖然從中值及平均值比對,以及比對第一四分位數與第三四分位數可以了解資料的偏度(skewness),從這些數據中並不是很容易其中的統計特徵。所以需要使用圖表來分析其中的統計特徵。
首先我們使用R的hist方法針對mxPH欄位畫出直方圖(histogram):
================================
> hist(algae$mxPH, prob = T)
================================
mxPH欄位資料的直方圖 |
其中prob=T表示是以機率(probabilities)代表每一區間的值。每一區間的面積代表機率的總和,而非高度。
******************************************************
問題:既然是機率,為何加總不是1?
******************************************************
你也可以設prob=F製作直方圖,此時則是以數量(frequency)繪圖。
一般我們需要檢驗所蒐集的資料是否符合常態分配,一般檢驗資料是否符合常態分配一般使用常態分位數圖(Normal QQ plots)[註:使用第一四分位數及第三四分位數繪製一條直線(理論上的常態分配),並將所有數據標示在圖上,看數據是否貼近這條直線]。以下我們使用Car這個套件來繪圖。
================================
> library(car)
//分割視窗,每一列兩欄。par()可用來設定圖形的各種參數
> par(mfrow=c(1,2))
//繪製直方圖
> hist(algae$mxPH, prob=T, xlab='',
+ main='Histogram of maximum pH value',ylim=0:1)
//繪製平滑的直方圖na.rm=T將缺漏值移除,否則如有缺漏值會發生錯誤
> lines(density(algae$mxPH,na.rm=T))
//在X軸上標示所有資料位置
> rug(jitter(algae$mxPH))
//原文是用的方法已棄用但仍可執行只是R會發出警告(此版本也會繪出95%信賴的範圍(舊版圖虛線部分))
> qq.plot(algae$mxPH,main='Normal QQ plot of maximum pH')
//新版採用下兩列指令替代qq.plot
>qqnorm(algae$mxPH, plot.it = T, main='Normal QQ plot of maximum pH')
>qqline(algae$mxPH,datax=F, distribution=qnorm)
//結束分割視窗
> par(mfrow=c(1,1))
================================
直方圖及新版常態QQ圖 |
舊版的常態QQ圖 |
*******************************************
問題:如果數據無法符合常態分配,那麼這些資料是否可用?或應如何處理這些數據?
*******************************************
在這一大串的指令當中,如果你想要知道某一指令呈現的效果為何,可以單步執行,看看每一指令產生的效果。
另外,每一指令的參數作用為何,可以在R使用?指令查看,如想要知道qqnorm的用法,可以執行下列指令:
>?qqnorm
或
>help(qqnorm)
另一個可以呈現數據分散情況的圖形是盒鬚圖或稱箱型圖(box plot)。下列指令可以呈現oPO4欄位的盒鬚圖:
================================
//繪出盒鬚圖
> boxplot(algae$oPO4, ylab = "Orthophosphate (oPO4)")
//在Y軸標出所有的資料
> rug(jitter(algae$oPO4), side = 2)
//繪出平均值的現(虛線)
> abline(h = mean(algae$oPO4, na.rm = T), lty = 2
================================
盒鬚圖 |
*******************************************
問題:實務上,遠離本體的資料是否可用?或該如何處理?
*******************************************
從圖中顯示出,大部分的資料料是集中在低的部分。代表是一種正的偏斜(positive skew)。在所有資料中只有oPO4是如此,其餘的都是向高處偏斜,甚至有些超大的值。
對於這些遠離本體的資料,我們比較關心那些『超級』偏離的值。以NH4為例,我們可以使用圖形來觀察這些『超級』值:
================================
//列出所有資料點
> plot(algae$NH4, xlab = "")
//繪製平均線
> abline(h = mean(algae$NH4, na.rm = T), lty = 1)
//繪製平均+1標準差的線
> abline(h = mean(algae$NH4, na.rm = T) + sd(algae$NH4, na.rm = T),
+ lty=2)
//繪製中數線
> abline(h = median(algae$NH4, na.rm = T), lty = 3)
//識別資料,用滑鼠點選資料點時可以顯示該資料的編號
> identify(algae$NH4)
================================
尋找超級偏離值 |
中間實線:平均值
最下面點縣:中值線
超級偏離值編號:152
此外你也可以將點選的結果儲存在變數當中:
================================
> plot(algae$NH4, xlab = "")
> clicked.lines <- identify(algae$NH4)
//點選...右鍵
> algae[clicked.lines, ]
----------------------------------------------------------
season size speed mxPH mnO2 Cl NO3
20 spring small medium 7.79 3.2 64.000 2.822
153 autumn medium high 7.30 11.8 44.205 45.650
---------------------
NH4 oPO4 PO4 Chla a1 a2 a3 a4 a5
20 8777.6 564.6 771.6 4.5 0.0 0 0 44.6 0.0
153 24064.0 44.0 34.0 53.1 2.2 0 0 1.2 5.9
---------------------
a6 a7
20 0.0 1.4
153 77.6 0.0
================================
你也可以使用下列指令搜尋超過某值的資料(列出NH4 > 19000的資料):
================================
> algae[algae$NH4 > 19000, ]
================================
這個指令除了會列出合乎條件的資料,同時也會列出所有缺漏值的資料。因此為移除缺漏值可以再加一條件:
================================
> algae[!is.na(algae$NH4) & algae$NH4 > 19000, ]
================================
接下來介紹一種一特定條件繪圖的方式。我們使用lattice套件。依特定條件,這條件一般是指以因子欄位當做條件。如依size為條件,size有三個因子:small、medium、large。繪圖時則依資料行中size欄位中的值是small、medium、large分別繪圖。
================================
> library(lattice)
//依size繪製a1資料的盒鬚圖
> bwplot(size ~ a1, data=algae, ylab='River Size',xlab='Algal A1')
================================
條件盒鬚圖 |
從上圖中可以觀察到,出現最多的是位於small的河流中。
我們也可以使用Hmisc這個套件繪製相同的圖型,只是外觀有所差異。
================================
> library(Hmisc)
> bwplot(size ~ a1, data=algae,panel=panel.bpplot,
+ probs=seq(.01,.49,by=.01), datadensity=TRUE,
+ ylab='River Size',xlab='Algal A1')
================================
Hmise套件繪製的條件盒鬚圖 |
其中菱形點代表平均值。垂直線由左至右分別是『第一四分位數』『中數』『第三四分位數』,小垂直線段代表各個資料點。與以lattice套件繪製的相比,除了出現最多的是位於small的河流中外,Hmisc繪製的更可以觀察出small河流中的數據比較分散。同時也可以看出small的平均值相對其他兩組資料是比較高的。
*******************************************
問題:資料比較分散代表什麼意義?平均值高低代表什麼意義?對資料處理有何影響?
*******************************************
*******************************************
問題:資料比較分散代表什麼意義?平均值高低代表什麼意義?對資料處理有何影響?
*******************************************
這類圖形的條件並不限制必須是因子變數,也可以使用數字變數,只要將這些連續性數字依大小順序分組即可。下面的例子是以mnO2為條件繪製a3的圖表。
================================
//number=4將mnO2依序分成4組,overlap=1/5表示各組之間有1/5的重疊,na.omit忽略NA
> minO2 <- equal.count(na.omit(algae$mnO2),
+ number=4,overlap=1/5)
//先依season分組,再依minO2分組,stripplot是lattice的另一個繪圖指令
> stripplot(season ~ a3|minO2,
+ data=algae[!is.na(algae$mnO2),])
================================
依季節及mnO2條件顯示a3的資料 資料順序是由左而右由下而上 |
缺漏值
由於資料集中常常會有一些不完整的部分,在處理時會造成程式發生錯誤或影響處理結果。因此在處理資料前必須先處理這些不完整的資料,也就是缺漏值。處理缺漏值的方式有:
- 移除有缺漏值的資料。
- 透過欄位之間的相關性,用以填入缺漏值。
- 透過其他行的相似性,用以填入缺漏值。
- 使用可以處理這些值的工具。(有限制性)
這些方法並非互補的,也就是說當你使用某一方法處理後,想要換另一種方法時,必須重新載入資料使之還原原始狀態。
移除缺漏值
如果缺漏值數量相對總資料是少的時候,移除缺漏值是可行的。
在處理前可以先列出或算出擁有缺漏值的資料:
================================
> algae[!complete.cases(algae),]//列出擁有缺漏值的資料
> nrow(algae[!complete.cases(algae),])//算出擁有缺漏值的資料
> nrow(algae[!complete.cases(algae),])//算出擁有缺漏值的資料
================================
使用下列的指令移除缺漏值:
================================
> algae <- na.omit(algae)
使用下列的指令移除缺漏值:
================================
> algae <- na.omit(algae)
================================
因為含有大量缺漏值的資料行,無法在後續預測過程當中處理,因此需要移除含有大量缺漏值得資料。如果不想移除所有含缺漏值的資料,可以加入相關條件(例如從上述指令列出所有含缺漏值的資料中62、199兩筆資料含較多(6個)的缺漏值,下列指令可僅移除這兩筆資料):
================================algae <- algae[-c(62, 199), ]
================================
如果資料數量大時,使用上述方法顯得不是很方便,可使用下列方法列出所有資料行中含有的缺漏值數量。
================================
> apply(algae, 1, function(x) sum(is.na(x)))
================================
is.na()的傳回值bool,如果TRUE則為1,FALSE則為0,因此可以sum加總結果。
is.na()的傳回值bool,如果TRUE則為1,FALSE則為0,因此可以sum加總結果。
apply()是R的一個函數,是一種元函數(meta-functions)。他的主要功能是針對某物件在某些條件下執行其他的函數或方法。此方法的原型如下:
================================
================================
apply(X, MARGIN, FUN, ...)
其中:
X: 是array,及matrix的資料集
MARGIN: 以matrix而言 1 表示行(
rows),2
表示列(columns),而 c(1, 2)
表示行與列。
FUN:表示要執行的函數或方法。其中函數或方法的宣告先有
function(x)後面接函數的內容(暫時函數(temporary function)也就是說這個函數只能在這個apply中執行),如不只一行指令可以{}包主所有程式碼。其中x是傳入函數的參數,即前面宣告的X(上例中的algae),如果x含有多個元素,apply會一一取出每一行(因為MARGIN=1),傳至FUN中執行。如果FUN含有多一個以上的參數,第二個以後的參數則在apply的第三個參數(...)中逐一傳入。
...:FUN函數的其他參數。
================================
你也可以在R中執行?
apply看到相關的說明。
在DMwR套件中有一個函數manyNAs()可以傳回缺漏值數量大於某程度之資料行編號。
================================> data(algae)
//manyNAs()第二個參數預設值是0.2表示列出超過20%的欄位是缺漏值的資料行編號> manyNAs(algae, 0.2)
[1] 62 199
//從algae中移除太多缺漏值的資料行:
> algae <- algae[-manyNAs(algae), ]algae[, 4:18]表示不使用前三個欄位。 use = "complete.obs"表示不使用缺漏值。cor()得出一個關連值的matrix,數字接近1(-1)表示兩個變數之間強烈正(負)相關。這個方法得出之完整的數字,不易閱讀。我們可以利用另一個函數sysnum()得出更簡潔的結果。================================以最常出現的值填入缺漏值
尋找最常出現值(Most Frequent Values)有許多種策略,最快最簡單的就是使用某些統計中心值(statistic of centrality)如中值、平均值等,至於選擇哪一個則是資料分配狀況而定。如果是適當的常態分配,平均值是最佳的選擇。而如果是偏斜的分配或存在遠離本體的資料,中值則是較佳的選擇。例如algae[48,]的mxPH無值,而其分配接近常態分配(參考前述)。因此可以平均值填入該缺漏值。
================================
//只填入某一行的缺漏值
> algae[48, "mxPH"] <- mean(algae$mxPH, na.rm = T)
//填入所有缺漏值
> algae[is.na(algae$Chla), "Chla"] <- median(algae$Chla, na.rm = T)
//使用DMwR套件的centralImputation()填入所有缺漏值
> data(algae)
> algae <- algae[-manyNAs(algae), ]
> algae <- centralImputation(algae)
================================
na.rm = T 表示排除缺漏值,不計入平均的運算。
上述的的方法雖然快速方便,很適合大型的資料集,但是可能導致較大的偏差,因而影響後續的分析作業。然而有些最佳化的方法可能過於複雜,可能無法應用於大型的資料採礦問題。
透過欄位之間的相關性,用以填入缺漏值
例如,透過相關性分析,我們發現某些變數與mxPH具有高度相關性,便可以使用這些變數來預測mxPH的值。是用cor()可以獲得變數間的相關性:================================> data(algae)> cor(algae[, 4:18], use = "complete.obs")================================
================================
> symnum(cor(algae[,4:18],use="complete.obs"))================================從列表中我們可以看出PO4與oPO4間(大於0.9)具有強烈的正相關。而NH4與NO3間(0.72)可能有些風險。而且如果事先移除62及199者兩筆資料,NH4與NO3便無缺漏值了。因此可以使用PO4與oPO4之關連來填入缺漏值。於是便需要建立兩者之間的線性相關:================================> data(algae) > algae <- algae[-manyNAs(algae), > lm(PO4 ~ oPO4, data = algae)----------------------------------------------------------Call: lm(formula = PO4 ~ oPO4, data = algae) Coefficients: (Intercept) oPO4 42.897 1.293================================lm()可以用來建立現行模型:Y = β0+β1X1+...+βnXn所以得到的線性關連模型:PO4=42.897+1.293*oPO4我們便可以使用這個公式來填入缺漏值。================================//填入單一行的缺漏值> algae[28, "PO4"] <- 42.897 + 1.293 * algae[28, "oPO4"]-----------------------------------------------------------//填入所有行的缺漏值> data(algae) > algae <- algae[-manyNAs(algae), ]//定義fillPO4函數處理填入缺漏值的工作 > fillPO4 <- function(oP) { + if (is.na(oP)) + return(NA) + else return(42.897 + 1.293 * oP) +}//在所有行中執行fillPO4 > algae[is.na(algae$PO4), "PO4"] <- sapply(algae[is.na(algae$PO4), + "oPO4"], fillPO4)================================sapply()與apply()類似,第一個參數為所有具有PO4缺漏值的行,並逐一帶入帶入fillPO4算出值並傳回。apply()是作用在array中的向度(行或列)。sapply()是作用在List或vector中的元素。除了數字變數之間的關連性外,我們也可以探索數字缺漏值與因子變數之間的關連性,也就是使用lattice套件的直方圖探索期間的關連性。從這上圖可以發現四季的圖型相當類似,因此可以推論mxPH並未受季節的影響。如果以河流的size分別顯示:
================================
//因為R是依據英文排序,所以先設定season因子的順序為春夏秋冬
> algae$season <- factor(algae$season, + "summer", "autumn", "winter"))
> histogram(~mxPH | season, data = algae)
================================
依據季節因子顯示各季的資料
================================
> histogram(~mxPH | size, data = algae)
================================
依據河流size因子顯示各資料 |
從上圖中可以發現size愈小數mxPH值愈低。 除此之外,還可以擴充顯示的依據條件:
================================
> histogram(~mxPH | size * speed, data = algae)
================================
依據河流size及流速因子顯示各資料 |
由上圖可以發現缺少小河流低流速的資料,其實第48筆資料正屬這部份,只是mxPH無值!
另一種也可以顯示類似資訊的圖型,此圖型列出所有的資料:
================================
> stripplot(size ~ mxPH | speed, data = algae, jitter = T)
================================
依據河流size及流速因子顯示各資料分布 |
jitter=T使用隨機數將個數據位置沿Y軸上下移動以避免過度集中互相掩蓋。
使用圖型分析可以協助探索缺漏值,但是這種方式有點乏味,因為有太多種組合可以用來分析。雖然如此,這類方法頗適合小的資料集,利用因子變數來推測缺漏值。
透過其他行的相似性,用以填入缺漏值
這個方法的假設是:如果兩個樣本之間,除了缺漏值以外的欄位皆相似,便可以推論缺漏值應相似。所謂相似必須定義一個指標(matric),這個指標如下:
x,y表示兩個資料之間的欄位。d(x,y)愈小表示相似度愈高。這個方法會找出與此含缺漏值資料最相似的另外筆10資料。然後用這10筆資料來填入缺漏值。填入的方法有二:
第一種方法是,如果是數字變數,便使用這10筆的中數填入,如果是因子變數則使用最多出現的值填入。
第二種方法是,使用加權平均的方式,權重是依據與被填入資料的距離成反比。
這個方法在DMwR套件中已實作knnImputation()方法。
================================
//第一種方法
> algae <- knnImputation(algae, k = 10, meth = "median")
//第二種方法
>algae <- knnImputation(algae,k=10)
================================
k代表要找尋相似資料的數量,預設為10。
至於要使用哪種方法,大多時是視應用領域而定。『透過其他行的相似性,用以填入缺漏值』卻是比較合理的方法,雖然有些問題(含有一些不合理的值,可能影響相似性的判斷,或者超大型資料集,包含過度複雜計算的資料)。對於超大的問題,還是可以使用隨機樣本來計算相似性。
本節介紹兩種迴歸預測模型:多變量線性迴歸(multiple linear regression)及迴歸樹(regression trees)。這兩種方法是完全不同形式的迴歸作法,而且很容易說明及執行。在真實的資料採礦項目中還是有許多其他的方法可供採用。這兩種方法對於缺漏值的處理方式不同,多線性迴歸不允許有缺漏值,所以必須先處理缺漏值的問題。迴歸樹則不在乎缺漏值,所以可以直接使用原始資料集。本節會利用前述200筆資料來探索預測模式,並用這些模式預測另外140筆的測試資料。
多變量線性迴歸
多變量線性迴歸是種常用資料分析的技術。這種模式是由多個變量Xi乘上特定係數βi加總後的結果。
================================
//首先處理缺漏值:
> data(algae)
> algae <- algae[-manyNAs(algae), ]
> clean.algae <- knnImputation(algae,k=10)
//使用lm()取得迴歸模型
> lm.a1 <- lm(a1 ~ ., data = clean.algae[, 1:12])
================================
lm()第一個參數a1 ~ .表示a1依據第二個參數指定的欄位統計迴歸模型,第二個參數指定前12欄位最為迴歸統計之來源。(注意第12個欄位是a1必須包含)如果你想依據某幾個欄位來分析第一個參數可以改寫成a1∼mxPH + NH4。'.'則代表第二參數的所有欄位。
如果想看此迴歸模型的內容可使用summary():
================================
================================
summary()會給出一些診斷的資訊,首先是線性模型是否符合的殘差(residuals),這個殘差必須是平均0(或愈接近0)愈表示模型可用,而且是一個常態分配。
對於每一個變量,R顯示其值及標準差。
要驗證每一變量的重要度,先假設其參數值為H0:βi=0(null),然後運用t-test來測試這個假設。t=βi/(標準差),最後算出Pr(>|t|)這一欄,其值如果是0.0001表示我們有99.99%的信心這個參數不是0(null)。如果結果不到90%的信心度(一個星號以上),βi=0(null)便為真。
另一個診斷指標R²(multiple 及 adjusted)表示模型的適合度,亦即模型資料的變異比率(proportion of variance)。其值愈接近1愈佳(幾乎100%呈現其變動),值愈小,適合度愈差。adjusted係數更是嚴苛,因為它參考數個迴歸模型的參數。
最後我們也測試null假設:所有變數皆無關,即H0:β1=
β2=...=βm=0。F-statistic即代表這個假設。若p-level值為0.0001即表示有99.99%的信心度null假設為非。即可以排除null假設。如果p-level。0.1表示無法通過測試,上述的t-test便無意義了。
有些診斷方式可以採用繪圖方式。如plot(lm.a1):
這個模型資料的變異比率(proportion of variance)不是很顯著(R² adjusted約32.0%)。因為F-statistic的p-level值非常小,故可以推翻目標變數與預測不相依的假設。
由於部分變數不是很顯著相關,必須將其從模型中排除。有許多方法可供選擇,下面是使用後退淘汰法(backward elimination),我們使用stats套件中的anova()來處理:
執行結果由於season的Pr(>F)接近1,代表其對於降低適合度殘差貢獻度最差。所以必需移除:
================================
> lm2.a1 <- update(lm.a1,.~.-season)
================================
update()是stats套件的方法。可以變動現有的線性模型。lms.a1即是從lm.a1中移除season變數。我們可以看看執行結果:
變異比率稍有改善但還不是很顯著(R² adjusted約32.8%),我們可以再次使用anova()來比對前後兩次的差距:
anova()使用F-test分析兩者的變異是否顯著,在這個案例中,雖然方差和(sum of the squared errors)降低(-447.62),但比對結果變異不是恨顯著(Pr(>F)=0.6971約只有30%的信心度說兩者有差異)。我們還可以再次執行直到移除所有的變數。但我們可以使用stats套件中的step()方法,它會幫我們自動反覆執行:
我們可以檢視執行結果:
最後的結果顯示變異比率(33%)仍然不理想,表示此領域問題並不適合線性假設。
另一個診斷指標R²(multiple 及 adjusted)表示模型的適合度,亦即模型資料的變異比率(proportion of variance)。其值愈接近1愈佳(幾乎100%呈現其變動),值愈小,適合度愈差。adjusted係數更是嚴苛,因為它參考數個迴歸模型的參數。
最後我們也測試null假設:所有變數皆無關,即H0:β1=
β2=...=βm=0。F-statistic即代表這個假設。若p-level值為0.0001即表示有99.99%的信心度null假設為非。即可以排除null假設。如果p-level。0.1表示無法通過測試,上述的t-test便無意義了。
有些診斷方式可以採用繪圖方式。如plot(lm.a1):
每一個適合變數值相對殘差值 標出編號者表示殘差太大 |
殘差是否符合常態分配 分配是否貼近直線 |
這個模型資料的變異比率(proportion of variance)不是很顯著(R² adjusted約32.0%)。因為F-statistic的p-level值非常小,故可以推翻目標變數與預測不相依的假設。
由於部分變數不是很顯著相關,必須將其從模型中排除。有許多方法可供選擇,下面是使用後退淘汰法(backward elimination),我們使用stats套件中的anova()來處理:
執行結果由於season的Pr(>F)接近1,代表其對於降低適合度殘差貢獻度最差。所以必需移除:
================================
> lm2.a1 <- update(lm.a1,.~.-season)
================================
update()是stats套件的方法。可以變動現有的線性模型。lms.a1即是從lm.a1中移除season變數。我們可以看看執行結果:
變異比率稍有改善但還不是很顯著(R² adjusted約32.8%),我們可以再次使用anova()來比對前後兩次的差距:
anova()使用F-test分析兩者的變異是否顯著,在這個案例中,雖然方差和(sum of the squared errors)降低(-447.62),但比對結果變異不是恨顯著(Pr(>F)=0.6971約只有30%的信心度說兩者有差異)。我們還可以再次執行直到移除所有的變數。但我們可以使用stats套件中的step()方法,它會幫我們自動反覆執行:
我們可以檢視執行結果:
最後的結果顯示變異比率(33%)仍然不理想,表示此領域問題並不適合線性假設。
迴歸樹
這是另一種迴歸模型,執行下列指令:
================================
> library(rpart)
> data(algae)
> algae <- algae[-manyNAs(algae), ]
> rt.a1 <- rpart(a1 ~ ., data = algae[, 1:12])
================================
執行rpart()產生一迴歸樹,其參數與前述的lm()一樣。檢視rt.a1的內容:
================================
split:分支條件 n:樣本數 deviance:與平均值偏差的平方和 yval:a1的平均值 |
================================
迴歸樹是解釋性(explanatory)變量邏輯推論的層級架構(hierarchy)。它會自動選取相關的變數,因此並不是所有的變數都會用到。閱讀方式,從root節點開始,然後依數字由小而大。每個數字代表一個節點。除了末節點,每一個節點都有兩個分支,每個分支都有其進入的條件。只要從root開始,依條件進行,直到末節點其最終結果(預測值)便是yval。我們可以使用graphics套件的plot(rt.a1)及text(rt.a1)或DMwR套件的prettyTree(rt.a1)來繪製樹狀圖:
用Plot()繪製的樹狀圖 |
用text()繪製的樹狀圖 |
用prettyTree()繪製的樹狀圖 |
我們也可以用summary()來檢視其內容:
================================
> summary(rt.a1)
----------------------------------------------------------
Call:
rpart(formula = a1 ~ ., data = algae[, 1:12])
n= 198
CP nsplit rel error xerror xstd
1 0.40573990 0 1.0000000 1.0089717 0.1308742
2 0.07188523 1 0.5942601 0.6757552 0.1183296
3 0.03088731 2 0.5223749 0.6591388 0.1164799
4 0.03040753 3 0.4914876 0.7006920 0.1219898
5 0.02787181 4 0.4610800 0.7020964 0.1220327
6 0.02775354 5 0.4332082 0.7020964 0.1220327
7 0.01812406 6 0.4054547 0.6953859 0.1139465
8 0.01634372 7 0.3873306 0.6959028 0.1125755
9 0.01000000 9 0.3546432 0.6957636 0.1127715
Variable importance
PO4 oPO4 NH4 Cl mxPH Chla NO3 mnO2 size season speed
25 20 15 15 9 7 3 2 1 1 1
Node number 1: 198 observations, complexity param=0.4057399
mean=16.99646, MSE=456.5722
left son=2 (147 obs) right son=3 (51 obs)
Primary splits:
PO4 < 43.818 to the right, improve=0.4048567, (1 missing)
oPO4 < 18.889 to the right, improve=0.3793450, (0 missing)
NH4 < 51.27 to the right, improve=0.3625269, (0 missing)
Cl < 7.2915 to the right, improve=0.3583409, (8 missing)
Chla < 1.15 to the right, improve=0.2533869, (10 missing)
Surrogate splits:
oPO4 < 17.5415 to the right, agree=0.944, adj=0.78, (1 split)
NH4 < 37.639 to the right, agree=0.893, adj=0.58, (0 split)
Cl < 9.0275 to the right, agree=0.858, adj=0.44, (0 split)
Chla < 1.05 to the right, agree=0.822, adj=0.30, (0 split)
mxPH < 7.295 to the right, agree=0.817, adj=0.28, (0 split)
Node number 2: 147 observations, complexity param=0.07188523
mean=8.979592, MSE=212.7831
left son=4 (140 obs) right son=5 (7 obs)
Primary splits:
Cl < 7.8065 to the right, improve=0.2071337, (1 missing)
Chla < 1.15 to the right, improve=0.1959676, (1 missing)
oPO4 < 51.118 to the right, improve=0.1651094, (0 missing)
NH4 < 49.25 to the right, improve=0.1494842, (0 missing)
PO4 < 125 to the right, improve=0.1393822, (0 missing)
Surrogate splits:
Chla < 0.6 to the right, agree=0.959, adj=0.143, (1 split)
Node number 3: 51 observations, complexity param=0.03040753
mean=40.10392, MSE=440.0541
left son=6 (28 obs) right son=7 (23 obs)
Primary splits:
mxPH < 7.87 to the left, improve=0.12171490, (1 missing)
PO4 < 6.35 to the right, improve=0.10576260, (1 missing)
Cl < 7.544 to the right, improve=0.10428070, (7 missing)
NH4 < 18.381 to the right, improve=0.10356000, (0 missing)
oPO4 < 10.625 to the right, improve=0.09644168, (0 missing)
Surrogate splits:
size splits as RRL, agree=0.78, adj=0.522, (1 split)
NO3 < 1.1875 to the right, agree=0.74, adj=0.435, (0 split)
oPO4 < 3.111 to the left, agree=0.70, adj=0.348, (0 split)
season splits as LLRR, agree=0.60, adj=0.130, (0 split)
NH4 < 22.0355 to the left, agree=0.60, adj=0.130, (0 split)
Node number 4: 140 observations, complexity param=0.03088731
mean=7.492857, MSE=154.4488
left son=8 (84 obs) right son=9 (56 obs)
Primary splits:
oPO4 < 51.118 to the right, improve=0.12913450, (0 missing)
PO4 < 125 to the right, improve=0.09908251, (0 missing)
NH4 < 41.875 to the right, improve=0.05847356, (0 missing)
NO3 < 3.2725 to the right, improve=0.05343570, (0 missing)
Chla < 3.65 to the right, improve=0.04761161, (1 missing)
Surrogate splits:
PO4 < 125 to the right, agree=0.857, adj=0.643, (0 split)
Cl < 27.8665 to the right, agree=0.721, adj=0.304, (0 split)
NO3 < 3.313 to the right, agree=0.679, adj=0.196, (0 split)
mnO2 < 9.5 to the left, agree=0.664, adj=0.161, (0 split)
season splits as RLLL, agree=0.657, adj=0.143, (0 split)
Node number 5: 7 observations
mean=38.71429, MSE=451.1098
Node number 6: 28 observations, complexity param=0.02775354
mean=33.45, MSE=409.0275
left son=12 (18 obs) right son=13 (10 obs)
Primary splits:
mxPH < 7.045 to the right, improve=0.2296931, (1 missing)
PO4 < 6.25 to the right, improve=0.2174386, (1 missing)
oPO4 < 12.375 to the right, improve=0.1721865, (0 missing)
NH4 < 17.1 to the right, improve=0.1098949, (0 missing)
season splits as LRRR, improve=0.0944271, (0 missing)
Surrogate splits:
NH4 < 11.25 to the right, agree=0.852, adj=0.556, (1 split)
oPO4 < 1.125 to the right, agree=0.852, adj=0.556, (0 split)
PO4 < 6.5835 to the right, agree=0.852, adj=0.556, (0 split)
speed splits as L-R, agree=0.778, adj=0.333, (0 split)
NO3 < 1.9675 to the left, agree=0.778, adj=0.333, (0 split)
Node number 7: 23 observations, complexity param=0.02787181
mean=48.20435, MSE=358.3091
left son=14 (12 obs) right son=15 (11 obs)
Primary splits:
PO4 < 15.177 to the right, improve=0.3057413, (0 missing)
NH4 < 20.4165 to the right, improve=0.2692864, (0 missing)
Cl < 7.544 to the right, improve=0.2055829, (0 missing)
Chla < 0.85 to the right, improve=0.1534699, (1 missing)
oPO4 < 6.25 to the right, improve=0.1013330, (0 missing)
Surrogate splits:
NH4 < 20.4165 to the right, agree=0.913, adj=0.818, (0 split)
Cl < 5.8595 to the right, agree=0.826, adj=0.636, (0 split)
NO3 < 1.353 to the right, agree=0.826, adj=0.636, (0 split)
oPO4 < 5 to the right, agree=0.783, adj=0.545, (0 split)
Chla < 0.85 to the right, agree=0.739, adj=0.455, (0 split)
Node number 8: 84 observations
mean=3.846429, MSE=40.96606
Node number 9: 56 observations, complexity param=0.01812406
mean=12.9625, MSE=274.8113
left son=18 (24 obs) right son=19 (32 obs)
Primary splits:
mnO2 < 10.05 to the right, improve=0.10646520, (0 missing)
PO4 < 101.894 to the left, improve=0.08815216, (0 missing)
oPO4 < 24.3335 to the left, improve=0.07637520, (0 missing)
size splits as LLR, improve=0.06017653, (0 missing)
mxPH < 8.35 to the right, improve=0.05440345, (0 missing)
Surrogate splits:
PO4 < 101.894 to the left, agree=0.750, adj=0.417, (0 split)
size splits as LRR, agree=0.696, adj=0.292, (0 split)
season splits as LRRR, agree=0.679, adj=0.250, (0 split)
NH4 < 89.8 to the left, agree=0.661, adj=0.208, (0 split)
mxPH < 8.025 to the right, agree=0.643, adj=0.167, (0 split)
Node number 12: 18 observations
mean=26.39444, MSE=285.8983
Node number 13: 10 observations
mean=46.15, MSE=379.7645
Node number 14: 12 observations
mean=38.18333, MSE=253.9597
Node number 15: 11 observations
mean=59.13636, MSE=243.086
Node number 18: 24 observations
mean=6.716667, MSE=52.02806
Node number 19: 32 observations, complexity param=0.01634372
mean=17.64687, MSE=390.6975
left son=38 (9 obs) right son=39 (23 obs)
Primary splits:
NO3 < 3.1875 to the right, improve=0.09580105, (0 missing)
Chla < 2.55 to the left, improve=0.08399898, (0 missing)
oPO4 < 24.917 to the left, improve=0.07524892, (0 missing)
mnO2 < 9.4 to the left, improve=0.06578127, (0 missing)
Cl < 43.7085 to the right, improve=0.04807023, (0 missing)
Surrogate splits:
mxPH < 7.55 to the left, agree=0.844, adj=0.444, (0 split)
NH4 < 224.643 to the right, agree=0.812, adj=0.333, (0 split)
PO4 < 206.7225 to the right, agree=0.812, adj=0.333, (0 split)
Cl < 84.0465 to the right, agree=0.750, adj=0.111, (0 split)
Node number 38: 9 observations
mean=7.866667, MSE=28.56444
Node number 39: 23 observations, complexity param=0.01634372
mean=21.47391, MSE=480.3263
left son=78 (13 obs) right son=79 (10 obs)
Primary splits:
mnO2 < 8 to the left, improve=0.15906320, (0 missing)
PO4 < 118.6 to the left, improve=0.10091960, (0 missing)
NH4 < 168.75 to the left, improve=0.07651249, (0 missing)
NO3 < 1.2495 to the left, improve=0.07260629, (0 missing)
mxPH < 8.26 to the right, improve=0.06930695, (0 missing)
Surrogate splits:
season splits as RLLL, agree=0.696, adj=0.3, (0 split)
size splits as LLR, agree=0.696, adj=0.3, (0 split)
speed splits as RLL, agree=0.696, adj=0.3, (0 split)
Cl < 46.9725 to the right, agree=0.696, adj=0.3, (0 split)
NH4 < 216.653 to the left, agree=0.696, adj=0.3, (0 split)
Node number 78: 13 observations
mean=13.80769, MSE=224.5807
Node number 79: 10 observations
mean=31.44, MSE=637.0704
================================
樹的建構有兩個步驟:一開始先建一個大樹,然後依據統計評估從底部修剪。因為大樹可以完美的符合訓練資料,但是可能擷取到假性的關連性,因而應用在新資料時,效能可能不佳 。這種現象稱為『過度擬合(overfitting)』。過度擬合在建模時常發生,特別是假設太過寬鬆時。
rpart()建立的樹是依據下列三原則決定其規模:1.偏差低於某門檻。2.節點中的樣本數低於某門檻。3.樹的深度高於某值。這三者是依據下列參數:cp、minsplit、maxdepth,其預設值分別為0.01、20、30。
printcp()透過cp值產生一群子樹,並評估這些子樹的效能:
================================
>printcp(rt.a1)
----------------------------------------------------------
Regression tree:
rpart(formula = a1 ~ ., data = algae[, 1:12])
Variables actually used in tree construction:
[1] Cl mnO2 mxPH NO3 oPO4 PO4
Root node error: 90401/198 = 456.57
n= 198
CP nsplit rel error xerror xstd
1 0.405740 0 1.00000 1.01123 0.130994
2 0.071885 1 0.59426 0.76443 0.115210
3 0.030887 2 0.52237 0.71454 0.114962
4 0.030408 3 0.49149 0.69195 0.106685
5 0.027872 4 0.46108 0.68825 0.106771
6 0.027754 5 0.43321 0.68398 0.107168
7 0.018124 6 0.40545 0.67585 0.104675
8 0.016344 7 0.38733 0.69122 0.105072
9 0.010000 9 0.35464 0.69876 0.099859
================================
註:由於是採隨機運算,因此上列xerror及xstd數字每次執行可能不同。
rpart()產生的是其中第9回,其cp=0.01。但第7回的相對誤差最低,是比較佳的選擇。
另一種選擇的規則是1-SE,選擇最小的樹(第2回)。誤差0.76443小於0.67585+0.104675=0.780525。
當決定所採用的樹,便可以透過cp值來修剪:
================================
> rt2.a1 <- prune(rt.a1, cp = 0.08)
> rt2.a1
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.29 16.996460
2) PO4>=43.818 147 31279.12 8.979592 *
3) PO4< 43.818 51 22442.76 40.103920 *
================================
上述建立樹及修剪樹的過程可以用rpartXse()一次完成:
================================
> (rt.a1 <- rpartXse(a1 ~ ., data = algae[, 1:12]))
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.29 16.996460
2) PO4>=43.818 147 31279.12 8.979592 *
3) PO4< 43.818 51 22442.76 40.103920 *
================================
R也可以讓你自行決定修剪樹的哪些節點:
================================
> first.tree <- rpart(a1 ~ ., data = algae[, 1:12])
> snip.rpart(first.tree, c(4, 7))
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.290 16.996460
2) PO4>=43.818 147 31279.120 8.979592
4) Cl>=7.8065 140 21622.830 7.492857 *
5) Cl< 7.8065 7 3157.769 38.714290 *
3) PO4< 43.818 51 22442.760 40.103920
6) mxPH< 7.87 28 11452.770 33.450000 *
7) mxPH>=7.87 23 8241.110 48.204350
14) PO4>=15.177 12 3047.517 38.183330 *
================================
snip.rpart()第二個參數是欲修剪的節點編號。
你也可以在圖形上直接點選欲修剪的節點:
================================
> prettyTree(first.tree)
> snip.rpart(first.tree)
----------------------------------------------------------
node number: 2 n= 147
response= 8.979592
Error (dev) = 31279.12
node number: 6 n= 28
response= 33.45
Error (dev) = 11452.77
n= 198
node), split, n, deviance, yval
1) root 198 90401.290 16.996460
2) PO4>=43.818 147 31279.120 8.979592 *
3) PO4< 43.818 51 22442.760 40.103920
6) mxPH< 7.87 28 11452.770 33.450000 *
7) mxPH>=7.87 23 8241.110 48.204350
14) PO4>=15.177 12 3047.517 38.183330 *
15) PO4< 15.177 11 2673.945 59.136360 *
================================
以上是點擊2、6兩個節點的結果。
rpart()建立的樹是依據下列三原則決定其規模:1.偏差低於某門檻。2.節點中的樣本數低於某門檻。3.樹的深度高於某值。這三者是依據下列參數:cp、minsplit、maxdepth,其預設值分別為0.01、20、30。
printcp()透過cp值產生一群子樹,並評估這些子樹的效能:
================================
>printcp(rt.a1)
----------------------------------------------------------
Regression tree:
rpart(formula = a1 ~ ., data = algae[, 1:12])
Variables actually used in tree construction:
[1] Cl mnO2 mxPH NO3 oPO4 PO4
Root node error: 90401/198 = 456.57
n= 198
CP nsplit rel error xerror xstd
1 0.405740 0 1.00000 1.01123 0.130994
2 0.071885 1 0.59426 0.76443 0.115210
3 0.030887 2 0.52237 0.71454 0.114962
4 0.030408 3 0.49149 0.69195 0.106685
5 0.027872 4 0.46108 0.68825 0.106771
6 0.027754 5 0.43321 0.68398 0.107168
7 0.018124 6 0.40545 0.67585 0.104675
8 0.016344 7 0.38733 0.69122 0.105072
9 0.010000 9 0.35464 0.69876 0.099859
================================
註:由於是採隨機運算,因此上列xerror及xstd數字每次執行可能不同。
rpart()產生的是其中第9回,其cp=0.01。但第7回的相對誤差最低,是比較佳的選擇。
另一種選擇的規則是1-SE,選擇最小的樹(第2回)。誤差0.76443小於0.67585+0.104675=0.780525。
當決定所採用的樹,便可以透過cp值來修剪:
================================
> rt2.a1 <- prune(rt.a1, cp = 0.08)
> rt2.a1
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.29 16.996460
2) PO4>=43.818 147 31279.12 8.979592 *
3) PO4< 43.818 51 22442.76 40.103920 *
================================
上述建立樹及修剪樹的過程可以用rpartXse()一次完成:
================================
> (rt.a1 <- rpartXse(a1 ~ ., data = algae[, 1:12]))
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.29 16.996460
2) PO4>=43.818 147 31279.12 8.979592 *
3) PO4< 43.818 51 22442.76 40.103920 *
================================
R也可以讓你自行決定修剪樹的哪些節點:
================================
> first.tree <- rpart(a1 ~ ., data = algae[, 1:12])
> snip.rpart(first.tree, c(4, 7))
----------------------------------------------------------
n= 198
node), split, n, deviance, yval
* denotes terminal node
1) root 198 90401.290 16.996460
2) PO4>=43.818 147 31279.120 8.979592
4) Cl>=7.8065 140 21622.830 7.492857 *
5) Cl< 7.8065 7 3157.769 38.714290 *
3) PO4< 43.818 51 22442.760 40.103920
6) mxPH< 7.87 28 11452.770 33.450000 *
7) mxPH>=7.87 23 8241.110 48.204350
14) PO4>=15.177 12 3047.517 38.183330 *
================================
snip.rpart()第二個參數是欲修剪的節點編號。
你也可以在圖形上直接點選欲修剪的節點:
================================
> prettyTree(first.tree)
> snip.rpart(first.tree)
----------------------------------------------------------
node number: 2 n= 147
response= 8.979592
Error (dev) = 31279.12
node number: 6 n= 28
response= 33.45
Error (dev) = 11452.77
n= 198
node), split, n, deviance, yval
1) root 198 90401.290 16.996460
2) PO4>=43.818 147 31279.120 8.979592 *
3) PO4< 43.818 51 22442.760 40.103920
6) mxPH< 7.87 28 11452.770 33.450000 *
7) mxPH>=7.87 23 8241.110 48.204350
14) PO4>=15.177 12 3047.517 38.183330 *
15) PO4< 15.177 11 2673.945 59.136360 *
================================
以上是點擊2、6兩個節點的結果。
模型的評估與選擇
評估模型最常用的標準是模型的預測能力(predictive performance of the models),其評估方式是使用真實資料預測並比對目標值,計算出其平均誤差。平均絕對誤差(mean absolute error (MAE))即是其中之一。在R中使用predict()來獲得預測結果:
================================
> lm.predictions.a1 <- predict(final.lm, clean.algae)
> rt.predictions.a1 <- predict(rt.a1, algae)
================================
predict(模型,資料集)
接著計算平均絕對誤差:
================================
> (mae.a1.lm <- mean(abs(lm.predictions.a1 - algae[, "a1"])))
----------------------------------------------------------
[1] 13.10681
================================
> (mae.a1.rt <- mean(abs(rt.predictions.a1 - algae[, "a1"])))
----------------------------------------------------------
[1] 11.61717
================================
另一個常用的誤差評估值是均方誤差(mean squared error (MSE))
================================
> (mse.a1.lm <- mean((lm.predictions.a1 - algae[, "a1"])^2))
----------------------------------------------------------
[1] 295.5407
================================
> (mse.a1.rt <- mean((rt.predictions.a1 - algae[, "a1"])^2))
----------------------------------------------------------
[1] 271.3226
================================
MSE有一個缺點是評估的目標值並非同一單位。因此較不常用,所以一般採用歸一化均方誤差(normalized mean squared error(NMSE))。這個統計值計算模型效能與預測基準(目標值的平均值)的比值:
================================
> (nmse.a1.lm <- mean((lm.predictions.a1-algae[,'a1'])^2)/mean((mean(algae[,'a1'])-algae[,'a1'])^2))
----------------------------------------------------------
[1] 0.6473034
================================
> (nmse.a1.rt <- mean((rt.predictions.a1-algae[,'a1'])^2)/mean((mean(algae[,'a1'])-algae[,'a1'])^2))
----------------------------------------------------------
[1] 0.5942601
================================
NMSE是無單位的,一般從0到1。如果模型效能佳,即愈趨近預測基準,則NMSE小於1,且值愈小愈好。如果大於1表示比用平均值作為預測值的還差。
我們可以使用regr.eval()各種的迴歸評估指標:
================================
> regr.eval(algae[, "a1"], rt.predictions.a1, train.y = algae[, "a1"])
----------------------------------------------------------
mae mse rmse nmse nmae
11.6171709 271.3226161 16.4718735 0.5942601 0.6953711
================================
我們也可以視覺化模型的效能,如誤差分布圖:
================================
> old.par <- par(mfrow = c(1, 2))
> plot(lm.predictions.a1, algae[, "a1"], main = "Linear Model",xlab = "Predictions", ylab = "True Values")
> abline(0, 1, lty = 2)
> plot(rt.predictions.a1, algae[, "a1"], main = "Regression Tree",xlab = "Predictions", ylab = "True Values")
> abline(0, 1, lty = 2)
> par(old.par)
================================
誤差分布圖 |
從上圖觀察,顯示模型效能不佳,因為效能佳的時,誤差值的點應接近虛線[abline(0,1,lty=2)]部分。虛線是代表X軸等於Y軸部分。同樣的也可以使用identify()手動標示個別的點:
================================
> plot(lm.predictions.a1,algae[,'a1'],main="Linear Model",xlab="Predictions",ylab="True Values")
> abline(0,1,lty=2)
> algae[identify(lm.predictions.a1,algae[,'a1']),]
================================
以滑鼠點選以選出差異大的點。按Esc結束後,R會顯示所選點的資料。
從上圖中可以發現有些點是負值,在此案例中的領域規則裏,負值是無意義的,而大部分情況下,應該是0。因此我們可以利用領域規則,加上ifelse()判斷<0的情況來改善這個模型:(請注意評估指標,是否對於模型的效能有所提昇)
================================
> sensible.lm.predictions.a1 <- ifelse(lm.predictions.a1 < 0, 0, lm.predictions.a1)
> regr.eval(algae[, "a1"], lm.predictions.a1, stats = c("mae", "mse"))
-----------------------------------------------------------
mae mse
13.10681 295.54069
================================
> regr.eval(algae[, "a1"], sensible.lm.predictions.a1, stats = c("mae", "mse"))
-----------------------------------------------------------
mae mse
12.48276 286.28541
================================
ifelse()有三個參數,第一個是邏輯判斷,如果此邏輯判斷為True則傳回第二參數,否則傳回第三參數。
從上述的評估,迴歸樹因為NMSE比較低似乎是比較適合用來預估此140個測試樣本。但這樣的推理有個陷阱。因為我們不知道這些樣本的目標便數值,而要去評估哪些模型的效能是比較佳的。重點是從不知道真正目標值的資料中,獲得模型效能可信賴的評估。使用訓練資料計算效能值標不是可信賴的,因為其結果會有偏差。事實上,有些模型可以從訓練資料得到零預測誤差的。但是這樣的效能並不能同樣在目標值未知的新樣本中也有同樣的效能。這個現象就如同前述的過度擬合。
因此需要一些模型可以評估運用於含未知值資料的模型效能。
k-fold交叉驗證(k-fold cross-validation)便是針對小型資料集一種最常用的方法。從訓練資料中隨機取得k個相同數量的子集,對每一個k子集使用其餘k-1個子集建構模型,病理用此k子集評估其效能並記錄模型的效能,然後針對所有其他的子集進行相同的步驟。最後得到k個效能評估。所有這些模型效能評估所用的資料都不是用以建構模型的資料。一般k值都是使用10。有時我們還會執行多次,以獲得更有信賴的評估。
一般面對預測的任務,需要下列的決策:
================================
> plot(lm.predictions.a1,algae[,'a1'],main="Linear Model",xlab="Predictions",ylab="True Values")
> abline(0,1,lty=2)
> algae[identify(lm.predictions.a1,algae[,'a1']),]
================================
以滑鼠點選以選出差異大的點。按Esc結束後,R會顯示所選點的資料。
從上圖中可以發現有些點是負值,在此案例中的領域規則裏,負值是無意義的,而大部分情況下,應該是0。因此我們可以利用領域規則,加上ifelse()判斷<0的情況來改善這個模型:(請注意評估指標,是否對於模型的效能有所提昇)
================================
> sensible.lm.predictions.a1 <- ifelse(lm.predictions.a1 < 0, 0, lm.predictions.a1)
> regr.eval(algae[, "a1"], lm.predictions.a1, stats = c("mae", "mse"))
-----------------------------------------------------------
mae mse
13.10681 295.54069
================================
> regr.eval(algae[, "a1"], sensible.lm.predictions.a1, stats = c("mae", "mse"))
-----------------------------------------------------------
mae mse
12.48276 286.28541
================================
ifelse()有三個參數,第一個是邏輯判斷,如果此邏輯判斷為True則傳回第二參數,否則傳回第三參數。
從上述的評估,迴歸樹因為NMSE比較低似乎是比較適合用來預估此140個測試樣本。但這樣的推理有個陷阱。因為我們不知道這些樣本的目標便數值,而要去評估哪些模型的效能是比較佳的。重點是從不知道真正目標值的資料中,獲得模型效能可信賴的評估。使用訓練資料計算效能值標不是可信賴的,因為其結果會有偏差。事實上,有些模型可以從訓練資料得到零預測誤差的。但是這樣的效能並不能同樣在目標值未知的新樣本中也有同樣的效能。這個現象就如同前述的過度擬合。
因此需要一些模型可以評估運用於含未知值資料的模型效能。
k-fold交叉驗證(k-fold cross-validation)便是針對小型資料集一種最常用的方法。從訓練資料中隨機取得k個相同數量的子集,對每一個k子集使用其餘k-1個子集建構模型,病理用此k子集評估其效能並記錄模型的效能,然後針對所有其他的子集進行相同的步驟。最後得到k個效能評估。所有這些模型效能評估所用的資料都不是用以建構模型的資料。一般k值都是使用10。有時我們還會執行多次,以獲得更有信賴的評估。
一般面對預測的任務,需要下列的決策:
- 選擇各式模型(也可能是同模型但參數不同)。
- 選擇用於比對模型的評估指標。
- 選擇試驗的方法論,以獲得這些指標可信賴的評估。
我們使用experimentalComparison()來選擇及評估。這個方法有三個參數:1.用以比對的資料集。2.各式模型。3.試驗程序的參數。
使用者提供實作模型的函數,這樣的函數必需針對所提供的訓練資料及測試資料,實作完整的「訓練」+「測試」+「評估」的流程。在函數中,逐一呼叫這些函數,這些函數傳回使用者想要的評估指標(此例使用NMSE)。如下所示:
================================
> cv.rpart <- function(form,train,test,...) {
+ m <- rpartXse(form,train,...)
+ p <- predict(m,test)
+ mse <- mean((p-resp(form,test))^2)
+ c(nmse=mse/mean((mean(resp(form,train))-resp(form,test))^2))
+}
> cv.lm <- function(form,train,test,...) {
+ m <- lm(form,train,...)
+ p <- predict(m,test)
+ p <- ifelse(p < 0,0,p)
+ mse <- mean((p-resp(form,test))^2)
+ c(nmse=mse/mean((mean(resp(form,train))-resp(form,test))^2))
+}
================================
所有使用者實作的函數需有前三個參數:1.(form)公式(formula)2.訓練資料3.測試資料。其餘的參數(...)是experimentalComparison用來學習評估之用。其中resp()是依據公式取得資料集中的目標便數值。
接下來便是進入學習及測試流程:
================================
> res <- experimentalComparison(
+ c(dataset(a1 ~ .,clean.algae[,1:12],'a1')),
+ c(variants('cv.lm'),
+ variants('cv.rpart',se=c(0,0.5,1))),
+ cvSettings(3,10,1234))
-----------------------------------------------------------
##### CROSS VALIDATION EXPERIMENTAL COMPARISON #####
** DATASET :: a1
++ LEARNER :: cv.lm variant -> cv.lm.v1
3 x 10 - Fold Cross Validation run with seed = 1234
Repetition 1
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 2
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 3
Fold: 1 2 3 4 5 6 7 8 9 10
++ LEARNER :: cv.rpart variant -> cv.rpart.v1
3 x 10 - Fold Cross Validation run with seed = 1234
Repetition 1
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 2
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 3
Fold: 1 2 3 4 5 6 7 8 9 10
++ LEARNER :: cv.rpart variant -> cv.rpart.v2
3 x 10 - Fold Cross Validation run with seed = 1234
Repetition 1
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 2
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 3
Fold: 1 2 3 4 5 6 7 8 9 10
++ LEARNER :: cv.rpart variant -> cv.rpart.v3
3 x 10 - Fold Cross Validation run with seed = 1234
Repetition 1
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 2
Fold: 1 2 3 4 5 6 7 8 9 10
Repetition 3
Fold: 1 2 3 4 5 6 7 8 9 10
================================
experimentalComparison()方法的第 一個參數,是由dataset(<formula>,<data frame>,<label>)定義的資料集。第二個參數是由variants()定義的使用者定義包含訓練+測試+評估過程的函數。第三個參數是函數的設定值,第一個參數表示k-flods執行次數,第二個參數指定k值,第三個參數指定隨機種子。
variants()第一個參數是函數名稱,第二個參數是該函數所需之參數。執行時依據第二參數指定的參數數量,執行第一參數指定函數相同的次數,並將第二參數的每一個參數傳入函數中。
我們使用summary()來檢視實際執行結果:
================================
> summary(res)
-----------------------------------------------------------
== Summary of a Cross Validation Experiment ==
3 x 10 - Fold Cross Validation run with seed = 1234
* Data sets :: a1
* Learners :: cv.lm.v1, cv.rpart.v1, cv.rpart.v2, cv.rpart.v3
* Summary of Experiment Results:
-> Datataset: a1
*Learner: cv.lm.v1
nmse
avg 0.7196105
std 0.1833064
min 0.4678248
max 1.2218455
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 0.6440843
std 0.2521952
min 0.2146359
max 1.1712674
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 0.6873747
std 0.2669942
min 0.2146359
max 1.3356744
invalid 0.0000000
*Learner: cv.rpart.v3
nmse
avg 0.7167122
std 0.2579089
min 0.3476446
max 1.3356744
invalid 0.0000000
================================
也可以使用圖形來觀察:
================================
> plot(res)
================================
欲查看某一模型的結果:
================================
> getVariant("cv.rpart.v1", res)
----------------------------------------------------------
Learner:: "cv.rpart"
Parameter values
se = 0
================================
我們也可以同時執行7個水藻變數的預測評估:
================================
> DSs <- sapply(names(clean.algae)[12:18],
+ function(x,names.attrs) {
+ f <- as.formula(paste(x,"~ ."))
+ dataset(f,clean.algae[,c(names.attrs,x)],x)
+}, + names(clean.algae)[1:11])
>DSs
----------------------------------------------------------
$a1
Task Name :: a1
Formula :: a1 ~ .
<environment: 0x000000001071e1b8>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a1 : num 0 1.4 3.3 3.1 9.2 15.1 2.4 18.2 25.4 17 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a2
Task Name :: a2
Formula :: a2 ~ .
<environment: 0x0000000010317b08>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a2 : num 0 7.6 53.6 41 2.9 14.6 1.2 1.6 5.4 0 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a3
Task Name :: a3
Formula :: a3 ~ .
<environment: 0x0000000010898510>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a3 : num 0 4.8 1.9 18.9 7.5 1.4 3.2 0 2.5 0 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a4
Task Name :: a4
Formula :: a4 ~ .
<environment: 0x0000000010364358>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a4 : num 0 1.9 0 0 0 0 3.9 0 0 2.9 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a5
Task Name :: a5
Formula :: a5 ~ .
<environment: 0x0000000008812940>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a5 : num 34.2 6.7 0 1.4 7.5 22.5 5.8 5.5 0 0 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a6
Task Name :: a6
Formula :: a6 ~ .
<environment: 0x0000000008277740>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a6 : num 8.3 0 0 0 4.1 12.6 6.8 8.7 0 0 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
$a7
Task Name :: a7
Formula :: a7 ~ .
<environment: 0x000000000fe02028>
Task Data ::
'data.frame': 198 obs. of 12 variables:
$ a7 : num 0 2.1 9.7 1.4 1 2.9 0 0 0 1.7 ...
$ season: Factor w/ 4 levels "autumn","spring",..: 4 2 1 2 1 4 3 1 4 4 ...
$ size : Factor w/ 3 levels "large","medium",..: 3 3 3 3 3 3 3 3 3 3 ...
$ speed : Factor w/ 3 levels "high","low","medium": 3 3 3 3 3 1 1 1 3 1 ...
$ mxPH : num 8 8.35 8.1 8.07 8.06 8.25 8.15 8.05 8.7 7.93 ...
$ mnO2 : num 9.8 8 11.4 4.8 9 13.1 10.3 10.6 3.4 9.9 ...
$ Cl : num 60.8 57.8 40 77.4 55.4 ...
$ NO3 : num 6.24 1.29 5.33 2.3 10.42 ...
$ NH4 : num 578 370 346.7 98.2 233.7 ...
$ oPO4 : num 105 428.8 125.7 61.2 58.2 ...
$ PO4 : num 170 558.8 187.1 138.7 97.6 ...
$ Chla : num 50 1.3 15.6 1.4 10.5 ...
================================
> res.all <- experimentalComparison(
+ DSs,
+ c(variants('cv.lm'),
+ variants('cv.rpart',se=c(0,0.5,1))
+), + cvSettings(5,10,1234))
>summary(res.all)
-----------------------------------------------------------
== Summary of a Cross Validation Experiment ==
5 x 10 - Fold Cross Validation run with seed = 1234
* Data sets :: a1, a2, a3, a4, a5, a6, a7
* Learners :: cv.lm.v1, cv.rpart.v1, cv.rpart.v2, cv.rpart.v3
* Summary of Experiment Results:
-> Datataset: a1
*Learner: cv.lm.v1
nmse
avg 0.7077282
std 0.1639373
min 0.4661104
max 1.2218455
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 0.6423100
std 0.2399321
min 0.2146359
max 1.1712674
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 0.6569726
std 0.2397636
min 0.2146359
max 1.3356744
invalid 0.0000000
*Learner: cv.rpart.v3
nmse
avg 0.6875212
std 0.2348946
min 0.2886238
max 1.3356744
invalid 0.0000000
-> Datataset: a2
*Learner: cv.lm.v1
nmse
avg 1.0449317
std 0.6276144
min 0.4098820
max 3.6764115
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.0426327
std 0.2005522
min 0.7584102
max 1.9152968
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 1.01626123
std 0.07435826
min 0.99287105
max 1.49933643
invalid 0.00000000
*Learner: cv.rpart.v3
nmse
avg 1.000000e+00
std 2.389599e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
-> Datataset: a3
*Learner: cv.lm.v1
nmse
avg 1.0216439
std 0.3522588
min 0.6818271
max 2.3354516
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.0406844
std 0.1823151
min 0.9234394
max 2.1525733
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 1.000000e+00
std 1.980954e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
*Learner: cv.rpart.v3
nmse
avg 1.000000e+00
std 1.980954e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
-> Datataset: a4
*Learner: cv.lm.v1
nmse
avg 2.1119756
std 3.1181959
min 0.3990221
max 15.4079825
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.0073953
std 0.1065607
min 0.4651041
max 1.4588424
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 1.000000e+00
std 2.774424e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
*Learner: cv.rpart.v3
nmse
avg 1.000000e+00
std 2.774424e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
-> Datataset: a5
*Learner: cv.lm.v1
nmse
avg 0.9316803
std 0.3478931
min 0.4980249
max 2.2706313
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.1793298
std 0.5659329
min 0.6366705
max 3.5927541
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 1.01137139
std 0.07049532
min 1.00000000
max 1.49445097
invalid 0.00000000
*Learner: cv.rpart.v3
nmse
avg 1.000000e+00
std 2.841612e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
-> Datataset: a6
*Learner: cv.lm.v1
nmse
avg 0.9359697
std 0.6045963
min 0.1211761
max 3.9247377
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.0191041
std 0.1991436
min 0.8225254
max 2.3774810
invalid 0.0000000
*Learner: cv.rpart.v2
nmse
avg 1.000000e+00
std 2.451947e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
*Learner: cv.rpart.v3
nmse
avg 1.000000e+00
std 2.451947e-16
min 1.000000e+00
max 1.000000e+00
invalid 0.000000e+00
-> Datataset: a7
*Learner: cv.lm.v1
nmse
avg 1.5238076
std 1.3762973
min 0.8670194
max 8.2205224
invalid 0.0000000
*Learner: cv.rpart.v1
nmse
avg 1.889286
std 3.528353
min 1.000000
max 20.829858
invalid 0.000000
*Learner: cv.rpart.v2
nmse
avg 1.463952
std 2.229237
min 1.000000
max 15.763627
invalid 0.000000
*Learner: cv.rpart.v3
nmse
avg 1.0295055
std 0.2086351
min 1.0000000
max 2.4752731
invalid 0.0000000
================================
其中DSs是7個資料集(a1~a7),用來放入experimentalComparison()中逐一執行。sapply()第一及第三個參數,傳入第二參數(function(x,names.attrs))作為參數。而且第一參數如果是向量,則每一個向量中的每一元素皆會逐一傳入。
paste(s1,s2)是將s1及s2字串連接成新字串。
as.formula(s)是將s字串轉換成R的公式(formula)物件。
同樣也可以圖形顯示:
================================
> plot(res.all)
================================
7種測試之NMSE |
從上圖可以發現,大部分結果不是很好,因為NMSE明顯>1。{問題:是指平均值嗎?}
我們也可以使用bestScores()找出最佳的模型:
================================
> bestScores(res.all)
----------------------------------------------------------
$a1
system score
nmse cv.rpart.v1 0.64231
$a2
system score
nmse cv.rpart.v3 1
$a3
system score
nmse cv.rpart.v2 1
$a4
system score
nmse cv.rpart.v2 1
$a5
system score
nmse cv.lm.v1 0.9316803
$a6
system score
nmse cv.lm.v1 0.9359697
$a7
system score
nmse cv.rpart.v3 1.029505
================================
除了a1以外,其結果是令人失望的。從圖『7種測試之NMSE』顯示結果的變動性,可以提供一個觀點,使用『綜效(ensemble)』試驗方法可能是一個好的替代方案。綜效是一種模型建構方法,主要是解決個別模型的限制,作法是建立一大堆各式模型,然後綜和這些模型的預測。有許多途徑可以做到綜效,除了多樣的模型獲得方式(無不同的訓練樣本、不同的變數、不同的模型技術),也可以不同預測方式(如投票、平均等)。『隨機森林(Random forests)』是由一大群樹狀模型(迴歸樹或分類樹)所組成,每棵樹都是完整的(未經修剪),而且每一步成長節點的最佳切割是從隨機屬性子集中選擇。最終結果則是平均所有樹的結果。randomForest套件randomForest()即實作這個概念。下列程式重複前述交叉驗證試驗,其中包含第三種隨機森林:
================================
> library(randomForest)
> cv.rf <- function(form,train,test,...) {
+ m <- randomForest(form,train,...)
+ p <- predict(m,test)
+ mse <- mean((p-resp(form,test))^2)
+ c(nmse=mse/mean((mean(resp(form,train))-resp(form,test))^2))
+}
> res.all <- experimentalComparison(
+ DSs,
+ c(variants('cv.lm'),
+ variants('cv.rpart',se=c(0,0.5,1)),
+ variants('cv.rf',ntree=c(200,500,700))
+), + cvSettings(5,10,1234))
================================
執行過程會花一點時間,最後使用bestScores()檢視執行結果:
================================
> bestScores(res.all)
----------------------------------------------------------
$a1
system score
nmse cv.rf.v3 0.5447361
$a2
system score
nmse cv.rf.v3 0.7777851
$a3
system score
nmse cv.rf.v2 0.9946093
$a4
system score
nmse cv.rf.v3 0.9591182
$a5
system score
nmse cv.rf.v1 0.7907947
$a6
system score
nmse cv.rf.v3 0.9126477
$a7
system score
nmse cv.rpart.v3 1.029505
================================
結果顯示,最佳的分數來自隨機森林模型,同樣的結果並不總是好的,尤其是a7。bestScores()顯示的資料,並未告訴我們這些最佳模型與其他模型之差異是否是統計上的顯著,亦即使用其他隨機資料可以得到類似結果的信賴度。使用compAnalysis()便可以得到這樣的資訊。cv.rf.v3在a1、a2、a4、a6表現最佳,下列驗證統計的顯著程度:
================================
> compAnalysis(res.all,against='cv.rf.v3',datasets=c('a1','a2','a4','a6'))
-----------------------------------------------------------
================================
其中sign.x提供了我們想要的資訊,無標示即表示cv.rf.v3於其他的模型差異的統計性顯著度低於95%,加號表示模型的平均評估指標顯著高於cv.rf.v3,而NMSE應該是愈低愈好。負號代表相對的意義。
你可以發現,不同隨機森林之間的差異,及與其他模型之間的差異,並不是統計上顯著的。相對其他模型,在大多數情況下隨機森林擁有較明顯的優勢。
*******************************************
問題:與其他模型有顯著差異,代表何種意義?
*******************************************
預測七種水藻的數量
資料採礦的主要目標是預測140樣本中七種水藻的數量。使用的模型是前述驗證過最佳模型。也就是呼叫bestScores()所看到的。也就是下列其中之一cv.rf.v3、cv.rf.v2、cv.rf.v1、cv.rpart.v3:
================================
//取得各水藻的最佳模型名稱
> bestModelsNames <- sapply(bestScores(res.all), function(x) x['nmse','system'])
> learners <- c(rf='randomForest',rpart='rpartXse')
//取得上述個名稱的實際模型名稱
> funcs <- learners[sapply(strsplit(bestModelsNames,'\\.'),
+ function(x) x[2])]
//存放個模型的參數設定
> parSetts <- lapply(bestModelsNames,
+ function(x) getVariant(x,res.all)@pars)
> bestModels <- list()
> for(a in 1:7) {
+ form <- as.formula(paste(names(clean.algae)[11+a],'~.'))
+ bestModels[[a]] <- do.call(funcs[a],
+ c(list(form,clean.algae[,c(1:11,11+a)]),parSetts[[a]]))
+}
================================
strsplit():分解字串,第一個字串是待分解字串,第二個參數是分隔的字串。如strsplit(cv.rf.v3,'\\.')結果便是'cv' 'rf' 'v3'。
getVariant()傳入各個模型名稱,傳回learner物件。@pars稱為插槽(slots)可視為物件的屬性。
do.call():呼叫第一個參數的函數名稱,而把第二個參數當做此函數的參數。
{問題:詳細的函數或類別需再額外研究。}
至此我們得到七個相對的模型。
test.algae內有這140個測試樣本,其中同樣含有缺漏值。我們也可以像之前處理的一樣,使用knnImputation()來補齊缺漏值。但這有些違背了預測建模的規則:『不要使用任何測試集的資訊建立模型』。knnImputation()有額外的參數可以用來避免使用測試集的資料建模:
================================
> clean.test.algae <- knnImputation(test.algae, k = 10, distData = algae[, 1:11])
================================
distData 參數可以讓你指定其他的資料集作為建模來源。其中使用的欄位,並未指定目標變數(1:11),這是因為測試資料集的目標變數並無資料。
接下來我們便可以計算預測值了:
================================
> preds <- matrix(ncol=7,nrow=140)
> for(i in 1:nrow(clean.test.algae))
+ preds[i,] <- sapply(1:7,
+ function(x)
+ predict(bestModels[[x]],clean.test.algae[i,])
+)
================================
至此我們得到這140筆樣本的7個數量預測。在套件中還有一個資料集algae.sols含有這140筆樣本的7個實際數量,我們可以用來計算這些模型的NMSE。
================================
//計算每種藻類的實際數量平均值
> avg.preds <- apply(algae[,12:18],2,mean)
//計算每種藻類預測結果的NMSE
> apply( ((algae.sols-preds)^2), 2,mean) /
+ apply( (scale(algae.sols,avg.preds,F)^2),2,mean)
----------------------------------------------------------
================================
scale()用來歸一化(normalize)資料集,其作法是第一個參數減第二個參數再除於第三個參數,除非這個參數是FALSE。(此處運算皆指向量)
最後的結果顯示,a7很難獲得好的分數,而其他6種則頗具競爭力,尤其是a1。因此我們可以結論,使用適當的模型選擇程序,可以獲得不錯的分數。
================================
//取得各水藻的最佳模型名稱
> bestModelsNames <- sapply(bestScores(res.all), function(x) x['nmse','system'])
> learners <- c(rf='randomForest',rpart='rpartXse')
//取得上述個名稱的實際模型名稱
> funcs <- learners[sapply(strsplit(bestModelsNames,'\\.'),
+ function(x) x[2])]
//存放個模型的參數設定
> parSetts <- lapply(bestModelsNames,
+ function(x) getVariant(x,res.all)@pars)
> bestModels <- list()
> for(a in 1:7) {
+ form <- as.formula(paste(names(clean.algae)[11+a],'~.'))
+ bestModels[[a]] <- do.call(funcs[a],
+ c(list(form,clean.algae[,c(1:11,11+a)]),parSetts[[a]]))
+}
================================
strsplit():分解字串,第一個字串是待分解字串,第二個參數是分隔的字串。如strsplit(cv.rf.v3,'\\.')結果便是'cv' 'rf' 'v3'。
getVariant()傳入各個模型名稱,傳回learner物件。@pars稱為插槽(slots)可視為物件的屬性。
do.call():呼叫第一個參數的函數名稱,而把第二個參數當做此函數的參數。
{問題:詳細的函數或類別需再額外研究。}
至此我們得到七個相對的模型。
test.algae內有這140個測試樣本,其中同樣含有缺漏值。我們也可以像之前處理的一樣,使用knnImputation()來補齊缺漏值。但這有些違背了預測建模的規則:『不要使用任何測試集的資訊建立模型』。knnImputation()有額外的參數可以用來避免使用測試集的資料建模:
================================
> clean.test.algae <- knnImputation(test.algae, k = 10, distData = algae[, 1:11])
================================
distData 參數可以讓你指定其他的資料集作為建模來源。其中使用的欄位,並未指定目標變數(1:11),這是因為測試資料集的目標變數並無資料。
接下來我們便可以計算預測值了:
================================
> preds <- matrix(ncol=7,nrow=140)
> for(i in 1:nrow(clean.test.algae))
+ preds[i,] <- sapply(1:7,
+ function(x)
+ predict(bestModels[[x]],clean.test.algae[i,])
+)
================================
至此我們得到這140筆樣本的7個數量預測。在套件中還有一個資料集algae.sols含有這140筆樣本的7個實際數量,我們可以用來計算這些模型的NMSE。
================================
//計算每種藻類的實際數量平均值
> avg.preds <- apply(algae[,12:18],2,mean)
//計算每種藻類預測結果的NMSE
> apply( ((algae.sols-preds)^2), 2,mean) /
+ apply( (scale(algae.sols,avg.preds,F)^2),2,mean)
----------------------------------------------------------
================================
scale()用來歸一化(normalize)資料集,其作法是第一個參數減第二個參數再除於第三個參數,除非這個參數是FALSE。(此處運算皆指向量)
最後的結果顯示,a7很難獲得好的分數,而其他6種則頗具競爭力,尤其是a1。因此我們可以結論,使用適當的模型選擇程序,可以獲得不錯的分數。
結論
這個案例所展現的相關技術有:
- 資料視覺化
- 描述性統計
- 處理變數缺漏值的策略
- 迴歸分析作業
- 迴歸分析的評估指標
- 多線性迴歸分析
- 迴歸樹分析
- 運用k-fold交叉驗證對各種模型的比較選擇
- 同時使用各種模型及隨機森林的應用
對R的基本操作技術有:
- 從文字檔載入資料
- 如何從料集中取得相關的統計描述
- 基本的資料視覺化
- 處理資料集中的缺漏值
- 取得迴歸模型
- 運用上述的模型預測測試資料的預測值
沒有留言:
張貼留言