Source code for sqlalchemy_searchable.vectorizers
from collections.abc import Callable
from functools import wraps
from inspect import isclass
from typing import Any
import sqlalchemy as sa
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.type_api import TypeEngine
VectorizerFunc = Callable[[sa.ColumnClause[Any]], sa.ColumnElement[str]]
[docs]
class Vectorizer:
def __init__(
self,
type_vectorizers: dict[type[TypeEngine[Any]], VectorizerFunc] | None = None,
column_vectorizers: dict[sa.Column[Any], VectorizerFunc] | None = None,
):
self.type_vectorizers = {} if type_vectorizers is None else type_vectorizers
self.column_vectorizers = (
{} if column_vectorizers is None else column_vectorizers
)
[docs]
def clear(self) -> None:
"""Clear all registered vectorizers."""
self.type_vectorizers = {}
self.column_vectorizers = {}
def contains_tsvector(self, tsvector_column: sa.Column[Any]) -> bool:
if not hasattr(tsvector_column.type, "columns"):
return False
return any(
getattr(tsvector_column.table.c, column) in self
for column in tsvector_column.type.columns
)
def __contains__(self, column: sa.Column[Any]) -> bool:
try:
self[column]
return True
except KeyError:
return False
def __getitem__(self, column: sa.Column[Any]) -> VectorizerFunc:
if column in self.column_vectorizers:
return self.column_vectorizers[column]
type_class = column.type.__class__
if type_class in self.type_vectorizers:
return self.type_vectorizers[type_class]
raise KeyError(column)
[docs]
def __call__(
self,
type_or_column: type[TypeEngine[Any]]
| sa.Column[Any]
| InstrumentedAttribute[Any],
) -> Callable[[VectorizerFunc], VectorizerFunc]:
"""Decorator to register a function as a vectorizer.
:param type_or_column: the SQLAlchemy database data type or the column to
register a vectorizer for
"""
def outer(func: VectorizerFunc) -> VectorizerFunc:
@wraps(func)
def wrapper(
column_reference: sa.ColumnClause[Any],
) -> sa.ColumnElement[str]:
return func(column_reference)
if isclass(type_or_column) and issubclass(type_or_column, TypeEngine):
self.type_vectorizers[type_or_column] = wrapper
elif isinstance(type_or_column, sa.Column):
self.column_vectorizers[type_or_column] = wrapper
elif isinstance(type_or_column, InstrumentedAttribute):
prop = type_or_column.property
if not isinstance(prop, sa.orm.ColumnProperty):
raise TypeError(
"Given InstrumentedAttribute does not wrap "
"ColumnProperty. Only instances of ColumnProperty are "
"supported for vectorizer."
)
column = type_or_column.property.columns[0]
self.column_vectorizers[column] = wrapper
else:
raise TypeError(
"First argument should be either valid SQLAlchemy type, "
"Column, ColumnProperty or InstrumentedAttribute object."
)
return wrapper
return outer