簡介與概述

 

此處指的 KMean 名詞上可能與它處不同,如 KMeans、k-means、KMeans 等,中文譯為 K 平均。較不建議譯為中文,很容易與股票裡之 K 平均線作為聯想。

假設有 X 筆資料,現欲將此 X 筆資料分類成 K 類,條件是分配到同一類的資料,必須盡可能的在各方面條件都相仿,此時便適用 KMean。

KMean 演算法屬 clustering 之演算法,以 C 語言而言本身並不難以實現,但實際在應用時往往很少「單純」拿它來應用。但 KMean 改善之演算法非常多 ( 近年來不知有沒有人繼續做這方面改善 ),如 k-mean++、KHS (沒記錯的話應是 KHS) 演算法等。

 

粗略流程

 

step 0 : 讀入 DCNT 個資料點,設定相關參數。

step 1 : 初始化。從 DCNT 個資料隨機取出 K 個不同之資料當初始重心。

step 2 : 更新 (資料-叢聚) 對應表。

step 3 : 依據 (資料-叢聚) 對應表 ,重新計算各叢聚之重心。

step 4 : 若未達收斂條件,回到 step 2。

 

所謂收斂條件,筆者採用三種

1. 限定最大迭代次數。

2. 限定最小變動點 (必設) - 該資料點在迭代過程中,改變所屬叢聚的個動稱之。

3. 限定總距離之變動 - 每筆資料到各所屬叢聚重心之距離總合,若已沒發生變化也視為收斂條件。

 

一般而言大多是以變動點個數作為收斂條件。但演算法到後期,可能會有資料點,會一直不斷地在兩個叢聚中一直輪流跑,於是再新增限定最大迭代次數作為收斂條件。

 

定義距離

 

在 Data mining 中,距離定義其實不少,較常見的為歐氏距離。現假設資料有 DIM 個屬性,則 x 到 y 的距離為

sum = (xi - yi)^2 , for  0 <= i < DIM
distance(x, y) = sum ^ (1/2) 

簡單的說即是將 x, y 每個屬性的差值平方後,加總,最後開根號。

但,在 KMean 整個流程跑下來,最後會發現,上面的「開根號」是沒必要的,因它不影響資料到叢聚距離的順序關係,故筆者在寫類似程式時,並不會特別再開根號上去。故此處所定義的距離為所有屬性差值之平方,為 SSE (sum of squared errors)。

 

所需變數

 

先假設資料有 DATA_CNT 筆,屬性 DIM 個,欲分成 K 叢。

雖 KMean 算簡單,但對新手而言也非一次就能順利完成。資料結構使用上全選陣列即可。至少必須包含下面三項

1. src_data[DATA_CNT][DIM] :原始資料,共有 DATA_CNT 筆,每筆屬性(維度、構面)有 DIM 個。如原始資料可能為學生成績,50位學生、3科成績,此時 DATA_CNT=50,DIM=3。

2. center[K][DIM]:各叢聚重心,因有 K 個叢聚,且每個叢聚重心也來自其他資料,故也有 DIM 個屬性。

3. table[DATA_CNT]:紀錄每個資料點所屬之重心。

至於下面二項是建議也要用上的,因可以加快程式之進行。

4. dis_k[K][DIM]:紀錄此叢聚裡,屬於此叢聚之所有資料各屬性之總合。如,假設資料有 2 個屬性(DIM=2),且第 1 叢聚有 3 筆資料,分別為 {1,2}, {3,4}, {5,6},則 dis_k[1][0] = 1+3+5= 9,dis_k[1][1] = 2+4+6 = 12。

5. cent_c[K]:紀錄第 k 叢聚裡,共有幾筆資料數。

 

程式碼

 

這只是一份 present code,拋開記憶體配置管理,使用最簡單之方式撰之。

 

/*******************************************************************/
/*                                                                 */
/*     filename : KMeans.c                                         */
/*     author   : edison.shih/edisonx                              */
/*     compiler : Visual C++ 2008                                  */
/*     date     : 2010.03.07                                       */
/*                                                                 */
/*******************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

// ------------------------------------
// param define
#define DCNT     100 /*
資料個數   */
#define DIM      3   /* 資料維度   */
#define K        5   /* 叢聚個數   */
#define MAX_ITER 20  /* 最大迭代   */
#define MIN_PT   0   /* 最小變動點 */
#define LOW      20  /* 資料下限   */
#define UP       300 /* 資料上限   */

// ------------------------------------
// variable
double data[DCNT][DIM]; /* 原始資料   */
double cent[DCNT][DIM]; /* 重心       */
double dis_k[K][DIM];   /* 叢聚距離   */
int table[DCNT];        /* 資料所屬叢聚*/
int cent_c[K];          /* 該叢聚資料數*/

// ------------------------------------
// function declare
double cal_dis(double *x, double *y, int dim);
void   get_data();               //
取得資料
void   kmeans_init();            // 演算法初始化
double update_table(int* ch_pt); // 更新table
void   update_cent();            // 更新重心位置
void   print_cent();             // 顯示重心位置
// ------------------------------------
// main function
int main()
{
     int     ch_pt;         /*
紀錄變動之點 */
     int     iter=0;        /* 迭代計數器   */
     double sse1;           /* 上一迭代之sse */
     double sse2;           /* 此次迭代之sse */

     srand((unsigned)time(NULL));    
     get_data();                      /* step 0 -
取得資料            */
     kmeans_init();                   /* step 1 -
初始化,隨機取得重心 */
     sse2 = update_table(&ch_pt);     /* step 2 - 更新一次對應表      */
     do{
           sse1 = sse2, ++iter;
           update_cent();             /* step 3 -
更新重心            */
           sse2=update_table(&ch_pt); /* step 4 - 更新對應表          */
     }while(iter<MAX_ITER && sse1!=sse2 && ch_pt>MIN_PT); // 收斂條件

    print_cent(); // 顯示最後重心位置
    printf("sse   = %.2lf\n", sse2);
    printf("ch_pt = %d\n", ch_pt);
    printf("iter = %d\n", iter);
    return 0;
}

// ------------------------------------
//
計算二點距離
double cal_dis(double *x, double *y, int dim)
{
     int i;
     double t, sum=0.0;
     for(i=0; i<dim; ++i)
           t=x[i]-y[i], sum+=t*t;
     return sum;
}

// ------------------------------------
//
取得資料,此處以隨機給
void get_data()
{
     int i, j;
     for(i=0; i<DCNT; ++i)
           for(j=0; j<DIM; ++j)
                data[i][j] = \
                LOW + (double)rand()*(UP-LOW) / RAND_MAX;
}
// ------------------------------------
//
演算化初始化
void   kmeans_init()
{
     int i, j, k, rnd;
     int pick[K];
     //
隨機找K 個不同資料點
     for(k=0; k<K; ++k){
           rnd = rand() % DCNT; //
隨機取一筆
           for(i=0; i<k && pick[i]!=rnd; ++i);
           if(i==k) pick[k]=rnd; //
沒重覆
           else --k; // 有重覆, 再找一次
     }
     //
將K 個資料點內容複制到重心cent
     for(k=0; k<K; ++k)
           for(j=0; j<DIM; ++j)
                cent[k][j] = data[pick[k]][j];
}

// ------------------------------------
//
更新table, 傳回sse, 存入點之變動數
double update_table(int* ch_pt)
{
     int i, j, k, min_k;
     double dis, min_dis, t_sse=0.0;

     *ch_pt=0;                          //
變動點數設0
     memset(cent_c, 0, sizeof(cent_c)); // 各叢聚資料數清0
     memset(dis_k, 0, sizeof(dis_k));   // 各叢聚距離和清0

     for(i=0; i<DCNT; ++i){
           //
尋找所屬重心
           min_dis = cal_dis(data[i], cent[0], DIM);
           min_k   = 0;
           for(k=1;k<K; ++k){
                dis = cal_dis(data[i], cent[k], DIM);
                if(dis < min_dis)
                     min_dis=dis, min_k = k;
           }
           *ch_pt+=(table[i]!=min_k); //
更新變動點數
           table[i] = min_k;          // 更新所屬重心
           ++cent_c[min_k];           // 累計重心資料數        
           t_sse += min_dis;          // 累計總重心距離
           for(j=0; j<DIM; ++j)       // 更新各叢聚總距離
                dis_k[min_k][j]+=data[i][j];        
     }
     return t_sse;
}

// ------------------------------------
//
更新重心位置
void update_cent()
{
     int k, j;
     for(k=0; k<K; ++k)
           for(j=0; j<DIM; ++j)
                cent[k][j]=dis_k[k][j]/cent_c[k];
}

// ------------------------------------
//
顯示重心位置
void   print_cent()
{
     int j, k;
     for(k=0; k<K; ++k) {
           for(j=0; j<DIM; ++j)
                printf("%6.2lf ", cent[k][j]);
           putchar('\n');
     }
}

arrow
arrow
    全站熱搜

    edisonx 發表在 痞客邦 留言(2) 人氣()