home

Menu
  • ripgrep search

datasette-socrata/datasette_socrata/__init__.py

from datasette import hookimpl, Response, Forbidden
from datasette_low_disk_space_hook import space_is_running_low
import asyncio
import datetime
import httpx
import sqlite_utils
from sqlite_utils.utils import TypeTracker, maximize_csv_field_size_limit
import urllib.parse
import re
import json
import io
import csv
import textwrap
 
maximize_csv_field_size_limit()
 
is_valid_id = re.compile(r"^\w{4}\-\w{4}$").match
 
 
class DatasetError(Exception):
    pass
 
 
class ParseError(DatasetError):
    pass
 
 
def parse_url(url):
    # Given a URL like this:
    # https://data.edmonton.ca/Urban-Planning-Economy/General-Building-Permits/24uj-dj8v
    # either returns (domain, id) or raises a ParseError
    bits = urllib.parse.urlparse(url)
    if not bits.netloc:
        raise ParseError("Missing domain")
    domain = bits.netloc
    potential_id = bits.path.split("/")[-1]
    if not is_valid_id(potential_id):
        raise ParseError("Last element of path was not a valid ID")
    return domain, potential_id
 
 
class MetadataError(DatasetError):
    pass
 
 
class DiskSpaceLow(Exception):
    pass
 
 
async def parse_url_fetch_metadata(url):
    domain, id = parse_url(url)
    return domain, id, await fetch_metadata(domain, id)
 
 
async def fetch_metadata(domain, id):
    metadata_url = "https://{}/api/views/{}.json".format(domain, id)
    try:
        async with httpx.AsyncClient() as client:
            metadata_response = await client.get(metadata_url)
            if metadata_response.status_code != 200:
                raise MetadataError("Dataset not found")
            else:
                return metadata_response.json()
    except httpx.HTTPError as e:
        raise MetadataError("HTTP error fetching metadata for dataset: {}".format(e))
 
 
async def get_row_count(domain, id):
    # Fetch the row count too - we ignore errors and keep row_count at None
    async with httpx.AsyncClient() as client:
        count_url = "https://{}/resource/{}.json?$select=count(*)".format(domain, id)
        count_response = await client.get(count_url)
        if count_response.status_code == 200:
            count_data = count_response.json()
            if (
                isinstance(count_data, list)
                and len(count_data) == 1
                and list(count_data[0].keys())[0].startswith("count")
            ):
                return int(list(count_data[0].values())[0])
    return None
 
 
async def import_socrata(request, datasette):
    if not await datasette.permission_allowed(
        request.actor, "import-socrata", default=False
    ):
        raise Forbidden("Permission denied for import-socrata")
 
    async def _error(message, status=400):
        return Response.html(
            await datasette.render_template(
                "datasette_socrata_error.html",
                {
                    "error": message,
                },
                request=request,
            ),
            status=status,
        )
 
    # Config can be used to restrict to a named database
    config = datasette.plugin_config("datasette-socrata") or {}
    configured_database = config.get("database")
 
    supported_databases = [
        db
        for db in datasette.databases.values()
        if db.is_mutable
        and db.name != "_internal"
        and await db.execute_write_fn(
            lambda conn: sqlite_utils.Database(conn).journal_mode == "wal"
        )
    ]
    if not supported_databases and not configured_database:
        return await _error(
            "There are no attached databases which can be written to and are running in WAL mode."
        )
 
    if configured_database:
        if configured_database not in [db.name for db in supported_databases]:
            return await _error(
                "Configured database '{}' is not both writable and running in WAL mode.".format(
                    configured_database
                )
            )
        supported_databases = [
            db for db in supported_databases if db.name == configured_database
        ]
 
    if request.method != "POST":
        url = request.args.get("url") or ""
        error = None
        metadata = None
        row_count = None
        if url:
            try:
                domain, id, metadata = await parse_url_fetch_metadata(url)
            except DatasetError as e:
                error = str(e)
            else:
                row_count = await get_row_count(domain, id)
        return Response.html(
            await datasette.render_template(
                "datasette_socrata.html",
                {
                    "error": error,
                    "url": url,
                    "fetched_metadata": metadata,
                    "row_count": row_count,
                    "databases": [db.name for db in supported_databases],
                },
                request=request,
            )
        )
 
    vars = await request.post_vars()
    url = vars.get("url") or ""
    error = None
    try:
        domain, id, metadata = await parse_url_fetch_metadata(url)
    except DatasetError as e:
        error = str(e)
    else:
        row_count = None
        if not error:
            row_count = await get_row_count(domain, id)
 
    if error:
        return Response.html(
            await datasette.render_template(
                "datasette_socrata.html", {"error": error}, request=request
            )
        )
 
    # Which database?
    if len(supported_databases) == 1:
        database = supported_databases[0]
    else:
        database_name = vars.get("database")
        filtered = [db for db in supported_databases if db.name == database_name]
        if not filtered:
            return await _error("You need to pick a database.")
        database = filtered[0]
 
    # Ensure table exists
    await database.execute_write(
        textwrap.dedent(
            """
        create table if not exists socrata_imports (
            id text primary key,
            name text,
            url text,
            metadata text,
            row_count integer,
            row_progress integer,
            import_started text,
            import_complete text
        );
        """
        )
    )
 
    # Run the actual import here
    # First, write the metadata to the `socrata_imports` table:
    await database.execute_write_fn(
        lambda conn: sqlite_utils.Database(conn)["socrata_imports"].insert(
            {
                "id": metadata["id"],
                "name": metadata["name"],
                "url": url,
                "metadata": json.dumps(metadata),
                "row_count": row_count,
                "row_progress": 0,
                "import_started": datetime.datetime.utcnow().isoformat() + "Z",
                "error": None,
            },
            replace=True,
            alter=True,
        )
    )
    await refresh_in_memory_socrata_metadata(datasette)
 
    # Now start the import, in a task which runs after this request has returned
    table_name = "socrata_" + id.replace("-", "_")
 
    # If the table exists already, delete it
    await database.execute_write_fn(
        lambda conn: sqlite_utils.Database(conn)[table_name].drop(ignore=True)
    )
 
    async def run_the_import_catch_errors():
        try:
            await run_the_import()
        except Exception as error:
            await database.execute_write_fn(
                lambda conn: sqlite_utils.Database(conn)["socrata_imports"].update(
                    id,
                    {"error": str(error)},
                )
            )
 
    async def run_the_import():
        csv_url = "https://{}/api/views/{}/rows.csv".format(domain, id)
 
        async def write_batch(rows):
            if await space_is_running_low(datasette):
                raise DiskSpaceLow("Disk space is running low")
 
            def _write(conn):
                db = sqlite_utils.Database(conn)
                with db.conn:
                    db[table_name].insert_all(rows)
                    db.execute(
                        "update socrata_imports set row_progress = row_progress + ? where id = ?",
                        (len(rows), id),
                    )
 
            return await database.execute_write_fn(_write, block=True)
 
        async with httpx.AsyncClient() as client:
            async with client.stream("GET", csv_url) as response:
                reader = AsyncDictReader(response.aiter_lines())
                tracker = TypeTracker()
                batch = []
                async for row in reader:
                    batch.append(row)
                    if len(batch) >= 100:
                        # Write this batch to disk
                        await write_batch(list(tracker.wrap(batch)))
                        batch = []
                if batch:
                    await write_batch(list(tracker.wrap(batch)))
 
                # Convert columns to detected types
                await database.execute_write_fn(
                    lambda conn: sqlite_utils.Database(conn)[table_name].transform(
                        types=tracker.types
                    )
                )
 
                # Record import completion time
                await database.execute_write_fn(
                    lambda conn: sqlite_utils.Database(conn)["socrata_imports"].update(
                        id,
                        {
                            "import_complete": datetime.datetime.utcnow().isoformat()
                            + "Z"
                        },
                    )
                )
 
    asyncio.ensure_future(run_the_import_catch_errors())
 
    # Wait for up to 1 second for the table to exist, then redirect to it
    i = 0
    while i < 10:
        if (
            await database.execute(
                "select 1 from sqlite_master where tbl_name = ?", [table_name]
            )
        ).first():
            return Response.redirect(
                datasette.urls.table(database=database.name, table=table_name)
            )
        i += 1
        await asyncio.sleep(0.1)
    # Set a message about the table being created and redirect to database page
    datasette.add_message(request, "Import has started", datasette.INFO)
    return Response.redirect(datasette.urls.database(database=database.name))
 
 
class AsyncDictReader:
    def __init__(self, async_line_iterator):
        self.async_line_iterator = async_line_iterator
        self.buffer = io.StringIO()
        self.reader = csv.DictReader(self.buffer)
        self.line_num = 0
 
    def __aiter__(self):
        return self
 
    async def __anext__(self):
        if self.line_num == 0:
            header = await self.async_line_iterator.__anext__()
            self.buffer.write(header)
 
        line = await self.async_line_iterator.__anext__()
 
        if not line:
            raise StopAsyncIteration
 
        self.buffer.write(line)
        self.buffer.seek(0)
 
        try:
            result = next(self.reader)
        except StopIteration as e:
            raise StopAsyncIteration from e
 
        self.buffer.seek(0)
        self.buffer.truncate(0)
        self.line_num = self.reader.line_num
 
        return result
 
 
@hookimpl
def register_routes():
    return [(r"^/-/import-socrata$", import_socrata)]
 
 
@hookimpl
def permission_allowed(actor, action):
    if action == "import-socrata" and actor and actor.get("id") == "root":
        return True
 
 
@hookimpl
def menu_links(datasette, actor):
    async def inner():
        if await datasette.permission_allowed(actor, "import-socrata", default=False):
            return [
                {
                    "href": datasette.urls.path("/-/import-socrata"),
                    "label": "Import from Socrata",
                },
            ]
 
    return inner
 
 
PROGRESS_JS = """
const SOCRATA_PROGRESS_CSS = `
.datasette-socrata-progress {
    position: relative;
}
.datasette-socrata-progress progress {
    -webkit-appearance: none;
    appearance: none;
    border: none;
    width: 100%;
    height: 2em;
    margin-top: 1em;
    margin-bottom: 1em;
}
.datasette-socrata-progress div {
    position: absolute;
    top: 0px;
    left: 10px;
    line-height: 2em;
    color: #fff;
    text-shadow: 1px 1px 3px #2f2c2c;
    font-size: 14px;
}
.datasette-socrata-progress progress::-webkit-progress-bar {
    background-color: #ddd;
}
.datasette-socrata-progress progress::-webkit-progress-value {
    background-color: #124d77;
}
`;
(function() {
    const style = document.createElement("style");
    style.innerHTML = SOCRATA_PROGRESS_CSS;
    const pollUrl = !POLL_URL!;
    document.head.appendChild(style);
    const wrapper = document.createElement('div');
    wrapper.setAttribute('class', 'datasette-socrata-progress');
    const progress = document.createElement('progress');
    wrapper.appendChild(progress);
    const text = document.createElement('div');
    wrapper.appendChild(text);
    progress.setAttribute('value', 0);
    progress.setAttribute('max', 100);
    progress.style.display = 'none';
    progress.innerHTML = 'Importing from Socrata...';
    const table = document.querySelector('table.rows-and-columns');
    table.parentNode.insertBefore(wrapper, table);
    // Only show message about completion if we have polled a different
    // value at least once.
    let hasPolled = false;
    /* Start polling */
    function pollNext() {
        fetch(pollUrl).then(r => r.json()).then(rows => {
            if (rows.length) {
                const row = rows[0];
                if (row.error) {
                    alert(`Error: ${row.error}`);
                    progress.style.display = 'none';
                    return;
                }
                if (row.row_count > row.row_progress) {
                    progress.style.display = 'block';
                    text.innerText = `${row.row_progress.toLocaleString()} / ${row.row_count.toLocaleString()}`;
                    progress.setAttribute('value', row.row_progress);
                    progress.setAttribute('max', row.row_count);
                    setTimeout(pollNext, 2000);
                    hasPolled = true;
                } else {
                    if (hasPolled) {
                        text.innerText = `${row.row_progress.toLocaleString()} rows loaded - refresh the page to see them`;
                    } else {
                        progress.style.display = 'none';
                    }
                }
            }
        });
    }
    pollNext();
})();
"""
 
 
@hookimpl
def extra_body_script(view_name, table, database, datasette):
    if not table:
        return
    dataset_id = table.replace("socrata_", "").replace("_", "-")
    if (
        view_name == "table"
        and table.startswith("socrata_")
        and is_valid_id(dataset_id)
    ):
        return PROGRESS_JS.replace(
            "!POLL_URL!",
            json.dumps(
                "{}.json?id={}&_shape=array&_col=row_count&_col=row_progress&_col=error".format(
                    datasette.urls.table(database=database, table="socrata_imports"),
                    dataset_id,
                )
            ),
        )
 
 
async def refresh_in_memory_socrata_metadata(datasette):
    datasette._socrata_metadata = {}
    databases = {}
    for database_name, db in datasette.databases.items():
        tables = {}
        if "socrata_imports" in await db.table_names():
            table_metadata = {
                row["id"]: dict(row, metadata=json.loads(row["metadata"]))
                for row in await db.execute("select * from socrata_imports")
            }
        else:
            continue
        for table_id, info in table_metadata.items():
            table_metadata = {
                "title": info["name"],
                "source": info["metadata"].get("attribution")
                or info["url"].split("//")[1].split("/")[0],
                "source_url": info["url"],
            }
            description = info["metadata"].get("description")
            if description:
                table_metadata["description"] = description
            column_descriptions = {
                c["name"]: c.get("description")
                for c in (info.get("metadata") or {}).get("columns") or []
                if c.get("description")
            }
            if column_descriptions:
                table_metadata["columns"] = column_descriptions
            if info["metadata"].get("license"):
                license_name = info["metadata"]["license"].get("name")
                license_url = info["metadata"]["license"].get("termsLink")
                if license_name and license_url:
                    table_metadata["license"] = license_name
                    table_metadata["license_url"] = license_url
            tables["socrata_{}".format(table_id.replace("-", "_"))] = table_metadata
        databases[database_name] = {
            "tables": tables,
        }
    datasette._socrata_metadata = {"databases": databases}
 
 
# Populate datasette._socrata_metadata on startup
@hookimpl
def startup(datasette):
    async def inner():
        await refresh_in_memory_socrata_metadata(datasette)
 
    return inner
 
 
@hookimpl
def get_metadata(datasette):
    return getattr(datasette, "_socrata_metadata", None)
 
Powered by Datasette