sqlite-comprehend/sqlite_comprehend/utils.py
import click
import boto3
from html.parser import HTMLParser
import json
import configparser
def common_boto3_options(fn):
for decorator in reversed(
(
click.option(
"--access-key",
help="AWS access key ID",
),
click.option(
"--secret-key",
help="AWS secret access key",
),
click.option(
"--session-token",
help="AWS session token",
),
click.option(
"--endpoint-url",
help="Custom endpoint URL",
),
click.option(
"-a",
"--auth",
type=click.File("r"),
help="Path to JSON/INI file containing credentials",
),
)
):
fn = decorator(fn)
return fn
def make_client(service, access_key, secret_key, session_token, endpoint_url, auth):
if auth:
if access_key or secret_key or session_token:
raise click.ClickException(
"--auth cannot be used with --access-key, --secret-key or --session-token"
)
auth_content = auth.read().strip()
if auth_content.startswith("{"):
# Treat as JSON
decoded = json.loads(auth_content)
access_key = decoded.get("AccessKeyId")
secret_key = decoded.get("SecretAccessKey")
session_token = decoded.get("SessionToken")
else:
# Treat as INI
config = configparser.ConfigParser()
config.read_string(auth_content)
# Use the first section that has an aws_access_key_id
for section in config.sections():
if "aws_access_key_id" in config[section]:
access_key = config[section].get("aws_access_key_id")
secret_key = config[section].get("aws_secret_access_key")
session_token = config[section].get("aws_session_token")
break
kwargs = {}
if access_key:
kwargs["aws_access_key_id"] = access_key
if secret_key:
kwargs["aws_secret_access_key"] = secret_key
if session_token:
kwargs["aws_session_token"] = session_token
if endpoint_url:
kwargs["endpoint_url"] = endpoint_url
return boto3.client(service, **kwargs)
# Adapted from Django's strip_tags() implementation
class MLStripper(HTMLParser):
def __init__(self):
super().__init__(convert_charrefs=False)
self.reset()
self.accumulated = []
def handle_data(self, d):
self.accumulated.append(d)
def handle_entityref(self, name):
self.accumulated.append("&%s;" % name)
def handle_charref(self, name):
self.accumulated.append("&#%s;" % name)
def get_string(self):
return "".join(self.accumulated)
def _strip_once(value):
s = MLStripper()
s.feed(value)
s.close()
return s.get_string()
def strip_tags(value):
value = str(value)
while "<" in value and ">" in value:
new_value = _strip_once(value)
if value.count("<") == new_value.count("<"):
# _strip_once wasn't able to detect more tags.
break
value = new_value
return value