KNN cosine 余弦相似度计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# coding: utf-8
import collections
import numpy as np
import os
from sklearn.neighbors import NearestNeighbors
 
 
def cos(vector1,vector2):
    dot_product = 0.0;
    normA = 0.0;
    normB = 0.0;
    for a,b in zip(vector1,vector2):
        dot_product += a*b
        normA += a**2
        normB += b**2
    if normA == 0.0 or normB==0.0:
        return None
    else:
        return dot_product / ((normA*normB)**0.5)
 
 
def iterbrowse(path):
    for home, dirs, files in os.walk(path):
        for filename in files:
            yield os.path.join(home, filename)
 
 
def get_data(filename):
    white_verify = []
    with open(filename) as f:
        lines = f.readlines()
        for line in lines:
            a = line.split("\t")
            if len(a) != 78:
                print(line)
                raise Exception("fuck")
            white_verify.append([float(n) for n in a[3:]])
    return white_verify
 
unwanted_features = {6, 7, 8, 41,42,43,67,68,69,70,71,72,73,74,75}
 
def get_wanted_data(x):
    return x
    """
    ans = []
    for item in x:
        #row = [data for i, data in enumerate(item) if i+6 in wanted_feature]
        row = [data for i, data in enumerate(item) if i+6 not in unwanted_features]
        ans.append(row)
        #assert len(row) == len(wanted_feature)
        assert len(row) == len(x[0])-len(unwanted_features)
    return ans
    """
 
 
if __name__ == "__main__":
    neg_file = "cc_data/black/black_all.txt"
    pos_file = "cc_data/white/white_all.txt"
    X = []
    y = []
    # if os.path.isfile(pos_file):
    #     if pos_file.endswith('.txt'):
    #         pos_set = np.genfromtxt(pos_file)
    #     elif pos_file.endswith('.npy'):
    #         pos_set = np.load(pos_file)
    #     X.extend(pos_set)
    #     y += [0] * len(pos_set)
    # print("len of X(white):", len(X))
    if os.path.isfile(neg_file):
        if neg_file.endswith('.txt'):
            neg_set = np.genfromtxt(neg_file)
        elif neg_file.endswith('.npy'):
            neg_set = np.load(neg_file)
        X.extend(list(neg_set) * 1)
        y += [1] * (1 * len(neg_set))
    print("len of X:", len(X))
    # print("X sample:", X[:3])
    # print("len of y:", len(y))
    # print("y sample:", y[:3])
    X = [x[3:] for x in X]
    X = get_wanted_data(X)
    # print("filtered X sample:", X[:3])
 
    black_verify = []
    for f in iterbrowse("todo/top"):
        print(f)
        black_verify += get_data(f)
    # print(black_verify)
    black_verify = get_wanted_data(black_verify)
    black_verify_labels = [1] * len(black_verify)
 
    white_verify = get_data("todo/white_verify.txt")
    # print(white_verify)
    white_verify = get_wanted_data(white_verify)
    white_verify_labels = [0] * len(white_verify)
 
    unknown_verify = get_data("todo/pek_feature74.txt")
    unknown_verify = get_wanted_data(unknown_verify)
 
    bd_verify = get_data("guzhaoshen_pek_out.txt")
    # print(unknown_verify)
 
    # samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
    #neigh = NearestNeighbors(n_neighbors=3)
    neigh = NearestNeighbors(n_neighbors=1, metric='cosine')
    neigh.fit(X)
 
    print("neigh.kneighbors(black_verify)")
    nearest_points = (neigh.kneighbors(black_verify))
    print(nearest_points)
    for i, x in enumerate(black_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
 
    #print(neigh.predict(black_verify))
    print("neigh.kneighbors(white_verify)")
    nearest_points = (neigh.kneighbors(white_verify))
    print(nearest_points)
    for i, x in enumerate(white_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
 
    #print(neigh.predict(white_verify))
    print("neigh.kneighbors(unknown_verify)")
    nearest_points = (neigh.kneighbors(unknown_verify))
    print(nearest_points)
    for i, x in enumerate(unknown_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
 
    #print(neigh.predict(unknown_verify))
    print("neigh.kneighbors(self)")
    print(neigh.kneighbors(X[:3]))
 
    #print(neigh.predict(X[:3]))
    print("neigh.kneighbors(bd pek)")
    print(neigh.kneighbors(bd_verify))
 
    nearest_points = (neigh.kneighbors(bd_verify))
    print(nearest_points)
    for i, x in enumerate(bd_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))

 输出示例:

neigh.kneighbors(white_verify)
(array([[ 0.01140831],
       [ 0.0067373 ],
       [ 0.00198682],
       [ 0.00686728],
       [ 0.00210445],
       [ 0.00061413],
       [ 0.00453888]]), array([[11032],
       [  967],
       [11091],
       [13149],
       [11091],
       [19041],
       [13068]]))
(0, array([11032]), 'cosine:', 1.0)
(1, array([967]), 'cosine:', 1.0)
(2, array([11091]), 'cosine:', 1.0)
(3, array([13149]), 'cosine:', 1.0)
(4, array([11091]), 'cosine:', 1.0)
(5, array([19041]), 'cosine:', 1.0)
(6, array([13068]), 'cosine:', 1.0)

样本质量堪忧啊!!!

 

注意:如果是常规knn,计算距离时候记得标准化。如果各个维度的数据属性衡量单位不一样:

1
2
3
4
5
6
7
8
9
10
11
12
13
from sklearn import preprocessing
scaler = preprocessing.StandardScaler().fit(X)
X = scaler.transform(X)
print("standard X sample:", X[:3])
 
black_verify = scaler.transform(black_verify)
print(black_verify)
 
white_verify = scaler.transform(white_verify)
print(white_verify)
 
unknown_verify = scaler.transform(unknown_verify)
print(unknown_verify)

 

posted @   bonelee  阅读(3075)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
点击右上角即可分享
微信分享提示