簡介與概述
此處指的 KMean 名詞上可能與它處不同,如 KMeans、k-means、KMeans 等,中文譯為 K 平均。較不建議譯為中文,很容易與股票裡之 K 平均線作為聯想。
假設有 X 筆資料,現欲將此 X 筆資料分類成 K 類,條件是分配到同一類的資料,必須盡可能的在各方面條件都相仿,此時便適用 KMean。
KMean 演算法屬 clustering 之演算法,以 C 語言而言本身並不難以實現,但實際在應用時往往很少「單純」拿它來應用。但 KMean 改善之演算法非常多 ( 近年來不知有沒有人繼續做這方面改善 ),如 k-mean++、KHS (沒記錯的話應是 KHS) 演算法等。
粗略流程
step 0 : 讀入 DCNT 個資料點,設定相關參數。
step 1 : 初始化。從 DCNT 個資料隨機取出 K 個不同之資料當初始重心。
step 2 : 更新 (資料-叢聚) 對應表。
step 3 : 依據 (資料-叢聚) 對應表 ,重新計算各叢聚之重心。
step 4 : 若未達收斂條件,回到 step 2。
所謂收斂條件,筆者採用三種
1. 限定最大迭代次數。
2. 限定最小變動點 (必設) - 該資料點在迭代過程中,改變所屬叢聚的個動稱之。
3. 限定總距離之變動 - 每筆資料到各所屬叢聚重心之距離總合,若已沒發生變化也視為收斂條件。
一般而言大多是以變動點個數作為收斂條件。但演算法到後期,可能會有資料點,會一直不斷地在兩個叢聚中一直輪流跑,於是再新增限定最大迭代次數作為收斂條件。
定義距離
在 Data mining 中,距離定義其實不少,較常見的為歐氏距離。現假設資料有 DIM 個屬性,則 x 到 y 的距離為
sum = (xi - yi)^2 , for 0 <= i < DIM
distance(x, y) = sum ^ (1/2)
簡單的說即是將 x, y 每個屬性的差值平方後,加總,最後開根號。
但,在 KMean 整個流程跑下來,最後會發現,上面的「開根號」是沒必要的,因它不影響資料到叢聚距離的順序關係,故筆者在寫類似程式時,並不會特別再開根號上去。故此處所定義的距離為所有屬性差值之平方,為 SSE (sum of squared errors)。
所需變數
先假設資料有 DATA_CNT 筆,屬性 DIM 個,欲分成 K 叢。
雖 KMean 算簡單,但對新手而言也非一次就能順利完成。資料結構使用上全選陣列即可。至少必須包含下面三項
1. src_data[DATA_CNT][DIM] :原始資料,共有 DATA_CNT 筆,每筆屬性(維度、構面)有 DIM 個。如原始資料可能為學生成績,50位學生、3科成績,此時 DATA_CNT=50,DIM=3。
2. center[K][DIM]:各叢聚重心,因有 K 個叢聚,且每個叢聚重心也來自其他資料,故也有 DIM 個屬性。
3. table[DATA_CNT]:紀錄每個資料點所屬之重心。
至於下面二項是建議也要用上的,因可以加快程式之進行。
4. dis_k[K][DIM]:紀錄此叢聚裡,屬於此叢聚之所有資料各屬性之總合。如,假設資料有 2 個屬性(DIM=2),且第 1 叢聚有 3 筆資料,分別為 {1,2}, {3,4}, {5,6},則 dis_k[1][0] = 1+3+5= 9,dis_k[1][1] = 2+4+6 = 12。
5. cent_c[K]:紀錄第 k 叢聚裡,共有幾筆資料數。
程式碼
這只是一份 present code,拋開記憶體配置管理,使用最簡單之方式撰之。
/*******************************************************************/
/* */
/* filename : KMeans.c */
/* author : edison.shih/edisonx */
/* compiler : Visual C++ 2008 */
/* date : 2010.03.07 */
/* */
/*******************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
// ------------------------------------
// param define
#define DCNT 100 /* 資料個數 */
#define DIM 3 /* 資料維度 */
#define K 5 /* 叢聚個數 */
#define MAX_ITER 20 /* 最大迭代 */
#define MIN_PT 0 /* 最小變動點 */
#define LOW 20 /* 資料下限 */
#define UP 300 /* 資料上限 */
// ------------------------------------
// variable
double data[DCNT][DIM]; /* 原始資料 */
double cent[DCNT][DIM]; /* 重心 */
double dis_k[K][DIM]; /* 叢聚距離 */
int table[DCNT]; /* 資料所屬叢聚*/
int cent_c[K]; /* 該叢聚資料數*/
// ------------------------------------
// function declare
double cal_dis(double *x, double *y, int dim);
void get_data(); // 取得資料
void kmeans_init(); // 演算法初始化
double update_table(int* ch_pt); // 更新table
void update_cent(); // 更新重心位置
void print_cent(); // 顯示重心位置
// ------------------------------------
// main function
int main()
{
int ch_pt; /* 紀錄變動之點 */
int iter=0; /* 迭代計數器 */
double sse1; /* 上一迭代之sse */
double sse2; /* 此次迭代之sse */
srand((unsigned)time(NULL));
get_data(); /* step 0 - 取得資料 */
kmeans_init(); /* step 1 - 初始化,隨機取得重心 */
sse2 = update_table(&ch_pt); /* step 2 - 更新一次對應表 */
do{
sse1 = sse2, ++iter;
update_cent(); /* step 3 - 更新重心 */
sse2=update_table(&ch_pt); /* step 4 - 更新對應表 */
}while(iter<MAX_ITER && sse1!=sse2 && ch_pt>MIN_PT); // 收斂條件
print_cent(); // 顯示最後重心位置
printf("sse = %.2lf\n", sse2);
printf("ch_pt = %d\n", ch_pt);
printf("iter = %d\n", iter);
return 0;
}
// ------------------------------------
// 計算二點距離
double cal_dis(double *x, double *y, int dim)
{
int i;
double t, sum=0.0;
for(i=0; i<dim; ++i)
t=x[i]-y[i], sum+=t*t;
return sum;
}
// ------------------------------------
// 取得資料,此處以隨機給
void get_data()
{
int i, j;
for(i=0; i<DCNT; ++i)
for(j=0; j<DIM; ++j)
data[i][j] = \
LOW + (double)rand()*(UP-LOW) / RAND_MAX;
}
// ------------------------------------
// 演算化初始化
void kmeans_init()
{
int i, j, k, rnd;
int pick[K];
// 隨機找K 個不同資料點
for(k=0; k<K; ++k){
rnd = rand() % DCNT; // 隨機取一筆
for(i=0; i<k && pick[i]!=rnd; ++i);
if(i==k) pick[k]=rnd; // 沒重覆
else --k; // 有重覆, 再找一次
}
// 將K 個資料點內容複制到重心cent
for(k=0; k<K; ++k)
for(j=0; j<DIM; ++j)
cent[k][j] = data[pick[k]][j];
}
// ------------------------------------
// 更新table, 傳回sse, 存入點之變動數
double update_table(int* ch_pt)
{
int i, j, k, min_k;
double dis, min_dis, t_sse=0.0;
*ch_pt=0; // 變動點數設0
memset(cent_c, 0, sizeof(cent_c)); // 各叢聚資料數清0
memset(dis_k, 0, sizeof(dis_k)); // 各叢聚距離和清0
for(i=0; i<DCNT; ++i){
// 尋找所屬重心
min_dis = cal_dis(data[i], cent[0], DIM);
min_k = 0;
for(k=1;k<K; ++k){
dis = cal_dis(data[i], cent[k], DIM);
if(dis < min_dis)
min_dis=dis, min_k = k;
}
*ch_pt+=(table[i]!=min_k); // 更新變動點數
table[i] = min_k; // 更新所屬重心
++cent_c[min_k]; // 累計重心資料數
t_sse += min_dis; // 累計總重心距離
for(j=0; j<DIM; ++j) // 更新各叢聚總距離
dis_k[min_k][j]+=data[i][j];
}
return t_sse;
}
// ------------------------------------
// 更新重心位置
void update_cent()
{
int k, j;
for(k=0; k<K; ++k)
for(j=0; j<DIM; ++j)
cent[k][j]=dis_k[k][j]/cent_c[k];
}
// ------------------------------------
// 顯示重心位置
void print_cent()
{
int j, k;
for(k=0; k<K; ++k) {
for(j=0; j<DIM; ++j)
printf("%6.2lf ", cent[k][j]);
putchar('\n');
}
}