CS 231 n knn python授業後の作業

2844 ワード

knnプログラム
http://cs231n.github.io/classification/ L 1 distance
d 1(I 1,I 2)=Σp_Ip 1−Ip 2|
CS231n knn python 课后作业_第1张图片
def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

# loop over all test rows
for i in xrange(num_test):
  # find the nearest training image to the i'th test image num_test 10000
  # using the L1 distance (sum of absolute value differences)       (    ,  10000 )[i:](3072 )    
  distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
  #self.Xtr:50000*3072; X[i,:]:1*3072,    ,      ,  distances   50000*1,         argmin
  min_index = np.argmin(distances) # get the index with smallest distance
  Ypred[i] = self.ytr[min_index] # predict the label of the nearest example

return Ypred
2018.5.11
def compute_distances_no_loops(self,X):
    num_test = X.shape[0]
    num_train = self.X_train.shape[0]
    dists = np.zeros((num_test, num_train))
    test_sum=np.sum(np.square(X),axis=1)
    train_sum=np.sum(np.square(self.X_train),axis=1)
    inner_product=np.dot(X,self.X_train.T)
    dists=np.sqrt(-2*inner_product+test_sum.reshape(-1,1)+train_sum)
    return dists
循環しないで計算します.参照してください.
https://blog.csdn.net/zhyh1435589631/article/details/54236643
CS231n knn python 课后作业_第2张图片
(a−b)2‾恏恏恏恏恏恏恏恏√=a 2+b 2−2 a⑨ρ⑩ρ⑩ρ⑩⑩⑩⑩⑩ρ⑩ρ⑩ρ⑩√(a 2+b)2=a+2−2 a
broadcastのために、最後にM*Nを実現したいです.test_sumは1*Mで、trin_sumは1*Nなので、test(u)をsumを入れ替えるだけでいいです.他は変えずに最後にM*N行列を出力します.
predicat_ケーブル
def predict_labels(self, dists, k=1):
https://blog.csdn.net/guangtishai4957/article/details/79950117
predicat_labels関数の下から2行目のy_pred[i]=np.argmax(np.bincount)の使い方説明
# bincount       
x = np.array([0, 1, 1, 3, 3, 3, 3, 5])  
# bincount            ,        x     ,  
#      x      5, bincount           0 5,                      (   ,         )  
y = np.bincount(x)  
print(y)  
    -》 [1 2 0 4 0 1]  
# numpy  argmax                   ,y     4,       3,        3  
#                                       !!!  
z = np.argmax(y)  
      3