速度數百倍之差,有人斷言KNN面臨淘汰,更快更強的ANN將取而代之 - 知乎

語言: CN / TW / HK
資料科學經典演算法 KNN 已被嫌慢,ANN 比它快 380 倍。

選自towardsdatascience,作者:Marie Stephen Leo,機器之心編譯,編輯:小舟、杜偉。

在模式識別領域中,K - 近鄰演算法(K-Nearest Neighbor, KNN)是一種用於分類和迴歸的非引數統計方法。K - 近鄰演算法非常簡單而有效,它的模型表示就是整個訓練資料集。就原理而言,對新資料點的預測結果是通過在整個訓練集上搜索與該資料點最相似的 K 個例項(近鄰)並且總結這 K 個例項的輸出變數而得出的。KNN 可能需要大量的記憶體或空間來儲存所有資料,並且使用距離或接近程度的度量方法可能會在維度非常高的情況下(有許多輸入變數)崩潰,這可能會對演算法在你的問題上的效能產生負面影響。這就是所謂的維數災難。

近似最近鄰演算法(Approximate Nearest Neighbor, ANN)則是一種通過犧牲精度來換取時間和空間的方式從大量樣本中獲取最近鄰的方法,並以其儲存空間少、查詢效率高等優點引起了人們的廣泛關注。

近日,一家技術公司的資料科學主管 Marie Stephen Leo 撰文對 KNN 與 ANN 進行了比較,結果表明,在搜尋到最近鄰的相似度為 99.3% 的情況下,ANN 比 sklearn 上的 KNN 快了 380 倍。

作者表示,幾乎每門資料科學課程中都會講授 KNN 演算法,但它正在走向「淘汰」!

KNN 簡述

在機器學習社群中,找到給定項的「K」個相似項被稱為相似性搜尋或最近鄰(NN)搜尋。最廣為人知的 NN 搜尋演算法是 KNN 演算法。在 KNN 中,給定諸如手機電商目錄之類的物件集合,則對於任何新的搜尋查詢,我們都可以從整個目錄中找到少量(K 個)最近鄰。例如,在下面示例中,如果設定 K = 3,則每個「iPhone」的 3 個最近鄰是另一個「iPhone」。同樣,每個「Samsung」的 3 個最近鄰也都是「Samsung」。

KNN 存在的問題

儘管 KNN 擅長查詢相似項,但它使用詳細的成對距離計算來查詢鄰居。如果你的資料包含 1000 個項,如若找出新產品的 K=3 最近鄰,則演算法需要對資料庫中所有其他產品執行 1000 次新產品距離計算。這還不算太糟糕,但是想象一下,現實世界中的客戶對客戶(Customer-to-Customer,C2C)市場,其中的資料庫包含數百萬種產品,每天可能會上傳數千種新產品。將每個新產品與全部數百萬種產品進行比較是不划算的,而且耗時良久,也就是說這種方法根本無法擴充套件。

解決方案

將最近鄰演算法擴充套件至大規模資料的方法是徹底避開暴力距離計算,使用 ANN 演算法。

近似最近距離演算法(ANN)

嚴格地講,ANN 是一種在 NN 搜尋過程中允許少量誤差的演算法。但在實際的 C2C 市場中,真實的鄰居數量比被搜尋的 K 近鄰數量要多。與暴力 KNN 相比,人工神經網路可以在短時間內獲得卓越的準確性。ANN 演算法有以下幾種:

  • Spotify 的 ANNOY
  • Google 的 ScaNN
  • Facebook 的 Faiss
  • HNSW

分層的可導航小世界(Hierarchical Navigable Small World, HNSW)

在 HNSW 中,作者描述了一種使用多層圖的 ANN 演算法。在插入元素階段,通過指數衰減概率分佈隨機選擇每個元素的最大層,逐步構建 HNSW 圖。這確保 layer=0 時有很多元素能夠實現精細搜尋,而 layer=2 時支援粗放搜尋的元素數量少了 e^-2。最近鄰搜尋從最上層開始進行粗略搜尋,然後逐步向下處理,直至最底層。使用貪心圖路徑演算法遍歷圖,並找到所需鄰居數量。

HNSW 圖結構。最近鄰搜尋從最頂層開始(粗放搜尋),在最底層結束(精細搜尋)。

HNSW Python 包

整個 HNSW 演算法程式碼已經用帶有 Python 繫結的 C++ 實現了,使用者可以通過鍵入以下命令將其安裝在機器上:pip install hnswlib。安裝並匯入軟體包之後,建立 HNSW 圖需要執行一些步驟,這些步驟已經被封裝到了以下函式中:

import hnswlib
import numpy as npdef fit_hnsw_index(features, ef=100, M=16, save_index_file=False):
    # Convenience function to create HNSW graph
    # features : list of lists containing the embeddings
    # ef, M: parameters to tune the HNSW algorithm
    
    num_elements = len(features)
    labels_index = np.arange(num_elements)    EMBEDDING_SIZE = len(features[0])    # Declaring index
    # possible space options are l2, cosine or ip
    p = hnswlib.Index(space='l2', dim=EMBEDDING_SIZE)    # Initing index - the maximum number of elements should be known
    p.init_index(max_elements=num_elements, ef_construction=ef, M=M)    # Element insertion
    int_labels = p.add_items(features, labels_index)    # Controlling the recall by setting ef
    # ef should always be > k
    p.set_ef(ef) 
    
    # If you want to save the graph to a file
    if save_index_file:
         p.save_index(save_index_file)
    
    return p

建立 HNSW 索引後,查詢「K」個最近鄰就僅需以下這一行程式碼:

ann_neighbor_indices, ann_distances = p.knn_query(features, k)

KNN 和 ANN 基準實驗

計劃

首先下載一個 500K + 行的大型資料集。然後將使用預訓練 fasttext 句子向量將文字列轉換為 300d 嵌入向量。然後將在不同長度的輸入資料 [1000. 10000, 100000, len(data)] 上訓練 KNN 和 HNSW ANN 模型,以度量資料大小對速度的影響。最後將查詢兩個模型中的 K=10 和 K=100 時的最近鄰,以度量「K」對速度的影響。首先匯入必要的包和模型。這需要一些時間,因為需要從網路上下載 fasttext 模型。

# Imports
# For input data pre-processing
import json
import gzip
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import fasttext.util
fasttext.util.download_model('en', if_exists='ignore') # English pre-trained model
ft = fasttext.load_model('cc.en.300.bin')# For KNN vs ANN benchmarking
from datetime import datetime
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
import hnswlib

資料

使用亞[馬遜產品資料集],其中包含「手機及配件」類別中的 527000 種產品。然後執行以下程式碼將其轉換為資料框架。記住僅需要產品 title 列,因為將使用它來搜尋相似的產品。

# Data: http://deepyeti.ucsd.edu/jianmo/amazon/
data = []
with gzip.open('meta_Cell_Phones_and_Accessories.json.gz') as f:
    for l in f:
        data.append(json.loads(l.strip()))# Pre-Processing: https://colab.research.google.com/drive/1Zv6MARGQcrBbLHyjPVVMZVnRWsRnVMpV#scrollTo=LgWrDtZ94w89
# Convert list into pandas dataframe
df = pd.DataFrame.from_dict(data)
df.fillna('', inplace=True)# Filter unformatted rows
df = df[~df.title.str.contains('getTime')]# Restrict to just 'Cell Phones and Accessories'
df = df[df['main_cat']=='Cell Phones & Accessories']# Reset index
df.reset_index(inplace=True, drop=True)# Only keep the title columns
df = df[['title']]# Check the df
print(df.shape)
df.head()

如果全部都可以執行精細搜尋,你將看到如下輸出:

亞馬遜產品資料集。

2嵌入

要對文字資料進行相似性搜尋,則必須首先將其轉換為數字向量。一種快速便捷的方法是使用經過預訓練的網路嵌入層,例如 Facebook [FastText] 提供的嵌入層。由於希望所有行都具有相同的長度向量,而與 title 中的單詞數目無關,所以將在 df 中的 title 列呼叫 get_sentence_vector 方法。

嵌入完成後,將 emb 列作為一個 list 輸入到 NN 演算法中。理想情況下可以在此步驟之前進行一些文字清理預處理。同樣,使用微調的嵌入模型也是一個好主意。

# Title Embedding using FastText Sentence Embedding
df['emb'] = df['title'].apply(ft.get_sentence_vector)# Extract out the embeddings column as a list of lists for input to our NN algos
X = [item.tolist() for item in df['emb'].values]

基準

有了演算法的輸入,下一步進行基準測試。具體而言,在搜尋空間中的產品數量和正在搜尋的 K 個最近鄰之間進行迴圈測試。在每次迭代中,除了記錄每種演算法的耗時以外,還要檢查 pct_overlap,因為一定比例的 KNN 最近鄰也被挑選為 ANN 最近鄰。

注意整個測試在一臺全天候執行的 8 核、30GB RAM 機器上執行大約 6 天,這有些耗時。理想情況下,你可以通過多程序來加快執行速度,因為每次執行都相互獨立。

# Number of products for benchmark loop
n_products = [1000, 10000, 100000, len(X)]# Number of neighbors for benchmark loop
n_neighbors = [10, 100]# Dictionary to save metric results for each iteration
metrics = {'products':[], 'k':[], 'knn_time':[], 'ann_time':[], 'pct_overlap':[]}for products in tqdm(n_products):
    # "products" number of products included in the search space
    features = X[:products]
    
    for k in tqdm(n_neighbors):   
        # "K" Nearest Neighbor search
        # KNN 
        knn_start = datetime.now()
        nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(features)
        knn_distances, knn_neighbor_indices = nbrs.kneighbors(X)
        knn_end = datetime.now()
        metrics['knn_time'].append((knn_end - knn_start).total_seconds())
        
        # HNSW ANN
        ann_start = datetime.now()
        p = fit_hnsw_index(features, ef=k*10)
        ann_neighbor_indices, ann_distances = p.knn_query(features, k)
        ann_end = datetime.now()
        metrics['ann_time'].append((ann_end - ann_start).total_seconds())
        
        # Average Percent Overlap in Nearest Neighbors across all "products"
        metrics['pct_overlap'].append(np.mean([len(np.intersect1d(knn_neighbor_indices[i], ann_neighbor_indices[i]))/k for i in range(len(features))]))
        
        metrics['products'].append(products)
        metrics['k'].append(k)
        
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv('metrics_df.csv', index=False)
metrics_df

執行結束時輸出如下所示。從表中已經能夠看出,HNSW ANN 完全超越了 KNN。

以表格形式呈現的結果。

結果

以圖表的形式檢視基準測試的結果,以真正瞭解二者之間的差異,其中使用標準的 matplotlib 程式碼來繪製這些圖表。這種差距是驚人的。根據查詢 K=10 和 K=100 最近鄰所需的時間,HNSW ANN 將 KNN 徹底淘汰。當搜尋空間包含約 50 萬個產品時,在 ANN 上搜索 100 個最近鄰的速度是 KNN 的 380 倍,同時兩者搜尋到最近鄰的相似度為 99.3%。

在搜尋空間包含 500K 個元素,搜尋空間中每個元素找到 K=100 最近鄰時,HNSW ANN 的速度比 Sklearn 的 KNN 快 380 倍。

在搜尋空間包含 500K 個元素,搜尋空間中每個元素找到 K=100 最近鄰時,HNSW ANN 和 KNN 搜尋到最近鄰的相似度為 99.3%。

基於以上結果,作者認為可以大膽地說:「KNN 已死」。

本篇文章的程式碼作者已在 GitHub 上給出:https://github.com/stephenleo/adventures-with-ann/blob/main/knn_is_dead.ipynb

原文連結:https://medium.com/towards-artificial-intelligence/knn-k-nearest-neighbors-is-dead-fc16507eb3e2