home

Menu
  • ripgrep search

datasette-faiss/datasette_faiss/__init__.py

from datasette import hookimpl
import faiss
import json
import numpy as np
import struct
 
indexes = {}
index_ids = {}
 
 
@hookimpl
def startup(datasette):
    # Create indexes for configured tables
    async def inner():
        config = datasette.plugin_config("datasette-faiss")
        if not config:
            return
        tables = config.get("tables") or []
        for database, table in tables:
            await populate_index(datasette, database, table)
 
    return inner
 
 
def faiss_search_with_scores(database, table, embedding, k):
    index = indexes[(database, table)]
    ids = index_ids[(database, table)]
    D, I = index.search(np.array([decode(embedding)]), k)
    return json.dumps([(ids[i], d) for i, d in zip(I[0], D[0])], default=float)
 
 
def faiss_search(database, table, embedding, k):
    index = indexes[(database, table)]
    ids = index_ids[(database, table)]
    _, I = index.search(np.array([decode(embedding)]), k)
    return json.dumps([ids[i] for i in I[0]])
 
 
@hookimpl
def prepare_connection(conn):
    conn.create_function("faiss_search", 4, faiss_search)
    conn.create_function("faiss_search_with_scores", 4, faiss_search_with_scores)
    conn.create_function("faiss_encode", 1, lambda s: encode(json.loads(s)))
    conn.create_function("faiss_decode", 1, lambda b: json.dumps(decode(b)))
    conn.create_aggregate("faiss_agg", 4, FaissAgg)
    conn.create_aggregate("faiss_agg_with_scores", 4, FaissAggWithScores)
 
 
async def populate_index(datasette, database, table):
    db = datasette.get_database(database)
    # For the moment assumes id, embedding
    def _populate(conn):
        rows = conn.execute("select id, embedding from [{}]".format(table)).fetchall()
        ids = [row[0] for row in rows]
        embeddings = [decode(row[1]) for row in rows]
        index = faiss.IndexFlatL2(len(embeddings[0]))
        index.add(np.array(embeddings))
        indexes[(database, table)] = index
        index_ids[(database, table)] = ids
 
    await db.execute_fn(_populate)
 
 
def decode(blob):
    return struct.unpack("f" * (len(blob) // 4), blob)
 
 
def encode(vector):
    return struct.pack("f" * len(vector), *vector)
 
 
class FaissAgg:
    with_scores = False
 
    def __init__(self):
        self.ids = []
        self.embeddings = []
        self.compare_embedding = None
        self.k = None
        self.first = True
 
    def step(self, id, embedding, compare_embedding, k):
        if self.first:
            self.first = False
            self.compare_embedding = decode(compare_embedding)
            self.k = k
        self.ids.append(id)
        self.embeddings.append(decode(embedding))
 
    def finalize(self):
        index = faiss.IndexFlatL2(len(self.compare_embedding))
        index.add(np.array(self.embeddings))
        D, I = index.search(np.array([self.compare_embedding]), self.k)
        if self.with_scores:
            return json.dumps(
                [(self.ids[i], d) for i, d in zip(I[0], D[0])], default=float
            )
        else:
            return json.dumps([self.ids[i] for i in I[0]])
 
 
class FaissAggWithScores(FaissAgg):
    with_scores = True
 
Powered by Datasette