home

Menu
  • ripgrep search

datasette-llm/datasette_llm/__init__.py

import asyncio
from concurrent.futures import ThreadPoolExecutor
from datasette import hookimpl, Response, NotFound
from datasette.database import Database
from datasette.utils import sqlite3
import datetime
import json
import llm
from llm.cli import cli as llm_cli
import pathlib
from sqlite_utils import Database as SqliteUtilsDatabase
 
 
@hookimpl
def register_commands(cli):
    cli.add_command(llm_cli, name="llm")
 
 
@hookimpl
def startup(datasette):
    has_llm_db = False
    try:
        datasette.get_database("llm")
        has_llm_db = True
    except KeyError:
        pass
 
    config = datasette.plugin_config("llm") or {}
    db_path = config.get("db_path") or (llm.user_dir() / "logs.db")
    db_path = pathlib.Path(db_path)
    if db_path.exists() and not has_llm_db:
        datasette.add_database(Database(datasette, path=str(db_path)), name="llm")
 
 
@hookimpl
def register_routes():
    return [
        (r"^/-/llm$", llm_index),
        # Capture conversation_id
        (r"^/-/llm/start$", llm_start),
        (r"^/-/llm/ws/(?P<conversation_id>[0-9a-z]+)$", llm_conversation_ws),
        (r"^/-/llm/(?P<conversation_id>[0-9a-z]+)$", llm_conversation),
    ]
 
 
async def llm_conversation_ws(request, scope, receive, send, datasette):
    if scope["type"] != "websocket":
        return Response.text("ws only", status=400)
 
    conversation_id = request.url_vars["conversation_id"]
    # Must have been initiated already
    initiated = (
        await datasette.get_database("llm").execute(
            "select * from initiated where id = :id", {"id": conversation_id}
        )
    ).first()
    if not initiated:
        return Response.text("Conversation not yet initiated", status=404)
 
    model_id = initiated["model"]
    db = datasette.get_database("llm")
 
    try:
        model = llm.get_model(model_id)
    except llm.UnknownModelError:
        return Response.text("Unknown model", status=400)
 
    if model.needs_key:
        model.key = llm.get_key(None, model.needs_key, model.key_env_var)
 
    first_prompt = None
    # If this conversation has not yet been started, start it
    has_messages = (
        await db.execute(
            "select 1 from responses where conversation_id = :id",
            {"id": conversation_id},
        )
    ).rows
    if not has_messages:
        first_prompt = initiated["prompt"]
 
    def run_in_thread(prompt, system=None):
        response = model.prompt(prompt, system=system)
        for chunk in response:
            yield chunk
        yield {"end": response}
 
    async def execute_prompt(prompt, send):
        async for item in async_wrap(run_in_thread, prompt)():
            if "error" in item:
                await send({"type": "websocket.send", "text": item["error"]})
            else:
                # It might be the 'end'
                if isinstance(item["item"], dict) and "end" in item["item"]:
                    # Log to the DB
                    response = item["item"]["end"]
                    await db.execute_write_fn(
                        lambda conn: response.log_to_db(SqliteUtilsDatabase(conn)),
                        block=False,
                    )
                else:
                    # Send the message to the client
                    await send({"type": "websocket.send", "text": item["item"]})
        await send({"type": "websocket.send", "text": "\n\n"})
 
    while True:
        event = await receive()
        if event["type"] == "websocket.connect":
            await send({"type": "websocket.accept"})
            if first_prompt:
                await execute_prompt(first_prompt, send)
            first_prompt = None
        elif first_prompt or (event["type"] == "websocket.receive"):
            if event["type"] == "websocket.receive":
                message = event["text"]
                decoded = json.loads(message)
                prompt = decoded["prompt"]
            else:
                prompt = first_prompt
                first_prompt = None
            await execute_prompt(prompt, send)
 
        elif event["type"] == "websocket.disconnect":
            break
 
 
async def llm_conversation(request, datasette):
    conversation_id = request.url_vars["conversation_id"]
    # It may have just been initiated but not yet started:
    try:
        initiated = (
            await datasette.get_database("llm").execute(
                "select * from initiated where id = :id", {"id": conversation_id}
            )
        ).first()
    except sqlite3.OperationalError:
        # Table has not been created yet
        initiated = None
    # Or there may be a conversation with responses:
    conversation = (
        await datasette.get_database("llm").execute(
            "select * from conversations where id = :id", {"id": conversation_id}
        )
    ).first()
    responses = (
        await datasette.get_database("llm").execute(
            "select * from responses where conversation_id = :id",
            {"id": conversation_id},
        )
    ).rows
    if not initiated and not conversation and not responses:
        raise NotFound("Conversation not found")
 
    model_id = conversation["model"] if conversation else initiated["model"]
 
    return Response.html(
        await datasette.render_template(
            "llm_conversation.html",
            {
                "model_id": model_id,
                "conversation_id": conversation_id,
                "conversation_title": (conversation["name"] if conversation else None)
                or "Untitled conversation",
                "responses": responses,
                "start_datetime_utc": responses[0]["datetime_utc"]
                if responses
                else None,
                "ws_path": datasette.urls.path("/-/llm/ws/{}".format(conversation_id)),
                # If we expect WebSocket to start streaming in results straight away:
                "show_empty_response": bool(initiated),
                "first_prompt": initiated["prompt"] if initiated else None,
                "first_system_prompt": initiated["system"] if initiated else None,
            },
            request=request,
        )
    )
 
 
async def llm_index(request, datasette):
    db = datasette.get_database("llm")
    previous_conversations = await db.execute(
        """
        select
            conversations.id,
            conversations.name,
            conversations.model,
            min(responses.datetime_utc) as start_datetime_utc,
            count(responses.id) as num_responses
        from conversations
        left join responses on conversations.id = responses.conversation_id
        group by conversations.id, conversations.name, conversations.model
        order by conversations.id desc limit 200;
        """
    )
    return Response.html(
        await datasette.render_template(
            "llm.html",
            {
                "models": [
                    {"model_id": ma.model.model_id, "name": str(ma.model)}
                    for ma in llm.get_models_with_aliases()
                ],
                "start_path": datasette.urls.path("/-/llm/start"),
                "previous_conversations": previous_conversations.rows,
            },
            request=request,
        )
    )
 
 
async def llm_start(request, datasette):
    if request.method != "POST":
        return Response.redirect(datasette.urls.path("/-/llm"))
    form_data = await request.post_vars()
    try:
        model_id = form_data["model_id"]
        prompt = form_data["prompt"]
        system = form_data.get("system")
    except KeyError:
        datasette.add_message(
            request, "Invalid start to the conversation", type=datasette.ERROR
        )
        return Response.redirect(datasette.urls.path("/-/llm"))
    try:
        model = llm.get_model(model_id)
    except llm.UnknownModelError:
        datasette.add_message(request, "Unknown model", type=datasette.ERROR)
        return Response.redirect(datasette.urls.path("/-/llm"))
    # Create the conversation and redirect the user
    conversation = model.conversation()
 
    def store(conn):
        db = SqliteUtilsDatabase(conn)
        db["initiated"].insert(
            {
                "id": conversation.id,
                "model": model_id,
                "prompt": prompt,
                "system": system,
                "actor_id": request.actor.get("id") if request.actor else None,
                "datetime_utc": datetime.datetime.utcnow().isoformat(),
            },
            pk="id",
        )
 
    await datasette.get_database("llm").execute_write_fn(store)
    return Response.redirect(datasette.urls.path("/-/llm/{}".format(conversation.id)))
 
 
END_SIGNAL = object()
 
 
def async_wrap(generator_func, *args, **kwargs):
    def next_item(gen):
        try:
            return next(gen)
        except StopIteration:
            return END_SIGNAL
 
    async def async_generator():
        loop = asyncio.get_running_loop()
        generator = iter(generator_func(*args, **kwargs))
        with ThreadPoolExecutor() as executor:
            while True:
                try:
                    item = await loop.run_in_executor(executor, next_item, generator)
                    if item is END_SIGNAL:
                        break
                    yield {"item": item}
                except Exception as ex:
                    yield {"error": str(ex)}
                    break
 
    return async_generator
 
Powered by Datasette