datasette-graphql/datasette_graphql/utils.py
from base64 import b64decode, b64encode
from collections import namedtuple
from datasette.utils import await_me_maybe
from datasette.plugins import pm
from enum import Enum
import graphene
from graphene.types import generic
import json
import keyword
import urllib
import re
import sqlite_utils
import textwrap
import time
TableMetadata = namedtuple(
"TableMetadata",
(
"columns",
"foreign_keys",
"fks_back",
"pks",
"supports_fts",
"is_view",
"graphql_name",
"graphql_columns",
),
)
class Bytes(graphene.Scalar):
# Replace with graphene.Base64 after graphene 3.0 is released
@staticmethod
def serialize(value):
if isinstance(value, bytes):
return b64encode(value).decode("utf-8")
return value
@classmethod
def parse_literal(cls, node):
if isinstance(node, ast.StringValue):
return cls.parse_value(node.value)
@staticmethod
def parse_value(value):
if isinstance(value, bytes):
return value
else:
return b64decode(value)
types = {
str: graphene.String(),
float: graphene.Float(),
int: graphene.Int(),
bytes: Bytes(),
}
class PageInfo(graphene.ObjectType):
endCursor = graphene.String()
hasNextPage = graphene.Boolean()
class DatabaseSchema:
def __init__(self, schema, table_classes, table_collection_classes):
self.schema = schema
self.table_classes = table_classes
self.table_collection_classes = table_collection_classes
# cache keys are (database, schema_version) tuples
_schema_cache = {}
async def schema_for_database_via_cache(datasette, database=None):
db = datasette.get_database(database)
schema_version = (await db.execute("PRAGMA schema_version")).first()[0]
cache_key = (database, schema_version)
if cache_key not in _schema_cache:
schema = await schema_for_database(datasette, database)
_schema_cache[cache_key] = schema
# Delete other cached versions of this database
to_delete = [
key for key in _schema_cache if key[0] == database and key != cache_key
]
for key in to_delete:
del _schema_cache[key]
return _schema_cache[cache_key]
async def schema_for_database(datasette, database=None):
db = datasette.get_database(database)
hidden_tables = await db.hidden_table_names()
# Perform all introspection in a single call to the execute_fn thread
table_metadata = await db.execute_fn(
lambda conn: introspect_tables(conn, datasette, db.name)
)
# Construct the tableFilter classes
table_filters = {
table: make_table_filter_class(table, meta)
for table, meta in table_metadata.items()
}
# And the table_collection_kwargs
table_collection_kwargs = {}
for table, meta in table_metadata.items():
column_names = meta.graphql_columns.values()
options = list(zip(column_names, column_names))
sort_enum = graphene.Enum.from_enum(
Enum("{}Sort".format(meta.graphql_name), options),
description="Sort by this column",
)
sort_desc_enum = graphene.Enum.from_enum(
Enum("{}SortDesc".format(meta.graphql_name), options),
description="Sort by this column descending",
)
kwargs = dict(
filter=graphene.List(
table_filters[table],
description='Filters e.g. {name: {eq: "datasette"}}',
),
where=graphene.String(
description="Extra SQL where clauses, e.g. \"name='datasette'\""
),
first=graphene.Int(description="Number of results to return"),
after=graphene.String(
description="Start at this pagination cursor (from tableInfo { endCursor })"
),
sort=sort_enum(),
sort_desc=sort_desc_enum(),
)
if meta.supports_fts:
kwargs["search"] = graphene.String(description="Search for this term")
table_collection_kwargs[table] = kwargs
# For each table, expose a graphene.List
to_add = []
table_classes = {}
table_collection_classes = {}
for table, meta in table_metadata.items():
table_name = meta.graphql_name
if table in hidden_tables:
continue
# (columns, foreign_keys, fks_back, pks, supports_fts) = table_meta
table_node_class = await make_table_node_class(
datasette,
db,
table,
table_classes,
table_filters,
table_metadata,
table_collection_classes,
table_collection_kwargs,
)
table_classes[table] = table_node_class
# We also need a table collection class - this is the thing with the
# nodes, edges, pageInfo and totalCount fields for that table
table_collection_class = make_table_collection_class(
table, table_node_class, meta
)
table_collection_classes[table] = table_collection_class
to_add.append(
(
meta.graphql_name,
graphene.Field(
table_collection_class,
**table_collection_kwargs[table],
description="Rows from the {} {}".format(
table, "view" if table_metadata[table].is_view else "table"
)
),
)
)
to_add.append(
(
"resolve_{}".format(meta.graphql_name),
make_table_resolver(
datasette, db.name, table, table_classes, table_metadata
),
)
)
# *_row field
table_row_kwargs = dict(table_collection_kwargs[table])
table_row_kwargs.pop("first")
# Add an argument for each primary key
for pk in meta.pks:
if pk == "rowid" and pk not in meta.columns:
pk_column_type = graphene.Int()
else:
pk_column_type = types[meta.columns[pk]]
table_row_kwargs[meta.graphql_columns.get(pk, pk)] = pk_column_type
to_add.append(
(
"{}_row".format(meta.graphql_name),
graphene.Field(
table_node_class,
args=table_row_kwargs,
description="Single row from the {} {}".format(
table, "view" if table_metadata[table].is_view else "table"
),
),
)
)
to_add.append(
(
"resolve_{}_row".format(meta.graphql_name),
make_table_resolver(
datasette,
db.name,
table,
table_classes,
table_metadata,
pk_args=meta.pks,
return_first_row=True,
),
)
)
if not to_add:
# Empty schema throws a 500 error, so add something here
to_add.append(("empty", graphene.String()))
to_add.append(("resolve_empty", lambda a, b: "schema"))
for extra_fields in pm.hook.graphql_extra_fields(
datasette=datasette, database=db.name
):
extra_tuples = await await_me_maybe(extra_fields)
if extra_tuples:
to_add.extend(extra_tuples)
Query = type(
"Query",
(graphene.ObjectType,),
{key: value for key, value in to_add},
)
database_schema = DatabaseSchema(
schema=graphene.Schema(
query=Query,
auto_camelcase=(datasette.plugin_config("datasette-graphql") or {}).get(
"auto_camelcase", False
),
),
table_classes=table_classes,
table_collection_classes=table_collection_classes,
)
return database_schema
def make_table_collection_class(table, table_class, meta):
table_name = meta.graphql_name
class _Edge(graphene.ObjectType):
cursor = graphene.String()
node = graphene.Field(table_class)
class Meta:
name = "{}Edge".format(table_name)
class _TableCollection(graphene.ObjectType):
totalCount = graphene.Int()
pageInfo = graphene.Field(PageInfo)
nodes = graphene.List(table_class)
edges = graphene.List(_Edge)
def resolve_totalCount(parent, info):
return parent["filtered_table_rows_count"]
def resolve_nodes(parent, info):
return parent["rows"]
def resolve_edges(parent, info):
return [
{
"cursor": path_from_row_pks(row, meta.pks, use_rowid=not meta.pks),
"node": row,
}
for row in parent["rows"]
]
def resolve_pageInfo(parent, info):
return {
"endCursor": parent["next"],
"hasNextPage": parent["next"] is not None,
}
class Meta:
name = "{}Collection".format(table_name)
return _TableCollection
class StringOperations(graphene.InputObjectType):
exact = graphene.String(name="eq", description="Exact match")
not_ = graphene.String(name="not", description="Not exact match")
contains = graphene.String(description="String contains")
endswith = graphene.String(description="String ends with")
startswith = graphene.String(description="String starts with")
gt = graphene.String(description="is greater than")
gte = graphene.String(description="is greater than or equal to")
lt = graphene.String(description="is less than")
lte = graphene.String(description="is less than or equal to")
like = graphene.String(description=r"is like (% for wildcards)")
notlike = graphene.String(description="is not like")
glob = graphene.String(description="glob matches (* for wildcards)")
in_ = graphene.List(graphene.String, name="in", description="in this list")
notin = graphene.List(graphene.String, description="not in this list")
arraycontains = graphene.String(description="JSON array contains this value")
date = graphene.String(description="Value is a datetime on this date")
isnull = graphene.Boolean(description="Value is null")
notnull = graphene.Boolean(description="Value is not null")
isblank = graphene.Boolean(description="Value is null or blank")
notblank = graphene.Boolean(description="Value is not null or blank")
class IntegerOperations(graphene.InputObjectType):
exact = graphene.Int(name="eq", description="Exact match")
not_ = graphene.Int(name="not", description="Not exact match")
gt = graphene.Int(description="is greater than")
gte = graphene.Int(description="is greater than or equal to")
lt = graphene.Int(description="is less than")
lte = graphene.Int(description="is less than or equal to")
in_ = graphene.List(graphene.Int, name="in", description="in this list")
notin = graphene.List(graphene.Int, description="not in this list")
arraycontains = graphene.Int(description="JSON array contains this value")
isnull = graphene.Boolean(description="Value is null")
notnull = graphene.Boolean(description="Value is not null")
isblank = graphene.Boolean(description="Value is null or blank")
notblank = graphene.Boolean(description="Value is not null or blank")
types_to_operations = {
str: StringOperations,
int: IntegerOperations,
float: IntegerOperations,
}
def make_table_filter_class(table, meta):
return type(
"{}Filter".format(meta.graphql_name),
(graphene.InputObjectType,),
{
meta.graphql_columns[column]: (
types_to_operations.get(column_type) or StringOperations
)()
for column, column_type in meta.columns.items()
},
)
async def make_table_node_class(
datasette,
db,
table,
table_classes,
table_filters,
table_metadata,
table_collection_classes,
table_collection_kwargs,
):
meta = table_metadata[table]
fks_by_column = {fk.column: fk for fk in meta.foreign_keys}
table_plugin_config = datasette.plugin_config(
"datasette-graphql", database=db.name, table=table
)
json_columns = []
if table_plugin_config and "json_columns" in table_plugin_config:
json_columns = table_plugin_config["json_columns"]
# Create a node class for this table
plain_columns = []
fk_columns = []
table_dict = {}
if meta.pks == ["rowid"]:
table_dict["rowid"] = graphene.Int()
plain_columns.append("rowid")
columns_to_graphql_names = {}
for colname, coltype in meta.columns.items():
graphql_name = meta.graphql_columns[colname]
columns_to_graphql_names[colname] = graphql_name
if colname in fks_by_column:
fk = fks_by_column[colname]
fk_columns.append((graphql_name, fk.other_table, fk.other_column))
table_dict[graphql_name] = graphene.Field(
make_table_getter(table_classes, fk.other_table)
)
table_dict["resolve_{}".format(graphql_name)] = make_fk_resolver(
db, table, table_classes, fk
)
else:
plain_columns.append(graphql_name)
if colname in json_columns:
table_dict[graphql_name] = generic.GenericScalar()
table_dict["resolve_{}".format(graphql_name)] = resolve_generic
else:
table_dict[graphql_name] = types[coltype]
# Now add the backwards foreign key fields for related items
fk_table_counts = {}
for fk in meta.fks_back:
fk_table_counts[fk.table] = fk_table_counts.get(fk.table, 0) + 1
for fk in meta.fks_back:
filter_class = table_filters[fk.table]
if fk_table_counts[fk.table] > 1:
field_name = "{}_by_{}_list".format(
table_metadata[fk.table].graphql_name,
table_metadata[fk.table].graphql_columns[fk.column],
)
field_description = "Related rows from the {} table (by {})".format(
fk.table, fk.column
)
else:
field_name = "{}_list".format(table_metadata[fk.table].graphql_name)
field_description = "Related rows from the {} table".format(fk.table)
table_dict[field_name] = graphene.Field(
make_table_collection_getter(table_collection_classes, fk.table),
**table_collection_kwargs[fk.table],
description=field_description
)
table_dict["resolve_{}".format(field_name)] = make_table_resolver(
datasette,
db.name,
fk.table,
table_classes,
table_metadata,
related_fk=fk,
)
table_dict["from_row"] = classmethod(
lambda cls, row: cls(
**dict([(meta.graphql_columns.get(k, k), v) for k, v in dict(row).items()])
)
)
table_dict["graphql_name_for_column"] = columns_to_graphql_names.get
async def example_query():
example_query_columns = [" {}".format(c) for c in plain_columns]
# Now add the foreign key columns
for graphql_name, other_table, other_column in fk_columns:
label_column = await db.label_column_for_table(other_table)
# Need to find outthe GraphQL names of other_column (the pk) and
# label_column on other_table
other_table_obj = table_classes[other_table]
example_query_columns.append(
" %s {\n %s\n%s }"
% (
graphql_name,
other_table_obj.graphql_name_for_column(other_column),
(
" %s\n"
% other_table_obj.graphql_name_for_column(label_column)
)
if label_column
else "",
)
)
return (
textwrap.dedent(
"""
{
TABLE {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
COLUMNS
}
}
}
"""
)
.strip()
.replace("TABLE", meta.graphql_name)
.replace("COLUMNS", "\n".join(example_query_columns))
)
table_dict["example_query"] = example_query
return type(meta.graphql_name, (graphene.ObjectType,), table_dict)
def make_table_resolver(
datasette,
database_name,
table_name,
table_classes,
table_metadata,
related_fk=None,
pk_args=None,
return_first_row=False,
):
meta = table_metadata[table_name]
async def resolve_table(
root,
info,
filter=None,
where=None,
first=None,
after=None,
search=None,
sort=None,
sort_desc=None,
**kwargs
):
if first is None:
first = 10
if return_first_row:
first = 1
pairs = []
column_name_rev = {v: k for k, v in meta.graphql_columns.items()}
for filter_ in filter or []:
for column_name, operations in filter_.items():
for operation_name, value in operations.items():
if isinstance(value, list):
value = ",".join(map(str, value))
pairs.append(
[
"{}__{}".format(
column_name_rev[column_name], operation_name.rstrip("_")
),
value,
]
)
if pk_args is not None:
for pk in pk_args:
if kwargs.get(pk) is not None:
pairs.append([pk, kwargs[pk]])
qs = {
"_nofacet": 1,
}
qs.update(pairs)
if after:
qs["_next"] = after
qs["_size"] = first
if search and meta.supports_fts:
qs["_search"] = search
if related_fk:
related_column = meta.graphql_columns.get(
related_fk.column, related_fk.column
)
related_other_column = table_metadata[
related_fk.other_table
].graphql_columns.get(related_fk.other_column, related_fk.other_column)
qs[related_column] = getattr(root, related_other_column)
if where:
qs["_where"] = where
if sort:
qs["_sort"] = column_name_rev[sort.value]
elif sort_desc:
qs["_sort_desc"] = column_name_rev[sort_desc.value]
path_with_query_string = "/{}/{}.json?{}".format(
database_name, table_name, urllib.parse.urlencode(qs)
)
context = info.context
if context and "time_started" in context:
elapsed_ms = (time.monotonic() - context["time_started"]) * 1000
if context["time_limit_ms"] and elapsed_ms > context["time_limit_ms"]:
assert False, "Time limit exceeded: {:.2f}ms > {}ms - {}".format(
elapsed_ms, context["time_limit_ms"], path_with_query_string
)
context["num_queries_executed"] += 1
if (
context["num_queries_limit"]
and context["num_queries_executed"] > context["num_queries_limit"]
):
assert False, "Query limit exceeded: {} > {} - {}".format(
context["num_queries_executed"],
context["num_queries_limit"],
path_with_query_string,
)
data = (await datasette.client.get(path_with_query_string)).json()
data["rows"] = [dict(zip(data["columns"], row)) for row in data["rows"]]
# If any cells are $base64, decode them into bytes objects
for row in data["rows"]:
for key, value in row.items():
if isinstance(value, dict) and value.get("$base64"):
row[key] = b64decode(value["encoded"])
klass = table_classes[table_name]
data["rows"] = [klass.from_row(r) for r in data["rows"]]
if return_first_row:
try:
return data["rows"][0]
except IndexError:
return None
else:
return data
return resolve_table
def make_fk_resolver(db, table, table_classes, fk):
async def resolve_foreign_key(parent, info):
# retrieve the correct column from parent
pk = getattr(parent, fk.column)
sql = "select * from [{}] where [{}] = :v".format(
fk.other_table, fk.other_column
)
params = {"v": pk}
results = await db.execute(sql, params)
fk_class = table_classes[fk.other_table]
try:
return [fk_class.from_row(row) for row in results.rows][0]
except IndexError:
return None
return resolve_foreign_key
def make_table_getter(table_classes, table):
def getter():
return table_classes[table]
return getter
def make_table_collection_getter(table_collection_classes, table):
def getter():
return table_collection_classes[table]
return getter
def path_from_row_pks(row, pks, use_rowid, quote=True):
"""Generate an optionally URL-quoted unique identifier
for a row from its primary keys."""
if use_rowid:
bits = [row.rowid]
else:
bits = [getattr(row, pk) for pk in pks]
if quote:
bits = [urllib.parse.quote_plus(str(bit)) for bit in bits]
else:
bits = [str(bit) for bit in bits]
return ",".join(bits)
def introspect_tables(conn, datasette, db_name):
db = sqlite_utils.Database(conn)
table_names = db.table_names()
view_names = db.view_names()
table_metadata = {}
table_namer = Namer("t")
for table in table_names + view_names:
datasette_table_metadata = datasette.table_metadata(
table=table, database=db_name
)
columns = db[table].columns_dict
foreign_keys = []
pks = []
supports_fts = bool(datasette_table_metadata.get("fts_table"))
fks_back = []
if hasattr(db[table], "foreign_keys"):
# Views don't have .foreign_keys
foreign_keys = [
fk
for fk in db[table].foreign_keys
# filter out keys to tables that do not exist
if fk.other_table in table_names
]
pks = db[table].pks
supports_fts = bool(db[table].detect_fts()) or supports_fts
# Gather all foreign keys pointing back here
collected = []
for t in db.tables:
collected.extend(t.foreign_keys)
fks_back = [f for f in collected if f.other_table == table]
is_view = table in view_names
column_namer = Namer("c")
table_metadata[table] = TableMetadata(
columns=columns,
foreign_keys=foreign_keys,
fks_back=fks_back,
pks=pks,
supports_fts=supports_fts,
is_view=is_view,
graphql_name=table_namer.name(table),
graphql_columns={column: column_namer.name(column) for column in columns},
)
return table_metadata
def resolve_generic(root, info):
json_string = getattr(root, info.field_name, "")
return json.loads(json_string)
_invalid_chars_re = re.compile(r"[^_a-zA-Z0-9]")
class Namer:
def __init__(self, underscore_prefix=""):
# Disallow all Python keywords
# https://github.com/simonw/datasette-graphql/issues/84
self.reserved = set(keyword.kwlist)
# 'description' confuses Graphene
# https://github.com/simonw/datasette-graphql/issues/85
self.reserved.add("description")
# Track names we have already issued:
self.names = set()
self.underscore_prefix = underscore_prefix
def name(self, value):
value = "_".join(value.split())
value = _invalid_chars_re.sub("_", value)
if not value:
value = "_"
if value[0].isdigit():
value = "_" + value
if value in self.reserved:
value += "_"
if value.startswith("__"):
value = "_0_" + value[2:]
suffix = 2
orig = value
if value.startswith("_") and value.endswith("_"):
value = self.underscore_prefix + value
while value in self.names:
value = orig + "_" + str(suffix)
suffix += 1
self.names.add(value)
return value