datasette-big-local/datasette_big_local/__init__.py
from cachetools import TTLCachefrom datasette import hookimplfrom datasette.database import Databasefrom datasette.utils.asgi import Responseimport asyncioimport base64import htmlimport httpximport pathlibimport functoolsimport uuidimport csv as csv_stdimport datetimeimport threadingimport sqlite_utilsfrom sqlite_utils.utils import TypeTrackerfrom urllib.parse import urlencodeimport reALLOWED = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789"split_re = re.compile("(_[0-9a-f]+_)")class Settings:def __init__(self, root_dir, graphql_url, csv_size_limit_mb, login_redirect_url):self.root_dir = root_dirself.graphql_url = graphql_urlself.csv_size_limit_mb = csv_size_limit_mbself.login_redirect_url = login_redirect_urldef get_settings(datasette):plugin_config = datasette.plugin_config("datasette-big-local") or {}return Settings(root_dir=plugin_config.get("root_dir") or ".",graphql_url=plugin_config.get("graphql_url")or "https://api.biglocalnews.org/graphql",csv_size_limit_mb=plugin_config.get("csv_size_limit_mb") or 100,login_redirect_url=plugin_config.get("login_redirect_url")or "https://biglocalnews.org/#/datasette?",)@hookimpldef forbidden(request, datasette):database_name = request.url_vars["database"]project_id = project_uuid_to_id(database_name)url = get_settings(datasette).login_redirect_url + urlencode({"redirect_path": request.full_path,"project_id": project_id,})return Response.redirect(url)def get_cache(datasette):cache = getattr(datasette, "big_local_cache", None)if cache is None:datasette.big_local_cache = cache = TTLCache(maxsize=100, ttl=60 * 5)return cache@hookimpldef permission_allowed(datasette, actor, action, resource):async def inner():if action not in ("view-database", "execute-sql"):# No opinionreturnif resource.startswith("_"):# _internal / _memory etcreturnif not actor:return Falsecache = get_cache(datasette)actor_id = actor["id"]database_name = resource# Check cache to see if actor is allowed to access this databasekey = (actor_id, database_name)result = cache.get(key)if result is not None:return result# Not cached - hit GraphQL API and cache the resultremember_token = actor["token"]# Figure out project ID from UUID database nameproject_id = project_uuid_to_id(database_name)try:await get_project(datasette, project_id, remember_token)result = Trueexcept (ProjectPermissionError, ProjectNotFoundError):result = False# Store in cachecache[key] = resultreturn resultreturn innerclass ProjectPermissionError(Exception):passclass ProjectNotFoundError(Exception):passFILES = """files(first:100) {edges {node {namesize}}}"""async def get_project(datasette, project_id, remember_token, files=False):async with httpx.AsyncClient() as client:response = await client.post(get_settings(datasette).graphql_url,json={"variables": {"id": project_id},"query": """query Node($id: ID!) {node(id: $id) {... on Project {idnameFILES}}}""".replace("FILES", FILES if files else ""),},cookies={"remember_token": remember_token},timeout=30,)if response.status_code != 200:raise ProjectPermissionError(response.text)else:data = response.json()["data"]if data["node"] is None:raise ProjectNotFoundError("Project not found")project = data["node"]# Clean up the files nested dataif files:files_edges = project.pop("files")project["files"] = [edge["node"] for edge in files_edges["edges"]]return projectdef alnum_encode(s):encoded = []for char in s:if char in ALLOWED:encoded.append(char)else:encoded.append("_" + hex(ord(char))[2:] + "_")return "".join(encoded)split_re = re.compile("(_[0-9a-f]+_)")def alnum_decode(s):decoded = []for bit in split_re.split(s):if bit.startswith("_"):hexbit = bit[1:-1]decoded.append(chr(int(hexbit, 16)))else:decoded.append(bit)return "".join(decoded)class OpenError(Exception):passasync def open_project_file(datasette, project_id, filename, remember_token):graphql_endpoint = get_settings(datasette).graphql_urlbody = {"operationName": "CreateFileDownloadURI","variables": {"input": {"fileName": filename,"projectId": project_id,}},"query": """mutation CreateFileDownloadURI($input: FileURIInput!) {createFileDownloadUri(input: $input) {ok {nameuri__typename}err__typename}}""",}async with httpx.AsyncClient() as client:response = await client.post(graphql_endpoint,json=body,cookies={"remember_token": remember_token},timeout=30,)if response.status_code != 200:raise OpenError(response.text)data = response.json()["data"]if data["createFileDownloadUri"]["err"]:raise OpenError(data["createFileDownloadUri"]["err"])# We need to do a HEAD request because the GraphQL endpoint doesn't# check if the file exists, it just signs whatever filename we senturi = data["createFileDownloadUri"]["ok"]["uri"]head_response = await client.head(uri)if head_response.status_code != 200:raise OpenError("File not found")return (uri,head_response.headers["etag"],int(head_response.headers["content-length"]),)def project_id_to_uuid(project_id):return base64.b64decode(project_id).decode("utf-8").split("Project:")[-1]def project_uuid_to_id(project_uuid):return base64.b64encode("Project:{}".format(project_uuid).encode("utf-8")).decode("utf-8")def ensure_database(datasette, project_uuid):# Create a database of that name if one does not exist alreadytry:db = datasette.get_database(project_uuid)except KeyError:root_dir = pathlib.Path(get_settings(datasette).root_dir)# Create empty filedb_path = str(root_dir / "{}.db".format(project_uuid))sqlite_utils.Database(db_path).vacuum()db = datasette.add_database(Database(datasette, path=db_path, is_mutable=True))return dbasync def big_local_open_private(request, datasette):# Same as big_local_open but reads remember_token from a cookieif not request.actor or not request.actor.get("token"):return Response.text("Forbidden", status=403)return await big_local_open_implementation(request, datasette, request.actor["token"])async def big_local_open(request, datasette):return await big_local_open_implementation(request, datasette)async def big_local_open_implementation(request, datasette, remember_token=None):if request.method == "GET":return Response.html("""<form action="/-/big-local-open" method="POST"><p><label>Project ID: <input name="project_id" value="UHJvamVjdDpmZjAxNTBjNi1iNjM0LTQ3MmEtODFiMi1lZjJlMGMwMWQyMjQ="></label></p><p><label>Filename: <input name="filename" value="universities_final.csv"></label></p><p><label>remember_token: <input name="remember_token" value=""></label></p><p><input type="submit"></p></form>""")post = await request.post_vars()required_keys = ["filename", "project_id"]if remember_token is None:required_keys.append("remember_token")bad_keys = [key for key in required_keys if not post.get(key)]if bad_keys:return Response.html("filename, project_id and remember_token POST variables are required",status=400,)filename = post["filename"]project_id = post["project_id"]remember_token = remember_token or post["remember_token"]# Turn project ID into a UUIDproject_uuid = project_id_to_uuid(project_id)db = ensure_database(datasette, project_uuid)# Use GraphQL to check permissions and get the signed URL for this resourcetry:uri, etag, length = await open_project_file(datasette, project_id, filename, remember_token)except OpenError as e:return Response.html("Could not open file: {}".format(html.escape(str(e))), status=400)csv_size_limit_mb = get_settings(datasette).csv_size_limit_mbif length > csv_size_limit_mb * 1024 * 1024:return Response.html("File exceeds size limit of {}MB".format(csv_size_limit_mb), status=400)# uri is valid, do we have the table already?table_name = alnum_encode(filename)if not await db.table_exists(table_name):await import_csv(db, uri, table_name)# Give it a moment to create the progress table and start runningawait asyncio.sleep(0.5)response = Response.redirect("/{}/{}".format(project_uuid, table_name))# Set a cookie so that the user can access this database in future# They might be signed in alreadyif request.actor and request.actor["token"] == remember_token:passelse:# Look up user and set cookieactor = await get_big_local_user(datasette, remember_token)if not actor:return Response.html("Invalid token", status=400)# Rename displayName to displayactor["display"] = actor.pop("displayName")actor["token"] = remember_tokenresponse.set_cookie("ds_actor",datasette.sign({"a": actor}, "actor"),)return responseasync def get_big_local_user(datasette, remember_token):graphql_endpoint = get_settings(datasette).graphql_urlquery = """query {user {iddisplayNameusername}}""".strip()async with httpx.AsyncClient() as client:response = await client.post(graphql_endpoint,json={"query": query},cookies={"remember_token": remember_token},timeout=30,)if response.status_code != 200:return Nonereturn response.json()["data"]["user"]async def big_local_project(datasette, request):if request.method == "GET":return Response.html("""<form action="/-/big-local-project" method="POST"><p><label>Project ID: <input name="project_id" value="UHJvamVjdDpmZjAxNTBjNi1iNjM0LTQ3MmEtODFiMi1lZjJlMGMwMWQyMjQ="></label></p><p><label>remember_token: <input name="remember_token" value=""></label></p><p><input type="submit"></p></form>""")post = await request.post_vars()bad_keys = [key for key in ("project_id", "remember_token") if not post.get(key)]if bad_keys:return Response.html("project_id and remember_token POST variables are required",status=400,)project_id = post["project_id"]remember_token = post["remember_token"]redirect_path = post.get("redirect_path")if redirect_path is not None and not redirect_path.startswith("/"):return Response.html("redirect_path must start with /", status=400)actor = Noneshould_set_cookie = Falseif request.actor and (request.actor["token"] == remember_token):# Does the token match the current actor's token?actor = request.actorelse:# Check remember_token is for a valid actoractor = await get_big_local_user(datasette, remember_token)if not actor:return Response.html("<h1>Invalid token</h1>", status=403)actor["display"] = actor.pop("displayName")actor["token"] = remember_tokenshould_set_cookie = True# Can the actor access the project?try:project = await get_project(datasette, project_id, actor["token"], True)except (ProjectPermissionError, ProjectNotFoundError):return Response.html("<h1>Cannot access project</h1>", status=403)# Figure out UUID for projectproject_uuid = project_id_to_uuid(project_id)# Stash project files in the cachecache = get_cache(datasette)cache_key = "project-files-{}".format(project_id)cache[cache_key] = project["files"]# Ensure database for project existsensure_database(datasette, project_uuid)# Redirect userresponse = Response.redirect(redirect_path or "/{}".format(project_uuid))if should_set_cookie:response.set_cookie("ds_actor",datasette.sign({"a": actor}, "actor"),)return response@hookimpldef extra_template_vars(datasette, view_name, database):async def inner():if view_name != "database":return {}cache = get_cache(datasette)cache_key = "project-files-{}".format(project_uuid_to_id(database))files = cache.get(cache_key) or []if not files:return {}# Filter out just the CSVs that have not yet been importeddb = datasette.get_database(database)table_names = set(await db.table_names())available_files = [filefor file in filesif file["name"].endswith(".csv")and alnum_encode(file["name"]) not in table_namesand file["size"] < get_settings(datasette).csv_size_limit_mb * 1024 * 1024]return {"available_files": available_files,"project_id": project_uuid_to_id(database),}return inner@hookimpldef register_routes():return [(r"^/-/big-local-open$", big_local_open),(r"^/-/big-local-open-private$", big_local_open_private),(r"^/-/big-local-project$", big_local_project),]@hookimpldef skip_csrf(scope):return scope["path"] in ("/-/big-local-open", "/-/big-local-project")async def import_csv(db, url, table_name):task_id = str(uuid.uuid4())def insert_initial_record(conn):database = sqlite_utils.Database(conn)if "_import_progress_" not in database.table_names():database["_import_progress_"].create({"id": str,"table": str,"bytes_todo": int,"bytes_done": int,"rows_done": int,"started": str,"completed": str,},pk="id",)database["_import_progress_"].insert({"id": task_id,"table": table_name,"bytes_todo": None,"bytes_done": 0,"rows_done": 0,"started": str(datetime.datetime.utcnow()),"completed": None,})await db.execute_write_fn(insert_initial_record)# We run this in a thread to avoid blockingthread = threading.Thread(target=functools.partial(fetch_and_insert_csv_in_thread,task_id,url,db,table_name,asyncio.get_event_loop(),),daemon=True,)thread.start()BATCH_SIZE = 100def fetch_and_insert_csv_in_thread(task_id, url, database, table_name, loop):bytes_todo = Nonebytes_done = 0tracker = TypeTracker()def stream_lines():nonlocal bytes_todo, bytes_donewith httpx.stream("GET", url) as r:try:bytes_todo = int(r.headers["content-length"])except TypeError:bytes_todo = Nonefor line in r.iter_lines():bytes_done += len(line)yield linereader = csv_std.reader(stream_lines())headers = next(reader)docs = (dict(zip(headers, row)) for row in reader)def update_progress(data):asyncio.ensure_future(database.execute_write_fn(lambda conn: sqlite_utils.Database(conn)["_import_progress_"].update(task_id, data),block=True,),loop=loop,)def write_batch(docs):asyncio.ensure_future(database.execute_write_fn(lambda conn: sqlite_utils.Database(conn)[table_name].insert_all(docs, alter=True),block=True,),loop=loop,)gathered = []i = 0for doc in tracker.wrap(docs):gathered.append(doc)i += 1if len(gathered) >= BATCH_SIZE:write_batch(gathered)gathered = []# Update progress tableupdate_progress({"rows_done": i,"bytes_todo": bytes_todo,"bytes_done": bytes_done,})if gathered:# Write any remaining rowswrite_batch(gathered)gathered = []# Mark as complete in the tableupdate_progress({"rows_done": i,"bytes_done": bytes_todo,"completed": str(datetime.datetime.utcnow()),})# Update the table's schema typestypes = tracker.typesif not all(v == "text" for v in types.values()):# Transform!asyncio.ensure_future(database.execute_write_fn(lambda conn: sqlite_utils.Database(conn)[table_name].transform(types=types),block=False,),loop=loop,)PROGRESS_BAR_JS = """const PROGRESS_BAR_CSS = `progress {-webkit-appearance: none;appearance: none;border: none;width: 100%;height: 2em;margin-top: 1em;margin-bottom: 1em;}progress::-webkit-progress-bar {background-color: #ddd;}progress::-webkit-progress-value {background-color: #124d77;}`;(function() {// Add CSSconst style = document.createElement("style");style.innerHTML = PROGRESS_BAR_CSS;document.head.appendChild(style);// Append progress barconst progress = document.createElement('progress');progress.setAttribute('value', 0);progress.innerHTML = 'Importing...';progress.style.display = 'none';const table = document.querySelector('table.rows-and-columns');table.parentNode.insertBefore(progress, table);console.log('progress', progress);// Figure out the polling URLlet parts = location.href.split('/');let table_name = parts.pop();parts.push("_import_progress_.json");let pollUrl = parts.join('/') + ('?_col=bytes_todo&_col=bytes_done&table=' + table_name +'&_sort_desc=started&_shape=array&_size=1');// Start pollinglet first = true;function pollNext() {fetch(pollUrl).then(r => r.json()).then(d => {let current = d[0].bytes_done;let total = d[0].bytes_todo;if (first) {progress.setAttribute('max', total);progress.style.display = 'block';first = false;}progress.setAttribute('value', current);if (current < total) {setTimeout(pollNext, 2000);} else {progress.parentNode.removeChild(progress);}});}pollNext();})();"""@hookimpldef extra_body_script(view_name):if view_name == "table":return PROGRESS_BAR_JS