Kmean分類
項(xiàng)目鏈接:https://github.com/Wchenguang/gglearn/blob/master/KmeansClassifier/講解/KmeansClassifier.ipynb
- 首先, 隨機(jī)確定 K 個(gè)初始點(diǎn)作為質(zhì)心(不必是數(shù)據(jù)中的點(diǎn))。
- 然后將數(shù)據(jù)集中的每個(gè)點(diǎn)分配到一個(gè)簇中, 具體來講, 就是為每個(gè)點(diǎn)找到距其最近的質(zhì)心, 并將其分配該質(zhì)心所對應(yīng)的簇. 這一步完成之后, 每個(gè)簇的質(zhì)心更新為該簇所有點(diǎn)的平均值.
- 重復(fù)上述過程直到數(shù)據(jù)集中的所有點(diǎn)都距離它所對應(yīng)的質(zhì)心最近時(shí)結(jié)束。
import numpy as np
import matplotlib.pyplot as plt
class KmeansClassifier:
def __init__(self, k, distance_method = "o", random_select = True,
plot = False):
self.k = k
self.distance_method = distance_method
self.random_select = random_select
self.plot = plot
if(distance_method == "o"):
self._dist = self._euler_dist
else:
self._dist = self._manhattan_distance
def _euler_dist(self, x1, x):
return np.sqrt(np.multiply(x-x1, x-x1).sum())
def _manhattan_distance(self, x1, x):
return np.abs(x-x1).sum()
def _get_nearest(self, x, center_list, dist):
dists = []
for center in center_list:
dists.append(dist(x, center))
return dists.index(min(dists))
def _fit(self, x, y, dist, x_center_index_list, center_list):
xy_map = np.hstack((x, x_center_index_list, x_center_index_list))
for row in xy_map:
row[-1] = self._get_nearest(row[:-2], center_list, dist)
flag = np.all(xy_map[:, -1] == xy_map[:, -2])
return flag, xy_map[:, -1].reshape(-1, 1), center_list
def _random_center_list(self, x, k):
center_list = np.zeros((x.shape[1], k))
for col in range(x.shape[1]):
col_max = np.max(x[:, col])
col_min = np.min(x[:, col])
center_list[col, :] = col_min + (col_max - col_min) * np.random.rand(1, k)
return center_list.T
def _updata_center_list(self, x, x_center_index_list, center_list):
new_center_list = []
for index in range(len(center_list)):
part_x = x[np.where(
x_center_index_list[:, -1] == index)]
if(0 != part_x.size):
new_center_list.append(np.mean(part_x, axis = 0))
else:
new_center_list.append(np.zeros(part_x.shape[1]))
return new_center_list
def _plot(self, x, x_center_index_list, center_list):
'''
數(shù)據(jù)繪制,只能繪制二維
'''
center_array = np.array(center_list)
for index in range(len(center_list)):
part_x = x[np.where(
x_center_index_list[:, -1] == index)]
plt.scatter(part_x[:, 0], part_x[:, 1])
plt.scatter(center_array[:, 0], center_array[:, 1], marker = "+")
plt.show()
def fit(self, x, y, center_list = None):
if not center_list:
center_list = self._random_center_list(x, self.k)
x_center_index_list = np.zeros(x.shape[0]).reshape(-1, 1)
flag = False
while(True):
flag, x_center_index_list, center_list = self._fit(x,
y, self._dist, x_center_index_list, center_list)
if(flag):
break
center_list = self._updata_center_list(x, x_center_index_list,
center_list)
if(self.plot):
self._plot(x, x_center_index_list, center_list)
return self
x = np.random.randint(1, 100, (50, 2))
y = np.random.randint(1,4,(10, 1))
kmeans_clf = KmeansClassifier(k = 4, plot=True)
kmeans_clf.fit(x, y)
更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號(hào)聯(lián)系: 360901061
您的支持是博主寫作最大的動(dòng)力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長非常感激您!手機(jī)微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
