"""Storage backend based on ScyllaDB.
ScyllaDB is a low-latency distributed key-value store, compatible with the Apache Cassandra
protocol. For details see _`https://www.scylladb.com/`.
"""
import asyncio
import contextlib
import random
import re
from typing import Dict, Iterable, List, Optional, Union
import cassandra.cluster # type: ignore
import cassandra.query # type: ignore
from narrow_down.storage import StorageBackend
QUERY_BATCH_SIZE = 50
def _wrap_future(f: cassandra.cluster.ResponseFuture):
"""Wrap a cassandra Future into an asyncio.Future object.
Based on https://stackoverflow.com/questions/49350346/how-to-wrap-custom-future-to-use-with-asyncio-in-python.
Args:
f: future to wrap
Returns:
And asyncio.Future object which can be awaited.
"""
loop = asyncio.get_event_loop()
aio_future = loop.create_future()
def on_result(result):
loop.call_soon_threadsafe(aio_future.set_result, result)
def on_error(exception, *_):
loop.call_soon_threadsafe(aio_future.set_exception, exception)
f.add_callback(on_result)
f.add_errback(on_error)
return aio_future
[docs]class ScyllaDBStore(StorageBackend):
"""Storage backend for a SimilarityStore using ScyllaDB."""
[docs] def __init__(
self,
cluster_or_session: Union[cassandra.cluster.Cluster, cassandra.cluster.Session],
keyspace: str,
table_prefix: Optional[str] = None,
) -> None:
"""Create a new empty or connect to an existing SQLite database.
Args:
cluster_or_session: Can be a cassandra cluster or a session object.
keyspace: Name of the keyspace to use.
table_prefix: A prefix to use for all table names in the database.
Raises:
ValueError: When the keyspace name is invalid.
"""
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", keyspace):
raise ValueError(f"Invalid keyspace name: {keyspace}")
if table_prefix and not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_prefix):
raise ValueError(f"Invalid table_prefix: {table_prefix}")
if isinstance(cluster_or_session, cassandra.cluster.Cluster):
self._scylla_cluster = cluster_or_session
self._scylla_session = None
else:
self._scylla_cluster = None
self._scylla_session = cluster_or_session
self._keyspace = keyspace
self._table_prefix = table_prefix or ""
self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
@contextlib.contextmanager
def _session(self) -> cassandra.cluster.Session:
"""Get or create a cassandra session."""
if (not self._scylla_session) and self._scylla_cluster:
self._scylla_session = self._scylla_cluster.connect()
yield self._scylla_session
async def _execute(self, session, query, parameters=None, timeout=None):
"""Execute a cassandra query with asyncio."""
return await _wrap_future(
session.execute_async(
query=query,
parameters=parameters,
timeout=timeout or cassandra.cluster._NOT_SET, # pylint: disable=protected-access
)
)
[docs] async def initialize(
self,
) -> "ScyllaDBStore":
"""Initialize the tables in the SQLite database file.
Returns:
self
"""
# Note: CQL does not know unsigned integers.
# So we need to take the 64bit (signed) bigint to hold a 32bit unsigned int safely.
create_settings = cassandra.query.SimpleStatement(
f"CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._table_prefix}settings ("
" key TEXT, "
" value TEXT, "
" PRIMARY KEY(key)"
");",
is_idempotent=True,
)
create_documents = cassandra.query.SimpleStatement(
f"CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._table_prefix}documents ("
" id bigint, "
" doc blob, "
" PRIMARY KEY(id)"
");",
is_idempotent=True,
)
create_buckets = cassandra.query.SimpleStatement(
f"CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._table_prefix}buckets ("
" bucket bigint, "
" hash bigint, "
" doc_id bigint, "
" PRIMARY KEY((bucket, hash), doc_id)"
");",
is_idempotent=True,
)
with self._session() as session:
await self._execute(session, create_settings, timeout=30)
await self._execute(session, create_documents, timeout=30)
await self._execute(session, create_buckets, timeout=30)
self._prepared_statements["set_setting"] = session.prepare(
f"INSERT INTO {self._keyspace}.{self._table_prefix}"
"settings(key,value) VALUES (?,?);"
)
self._prepared_statements["get_setting"] = session.prepare(
f"SELECT value FROM {self._keyspace}.{self._table_prefix}settings WHERE key=?;"
)
self._prepared_statements["set_doc"] = session.prepare(
f"INSERT INTO {self._keyspace}.{self._table_prefix}documents(id,doc) VALUES (?,?);"
)
self._prepared_statements["set_doc_checked"] = session.prepare(
f"INSERT INTO {self._keyspace}.{self._table_prefix}"
"documents(id,doc) VALUES (?,?) IF NOT EXISTS;"
)
self._prepared_statements["get_doc"] = session.prepare(
f"SELECT doc FROM {self._keyspace}.{self._table_prefix}documents WHERE id=?;"
)
self._prepared_statements["del_doc"] = session.prepare(
f"DELETE FROM {self._keyspace}.{self._table_prefix}documents WHERE id=?;"
)
self._prepared_statements["add_doc_to_bucket"] = session.prepare(
f"INSERT INTO {self._keyspace}.{self._table_prefix}"
"buckets(bucket,hash,doc_id) VALUES (?,?,?);"
)
self._prepared_statements["get_docs_from_bucket"] = session.prepare(
f"SELECT doc_id FROM {self._keyspace}.{self._table_prefix}"
"buckets WHERE bucket=? AND hash=?;"
)
self._prepared_statements["del_doc_from_bucket"] = session.prepare(
f"DELETE FROM {self._keyspace}.{self._table_prefix}"
"buckets WHERE bucket=? AND hash=? AND doc_id=?;"
)
for statement in self._prepared_statements.values():
statement.is_idempotent = True
return self
[docs] async def insert_setting(self, key: str, value: str):
"""Store a setting as key-value pair."""
with self._session() as session:
await self._execute(session, self._prepared_statements["set_setting"], (key, value))
[docs] async def query_setting(self, key: str) -> Optional[str]:
"""Query a setting with the given key.
Args:
key: The identifier of the setting
Returns:
A string with the value. If the key does not exist or the storage is uninitialized
None is returned.
Raises:
cassandra.DriverException: In case the database query fails for any reason.
""" # pylint: disable=missing-raises-doc
with self._session() as session:
try:
result_list = await self._execute(
session, self._prepared_statements["get_setting"], (key,)
)
return None if not result_list else result_list[0].value
except KeyError as e:
if "get_setting" in e.args:
return None
raise # Don't swallow unknown errors
[docs] async def insert_document(self, document: bytes, document_id: Optional[int] = None) -> int:
"""Add the data of a document to the storage and return its ID."""
with self._session() as session:
if document_id:
await self._execute(
session,
self._prepared_statements["set_doc"],
(document_id, document),
)
return document_id
else:
for _ in range(10):
doc_id = random.randint(a=0, b=2**32) # noqa=S311
result = await self._execute(
session,
self._prepared_statements["set_doc_checked"],
(doc_id, document),
)
inserted_successfully = result[0].applied
if (
inserted_successfully
or result[0].id == doc_id
and result[0].doc == document
):
return doc_id
raise RuntimeError("Unable to find an ID for a document. This should never happen.")
[docs] async def query_document(self, document_id: int) -> bytes:
"""Get the data belonging to a document.
Args:
document_id: The id of the document. This ID is created and returned by the
`insert_document` method.
Returns:
The document stored under the key `document_id` as bytes object.
Raises:
KeyError: If the document is not stored.
"""
with self._session() as session:
docs = await self._execute(
session, self._prepared_statements["get_doc"], (document_id,)
)
if not docs:
raise KeyError(f"No document with id {document_id}")
return docs[0].doc
[docs] async def query_documents(self, document_ids: List[int]) -> List[bytes]:
"""Get the data belonging to multiple documents.
Args:
document_ids: Key under which the data is stored.
Returns:
The documents stored under the key `document_id` as bytes object.
Raises:
KeyError: If no document was found for at least one of the ids.
"""
if len(document_ids) > QUERY_BATCH_SIZE:
with self._session() as session:
result_doc_dicts = await asyncio.gather(
*[
self._query_document_batch(session, document_ids[i : i + QUERY_BATCH_SIZE])
for i in range(0, len(document_ids), QUERY_BATCH_SIZE)
]
)
doc_dicts = {id_: doc for d in result_doc_dicts for id_, doc in d.items()}
return [doc_dicts[i] for i in document_ids]
else:
docs: List[bytes] = await asyncio.gather(
*[self.query_document(id_) for id_ in document_ids]
)
return docs
async def _query_document_batch(self, session, doc_id_batch):
doc_ids_str = ",".join(map(str, map(int, doc_id_batch)))
query = (
f"select id, doc from {self._keyspace}.{self._table_prefix}documents "
f"where id IN ({doc_ids_str});"
)
result_docs = {r.id: r.doc for r in await self._execute(session, query)}
return result_docs
[docs] async def remove_document(self, document_id: int):
"""Remove a document given by ID from the list of documents."""
with self._session() as session:
await self._execute(session, self._prepared_statements["del_doc"], (document_id,))
[docs] async def add_document_to_bucket(self, bucket_id: int, document_hash: int, document_id: int):
"""Link a document to a bucket."""
with self._session() as session:
await self._execute(
session,
self._prepared_statements["add_doc_to_bucket"],
(bucket_id, document_hash, document_id),
)
[docs] async def query_ids_from_bucket(self, bucket_id, document_hash: int) -> Iterable[int]:
"""Get all document IDs stored in a bucket for a certain hash value."""
with self._session() as session:
rows = await self._execute(
session,
self._prepared_statements["get_docs_from_bucket"],
(bucket_id, document_hash),
)
return [r.doc_id for r in rows]
[docs] async def remove_id_from_bucket(self, bucket_id: int, document_hash: int, document_id: int):
"""Remove a document from a bucket."""
with self._session() as session:
await self._execute(
session,
self._prepared_statements["del_doc_from_bucket"],
(bucket_id, document_hash, document_id),
)