datasette-big-local/datasette_big_local/__init__.py
from cachetools import TTLCache
from datasette import hookimpl
from datasette.database import Database
from datasette.utils.asgi import Response
import asyncio
import base64
import html
import httpx
import pathlib
import functools
import uuid
import csv as csv_std
import datetime
import threading
import sqlite_utils
from sqlite_utils.utils import TypeTracker
from urllib.parse import urlencode
import re
ALLOWED = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789"
split_re = re.compile("(_[0-9a-f]+_)")
class Settings:
def __init__(self, root_dir, graphql_url, csv_size_limit_mb, login_redirect_url):
self.root_dir = root_dir
self.graphql_url = graphql_url
self.csv_size_limit_mb = csv_size_limit_mb
self.login_redirect_url = login_redirect_url
def get_settings(datasette):
plugin_config = datasette.plugin_config("datasette-big-local") or {}
return Settings(
root_dir=plugin_config.get("root_dir") or ".",
graphql_url=plugin_config.get("graphql_url")
or "https://api.biglocalnews.org/graphql",
csv_size_limit_mb=plugin_config.get("csv_size_limit_mb") or 100,
login_redirect_url=plugin_config.get("login_redirect_url")
or "https://biglocalnews.org/#/datasette?",
)
@hookimpl
def forbidden(request, datasette):
database_name = request.url_vars["database"]
project_id = project_uuid_to_id(database_name)
url = get_settings(datasette).login_redirect_url + urlencode(
{
"redirect_path": request.full_path,
"project_id": project_id,
}
)
return Response.redirect(url)
def get_cache(datasette):
cache = getattr(datasette, "big_local_cache", None)
if cache is None:
datasette.big_local_cache = cache = TTLCache(maxsize=100, ttl=60 * 5)
return cache
@hookimpl
def permission_allowed(datasette, actor, action, resource):
async def inner():
if action not in ("view-database", "execute-sql"):
# No opinion
return
if resource.startswith("_"):
# _internal / _memory etc
return
if not actor:
return False
cache = get_cache(datasette)
actor_id = actor["id"]
database_name = resource
# Check cache to see if actor is allowed to access this database
key = (actor_id, database_name)
result = cache.get(key)
if result is not None:
return result
# Not cached - hit GraphQL API and cache the result
remember_token = actor["token"]
# Figure out project ID from UUID database name
project_id = project_uuid_to_id(database_name)
try:
await get_project(datasette, project_id, remember_token)
result = True
except (ProjectPermissionError, ProjectNotFoundError):
result = False
# Store in cache
cache[key] = result
return result
return inner
class ProjectPermissionError(Exception):
pass
class ProjectNotFoundError(Exception):
pass
FILES = """
files(first:100) {
edges {
node {
name
size
}
}
}
"""
async def get_project(datasette, project_id, remember_token, files=False):
async with httpx.AsyncClient() as client:
response = await client.post(
get_settings(datasette).graphql_url,
json={
"variables": {"id": project_id},
"query": """
query Node($id: ID!) {
node(id: $id) {
... on Project {
id
name
FILES
}
}
}
""".replace(
"FILES", FILES if files else ""
),
},
cookies={"remember_token": remember_token},
timeout=30,
)
if response.status_code != 200:
raise ProjectPermissionError(response.text)
else:
data = response.json()["data"]
if data["node"] is None:
raise ProjectNotFoundError("Project not found")
project = data["node"]
# Clean up the files nested data
if files:
files_edges = project.pop("files")
project["files"] = [edge["node"] for edge in files_edges["edges"]]
return project
def alnum_encode(s):
encoded = []
for char in s:
if char in ALLOWED:
encoded.append(char)
else:
encoded.append("_" + hex(ord(char))[2:] + "_")
return "".join(encoded)
split_re = re.compile("(_[0-9a-f]+_)")
def alnum_decode(s):
decoded = []
for bit in split_re.split(s):
if bit.startswith("_"):
hexbit = bit[1:-1]
decoded.append(chr(int(hexbit, 16)))
else:
decoded.append(bit)
return "".join(decoded)
class OpenError(Exception):
pass
async def open_project_file(datasette, project_id, filename, remember_token):
graphql_endpoint = get_settings(datasette).graphql_url
body = {
"operationName": "CreateFileDownloadURI",
"variables": {
"input": {
"fileName": filename,
"projectId": project_id,
}
},
"query": """
mutation CreateFileDownloadURI($input: FileURIInput!) {
createFileDownloadUri(input: $input) {
ok {
name
uri
__typename
}
err
__typename
}
}
""",
}
async with httpx.AsyncClient() as client:
response = await client.post(
graphql_endpoint,
json=body,
cookies={"remember_token": remember_token},
timeout=30,
)
if response.status_code != 200:
raise OpenError(response.text)
data = response.json()["data"]
if data["createFileDownloadUri"]["err"]:
raise OpenError(data["createFileDownloadUri"]["err"])
# We need to do a HEAD request because the GraphQL endpoint doesn't
# check if the file exists, it just signs whatever filename we sent
uri = data["createFileDownloadUri"]["ok"]["uri"]
head_response = await client.head(uri)
if head_response.status_code != 200:
raise OpenError("File not found")
return (
uri,
head_response.headers["etag"],
int(head_response.headers["content-length"]),
)
def project_id_to_uuid(project_id):
return base64.b64decode(project_id).decode("utf-8").split("Project:")[-1]
def project_uuid_to_id(project_uuid):
return base64.b64encode("Project:{}".format(project_uuid).encode("utf-8")).decode(
"utf-8"
)
def ensure_database(datasette, project_uuid):
# Create a database of that name if one does not exist already
try:
db = datasette.get_database(project_uuid)
except KeyError:
root_dir = pathlib.Path(get_settings(datasette).root_dir)
# Create empty file
db_path = str(root_dir / "{}.db".format(project_uuid))
sqlite_utils.Database(db_path).vacuum()
db = datasette.add_database(Database(datasette, path=db_path, is_mutable=True))
return db
async def big_local_open_private(request, datasette):
# Same as big_local_open but reads remember_token from a cookie
if not request.actor or not request.actor.get("token"):
return Response.text("Forbidden", status=403)
return await big_local_open_implementation(
request, datasette, request.actor["token"]
)
async def big_local_open(request, datasette):
return await big_local_open_implementation(request, datasette)
async def big_local_open_implementation(request, datasette, remember_token=None):
if request.method == "GET":
return Response.html(
"""
<form action="/-/big-local-open" method="POST">
<p><label>Project ID: <input name="project_id" value="UHJvamVjdDpmZjAxNTBjNi1iNjM0LTQ3MmEtODFiMi1lZjJlMGMwMWQyMjQ="></label></p>
<p><label>Filename: <input name="filename" value="universities_final.csv"></label></p>
<p><label>remember_token: <input name="remember_token" value=""></label></p>
<p><input type="submit"></p>
</form>
"""
)
post = await request.post_vars()
required_keys = ["filename", "project_id"]
if remember_token is None:
required_keys.append("remember_token")
bad_keys = [key for key in required_keys if not post.get(key)]
if bad_keys:
return Response.html(
"filename, project_id and remember_token POST variables are required",
status=400,
)
filename = post["filename"]
project_id = post["project_id"]
remember_token = remember_token or post["remember_token"]
# Turn project ID into a UUID
project_uuid = project_id_to_uuid(project_id)
db = ensure_database(datasette, project_uuid)
# Use GraphQL to check permissions and get the signed URL for this resource
try:
uri, etag, length = await open_project_file(
datasette, project_id, filename, remember_token
)
except OpenError as e:
return Response.html(
"Could not open file: {}".format(html.escape(str(e))), status=400
)
csv_size_limit_mb = get_settings(datasette).csv_size_limit_mb
if length > csv_size_limit_mb * 1024 * 1024:
return Response.html(
"File exceeds size limit of {}MB".format(csv_size_limit_mb), status=400
)
# uri is valid, do we have the table already?
table_name = alnum_encode(filename)
if not await db.table_exists(table_name):
await import_csv(db, uri, table_name)
# Give it a moment to create the progress table and start running
await asyncio.sleep(0.5)
response = Response.redirect("/{}/{}".format(project_uuid, table_name))
# Set a cookie so that the user can access this database in future
# They might be signed in already
if request.actor and request.actor["token"] == remember_token:
pass
else:
# Look up user and set cookie
actor = await get_big_local_user(datasette, remember_token)
if not actor:
return Response.html("Invalid token", status=400)
# Rename displayName to display
actor["display"] = actor.pop("displayName")
actor["token"] = remember_token
response.set_cookie(
"ds_actor",
datasette.sign({"a": actor}, "actor"),
)
return response
async def get_big_local_user(datasette, remember_token):
graphql_endpoint = get_settings(datasette).graphql_url
query = """
query {
user {
id
displayName
username
}
}
""".strip()
async with httpx.AsyncClient() as client:
response = await client.post(
graphql_endpoint,
json={"query": query},
cookies={"remember_token": remember_token},
timeout=30,
)
if response.status_code != 200:
return None
return response.json()["data"]["user"]
async def big_local_project(datasette, request):
if request.method == "GET":
return Response.html(
"""
<form action="/-/big-local-project" method="POST">
<p><label>Project ID: <input name="project_id" value="UHJvamVjdDpmZjAxNTBjNi1iNjM0LTQ3MmEtODFiMi1lZjJlMGMwMWQyMjQ="></label></p>
<p><label>remember_token: <input name="remember_token" value=""></label></p>
<p><input type="submit"></p>
</form>
"""
)
post = await request.post_vars()
bad_keys = [key for key in ("project_id", "remember_token") if not post.get(key)]
if bad_keys:
return Response.html(
"project_id and remember_token POST variables are required",
status=400,
)
project_id = post["project_id"]
remember_token = post["remember_token"]
redirect_path = post.get("redirect_path")
if redirect_path is not None and not redirect_path.startswith("/"):
return Response.html("redirect_path must start with /", status=400)
actor = None
should_set_cookie = False
if request.actor and (request.actor["token"] == remember_token):
# Does the token match the current actor's token?
actor = request.actor
else:
# Check remember_token is for a valid actor
actor = await get_big_local_user(datasette, remember_token)
if not actor:
return Response.html("<h1>Invalid token</h1>", status=403)
actor["display"] = actor.pop("displayName")
actor["token"] = remember_token
should_set_cookie = True
# Can the actor access the project?
try:
project = await get_project(datasette, project_id, actor["token"], True)
except (ProjectPermissionError, ProjectNotFoundError):
return Response.html("<h1>Cannot access project</h1>", status=403)
# Figure out UUID for project
project_uuid = project_id_to_uuid(project_id)
# Stash project files in the cache
cache = get_cache(datasette)
cache_key = "project-files-{}".format(project_id)
cache[cache_key] = project["files"]
# Ensure database for project exists
ensure_database(datasette, project_uuid)
# Redirect user
response = Response.redirect(redirect_path or "/{}".format(project_uuid))
if should_set_cookie:
response.set_cookie(
"ds_actor",
datasette.sign({"a": actor}, "actor"),
)
return response
@hookimpl
def extra_template_vars(datasette, view_name, database):
async def inner():
if view_name != "database":
return {}
cache = get_cache(datasette)
cache_key = "project-files-{}".format(project_uuid_to_id(database))
files = cache.get(cache_key) or []
if not files:
return {}
# Filter out just the CSVs that have not yet been imported
db = datasette.get_database(database)
table_names = set(await db.table_names())
available_files = [
file
for file in files
if file["name"].endswith(".csv")
and alnum_encode(file["name"]) not in table_names
and file["size"] < get_settings(datasette).csv_size_limit_mb * 1024 * 1024
]
return {
"available_files": available_files,
"project_id": project_uuid_to_id(database),
}
return inner
@hookimpl
def register_routes():
return [
(r"^/-/big-local-open$", big_local_open),
(r"^/-/big-local-open-private$", big_local_open_private),
(r"^/-/big-local-project$", big_local_project),
]
@hookimpl
def skip_csrf(scope):
return scope["path"] in ("/-/big-local-open", "/-/big-local-project")
async def import_csv(db, url, table_name):
task_id = str(uuid.uuid4())
def insert_initial_record(conn):
database = sqlite_utils.Database(conn)
if "_import_progress_" not in database.table_names():
database["_import_progress_"].create(
{
"id": str,
"table": str,
"bytes_todo": int,
"bytes_done": int,
"rows_done": int,
"started": str,
"completed": str,
},
pk="id",
)
database["_import_progress_"].insert(
{
"id": task_id,
"table": table_name,
"bytes_todo": None,
"bytes_done": 0,
"rows_done": 0,
"started": str(datetime.datetime.utcnow()),
"completed": None,
}
)
await db.execute_write_fn(insert_initial_record)
# We run this in a thread to avoid blocking
thread = threading.Thread(
target=functools.partial(
fetch_and_insert_csv_in_thread,
task_id,
url,
db,
table_name,
asyncio.get_event_loop(),
),
daemon=True,
)
thread.start()
BATCH_SIZE = 100
def fetch_and_insert_csv_in_thread(task_id, url, database, table_name, loop):
bytes_todo = None
bytes_done = 0
tracker = TypeTracker()
def stream_lines():
nonlocal bytes_todo, bytes_done
with httpx.stream("GET", url) as r:
try:
bytes_todo = int(r.headers["content-length"])
except TypeError:
bytes_todo = None
for line in r.iter_lines():
bytes_done += len(line)
yield line
reader = csv_std.reader(stream_lines())
headers = next(reader)
docs = (dict(zip(headers, row)) for row in reader)
def update_progress(data):
asyncio.ensure_future(
database.execute_write_fn(
lambda conn: sqlite_utils.Database(conn)["_import_progress_"].update(
task_id, data
),
block=True,
),
loop=loop,
)
def write_batch(docs):
asyncio.ensure_future(
database.execute_write_fn(
lambda conn: sqlite_utils.Database(conn)[table_name].insert_all(
docs, alter=True
),
block=True,
),
loop=loop,
)
gathered = []
i = 0
for doc in tracker.wrap(docs):
gathered.append(doc)
i += 1
if len(gathered) >= BATCH_SIZE:
write_batch(gathered)
gathered = []
# Update progress table
update_progress(
{
"rows_done": i,
"bytes_todo": bytes_todo,
"bytes_done": bytes_done,
}
)
if gathered:
# Write any remaining rows
write_batch(gathered)
gathered = []
# Mark as complete in the table
update_progress(
{
"rows_done": i,
"bytes_done": bytes_todo,
"completed": str(datetime.datetime.utcnow()),
}
)
# Update the table's schema types
types = tracker.types
if not all(v == "text" for v in types.values()):
# Transform!
asyncio.ensure_future(
database.execute_write_fn(
lambda conn: sqlite_utils.Database(conn)[table_name].transform(
types=types
),
block=False,
),
loop=loop,
)
PROGRESS_BAR_JS = """
const PROGRESS_BAR_CSS = `
progress {
-webkit-appearance: none;
appearance: none;
border: none;
width: 100%;
height: 2em;
margin-top: 1em;
margin-bottom: 1em;
}
progress::-webkit-progress-bar {
background-color: #ddd;
}
progress::-webkit-progress-value {
background-color: #124d77;
}
`;
(function() {
// Add CSS
const style = document.createElement("style");
style.innerHTML = PROGRESS_BAR_CSS;
document.head.appendChild(style);
// Append progress bar
const progress = document.createElement('progress');
progress.setAttribute('value', 0);
progress.innerHTML = 'Importing...';
progress.style.display = 'none';
const table = document.querySelector('table.rows-and-columns');
table.parentNode.insertBefore(progress, table);
console.log('progress', progress);
// Figure out the polling URL
let parts = location.href.split('/');
let table_name = parts.pop();
parts.push("_import_progress_.json");
let pollUrl = parts.join('/') + (
'?_col=bytes_todo&_col=bytes_done&table=' + table_name +
'&_sort_desc=started&_shape=array&_size=1'
);
// Start polling
let first = true;
function pollNext() {
fetch(pollUrl).then(r => r.json()).then(d => {
let current = d[0].bytes_done;
let total = d[0].bytes_todo;
if (first) {
progress.setAttribute('max', total);
progress.style.display = 'block';
first = false;
}
progress.setAttribute('value', current);
if (current < total) {
setTimeout(pollNext, 2000);
} else {
progress.parentNode.removeChild(progress);
}
});
}
pollNext();
})();
"""
@hookimpl
def extra_body_script(view_name):
if view_name == "table":
return PROGRESS_BAR_JS