feat: Generalize Search to Other Models (#2472)

* generalized search logic to SearchFilter

* added default search behavior for all models

* fix for schema overrides

* added search support to several models

* fix for label search

* tests and fixes

* add config for normalizing characters

* dramatically simplified search tests

* bark bark

* fix normalization bug

* tweaked tests

* maybe this time?

---------

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson
2023-08-20 13:30:21 -05:00
committed by GitHub
parent 76ae0bafc7
commit 99372aa2b6
16 changed files with 521 additions and 250 deletions

View File

@@ -16,6 +16,7 @@ from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from mealie.schema.response.query_search import SearchFilter
Schema = TypeVar("Schema", bound=MealieModel)
Model = TypeVar("Model", bound=SqlAlchemyBase)
@@ -291,7 +292,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]:
"""
pagination is a method to interact with the filtered database table and return a paginated result
using the PaginationBase that provides several data points that are needed to manage pagination
@@ -302,12 +303,16 @@ class RepositoryGeneric(Generic[Schema, Model]):
as the override, as the type system is not able to infer the result of this method.
"""
eff_schema = override or self.schema
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
pagination_result = pagination.copy()
q = self._query(override_schema=eff_schema, with_options=False)
fltr = self._filter_builder()
q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
if search:
q = self.add_search_to_query(q, eff_schema, search)
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
# Apply options late, so they do not get used for counting
q = q.options(*eff_schema.loader_options())
@@ -318,8 +323,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.session.rollback()
raise e
return PaginationBase(
page=pagination.page,
per_page=pagination.per_page,
page=pagination_result.page,
per_page=pagination_result.per_page,
total=count,
total_pages=total_pages,
items=[eff_schema.from_orm(s) for s in data],
@@ -392,3 +397,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
query = query.order_by(case_stmt)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
search_filter = SearchFilter(self.session, search, schema._normalize_search)
return search_filter.filter_query_by_search(query, schema, self.model)

View File

@@ -5,10 +5,9 @@ from uuid import UUID
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import Select, and_, desc, func, or_, select, text
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from text_unidecode import unidecode
from mealie.db.models.recipe.category import Category
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
@@ -18,13 +17,7 @@ from mealie.db.models.recipe.tag import Tag
from mealie.db.models.recipe.tool import Tool
from mealie.schema.cookbook.cookbook import ReadCookBook
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import (
RecipeCategory,
RecipePagination,
RecipeSummary,
RecipeTag,
RecipeTool,
)
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import PaginationQuery
@@ -151,98 +144,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
return ids + additional_ids
def _add_search_to_query(self, query: Select, search: str) -> Select:
"""
0. fuzzy search (postgres only) and tokenized search are performed separately
1. take search string and do a little pre-normalization
2. look for internal quoted strings and keep them together as "literal" parts of the search
3. remove special characters from each non-literal search string
4. token search looks for any individual exact hit in name, description, and ingredients
5. fuzzy search looks for trigram hits in name, description, and ingredients
6. Sort order is determined by closeness to the recipe name
Should search also look at tags?
"""
normalized_search = unidecode(search).lower().strip()
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
# keep quoted phrases together as literal portions of the search string
literal = False
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") # thank you stack exchange!
removequotes_regex = re.compile(r"""['"](.*)['"]""")
if quoted_regex.search(normalized_search):
literal = True
temp = normalized_search
quoted_search_list = [match.group() for match in quoted_regex.finditer(temp)] # all quoted strings
quoted_search_list = [removequotes_regex.sub("\\1", x) for x in quoted_search_list] # remove outer quotes
temp = quoted_regex.sub("", temp) # remove all quoted strings, leaving just non-quoted
temp = temp.translate(
str.maketrans(punctuation, " " * len(punctuation))
) # punctuation->spaces for splitting, but only on unquoted strings
unquoted_search_list = temp.split() # all unquoted strings
normalized_search_list = quoted_search_list + unquoted_search_list
else:
#
normalized_search = normalized_search.translate(str.maketrans(punctuation, " " * len(punctuation)))
normalized_search_list = normalized_search.split()
normalized_search_list = [x.strip() for x in normalized_search_list] # remove padding whitespace inside quotes
# I would prefer to just do this in the recipe_ingredient.any part of the main query, but it turns out
# that at least sqlite wont use indexes for that correctly anymore and takes a big hit, so prefiltering it is
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
ingredient_ids = (
self.session.execute(
select(RecipeIngredientModel.id).filter(
or_(
RecipeIngredientModel.note_normalized.op("%>")(normalized_search),
RecipeIngredientModel.original_text_normalized.op("%>")(normalized_search),
)
)
)
.scalars()
.all()
)
else: # exact token search
ingredient_ids = (
self.session.execute(
select(RecipeIngredientModel.id).filter(
or_(
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in normalized_search_list],
*[
RecipeIngredientModel.original_text_normalized.like(f"%{ns}%")
for ns in normalized_search_list
],
)
)
)
.scalars()
.all()
)
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
# default = 0.7 is too strict for effective fuzzing
self.session.execute(text("set pg_trgm.word_similarity_threshold = 0.5;"))
q = query.filter(
or_(
RecipeModel.name_normalized.op("%>")(normalized_search),
RecipeModel.description_normalized.op("%>")(normalized_search),
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
func.least(
RecipeModel.name_normalized.op("<->>")(normalized_search),
)
)
else: # exact token search
q = query.filter(
or_(
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in normalized_search_list],
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in normalized_search_list],
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by(desc(RecipeModel.name_normalized.like(f"%{normalized_search}%")))
return q
def page_all(
def page_all( # type: ignore
self,
pagination: PaginationQuery,
override=None,
@@ -299,7 +201,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
)
q = q.filter(*filters)
if search:
q = self._add_search_to_query(q, search)
q = self.add_search_to_query(q, self.schema, search)
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)

View File

@@ -41,10 +41,11 @@ class MultiPurposeLabelsController(BaseUserController):
return HttpRepo(self.repo, self.logger, self.registered_exceptions, self.t("generic.server-error"))
@router.get("", response_model=MultiPurposeLabelPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=MultiPurposeLabelSummary,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -38,11 +38,12 @@ class RecipeCategoryController(BaseCrudController):
return HttpRepo(self.repo, self.logger)
@router.get("", response_model=RecipeCategoryPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
"""Returns a list of available categories in the database"""
response = self.repo.page_all(
pagination=q,
override=RecipeCategory,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -27,11 +27,12 @@ class TagController(BaseCrudController):
return HttpRepo(self.repo, self.logger)
@router.get("", response_model=RecipeTagPagination)
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
"""Returns a list of available tags in the database"""
response = self.repo.page_all(
pagination=q,
override=RecipeTag,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -25,10 +25,11 @@ class RecipeToolController(BaseUserController):
return HttpRepo[RecipeToolCreate, RecipeTool, RecipeToolCreate](self.repo, self.logger)
@router.get("", response_model=RecipeToolPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=RecipeTool,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -45,10 +45,11 @@ class IngredientFoodsController(BaseUserController):
raise HTTPException(500, "Failed to merge foods") from e
@router.get("", response_model=IngredientFoodPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=IngredientFood,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -45,10 +45,11 @@ class IngredientUnitsController(BaseUserController):
raise HTTPException(500, "Failed to merge units") from e
@router.get("", response_model=IngredientUnitPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=IngredientUnit,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@@ -1,7 +1,8 @@
# This file is auto-generated by gen_schema_exports.py
from .mealie_model import HasUUID, MealieModel
from .mealie_model import HasUUID, MealieModel, SearchType
__all__ = [
"HasUUID",
"MealieModel",
"SearchType",
]

View File

@@ -1,16 +1,34 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol, TypeVar
from enum import Enum
from typing import ClassVar, Protocol, TypeVar
from humps.main import camelize
from pydantic import UUID4, BaseModel
from sqlalchemy import Select, desc, func, or_, text
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.interfaces import LoaderOption
from mealie.db.models._model_base import SqlAlchemyBase
T = TypeVar("T", bound=BaseModel)
class SearchType(Enum):
fuzzy = "fuzzy"
tokenized = "tokenized"
class MealieModel(BaseModel):
_fuzzy_similarity_threshold: ClassVar[float] = 0.5
_normalize_search: ClassVar[bool] = False
_searchable_properties: ClassVar[list[str]] = []
"""
Searchable properties for the search API.
The first property will be used for sorting (order_by)
"""
class Config:
alias_generator = camelize
allow_population_by_field_name = True
@@ -59,6 +77,40 @@ class MealieModel(BaseModel):
def loader_options(cls) -> list[LoaderOption]:
return []
@classmethod
def filter_search_query(
cls,
db_model: type[SqlAlchemyBase],
query: Select,
session: Session,
search_type: SearchType,
search: str,
search_list: list[str],
) -> Select:
"""
Filters a search query based on model attributes
Can be overridden to support a more advanced search
"""
if not cls._searchable_properties:
raise AttributeError("Not Implemented")
model_properties: list[InstrumentedAttribute] = [getattr(db_model, prop) for prop in cls._searchable_properties]
if search_type is SearchType.fuzzy:
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
filters = [prop.op("%>")(search) for prop in model_properties]
# trigram ordering by the first searchable property
return query.filter(or_(*filters)).order_by(func.least(model_properties[0].op("<->>")(search)))
else:
filters = []
for prop in model_properties:
filters.extend([prop.like(f"%{s}%") for s in search_list])
# order by how close the result is to the first searchable property
return query.filter(or_(*filters)).order_by(desc(model_properties[0].like(f"%{search}%")))
class HasUUID(Protocol):
id: UUID4

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import ClassVar
from pydantic import UUID4
from mealie.schema._mealie import MealieModel
@@ -20,7 +22,7 @@ class MultiPurposeLabelUpdate(MultiPurposeLabelSave):
class MultiPurposeLabelSummary(MultiPurposeLabelUpdate):
pass
_searchable_properties: ClassVar[list[str]] = ["name"]
class Config:
orm_mode = True
@@ -31,14 +33,5 @@ class MultiPurposeLabelPagination(PaginationBase):
class MultiPurposeLabelOut(MultiPurposeLabelUpdate):
# shopping_list_items: list[ShoppingListItemOut] = []
# foods: list[IngredientFood] = []
class Config:
orm_mode = True
# from mealie.schema.recipe.recipe_ingredient import IngredientFood
# from mealie.schema.group.group_shopping_list import ShoppingListItemOut
# MultiPurposeLabelOut.update_forward_refs()

View File

@@ -2,16 +2,17 @@ from __future__ import annotations
import datetime
from pathlib import Path
from typing import Any
from typing import Any, ClassVar
from uuid import uuid4
from pydantic import UUID4, BaseModel, Field, validator
from slugify import slugify
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy import Select, desc, func, or_, select, text
from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy.orm.interfaces import LoaderOption
from mealie.core.config import get_app_dirs
from mealie.schema._mealie import MealieModel
from mealie.schema._mealie import MealieModel, SearchType
from mealie.schema.response.pagination import PaginationBase
from ...db.models.recipe import (
@@ -37,6 +38,8 @@ class RecipeTag(MealieModel):
name: str
slug: str
_searchable_properties: ClassVar[list[str]] = ["name"]
class Config:
orm_mode = True
@@ -78,6 +81,7 @@ class CreateRecipe(MealieModel):
class RecipeSummary(MealieModel):
id: UUID4 | None
_normalize_search: ClassVar[bool] = True
user_id: UUID4 = Field(default_factory=uuid4)
group_id: UUID4 = Field(default_factory=uuid4)
@@ -259,6 +263,69 @@ class Recipe(RecipeSummary):
selectinload(RecipeModel.notes),
]
@classmethod
def filter_search_query(
cls, db_model, query: Select, session: Session, search_type: SearchType, search: str, search_list: list[str]
) -> Select:
"""
1. token search looks for any individual exact hit in name, description, and ingredients
2. fuzzy search looks for trigram hits in name, description, and ingredients
3. Sort order is determined by closeness to the recipe name
Should search also look at tags?
"""
if search_type is SearchType.fuzzy:
# I would prefer to just do this in the recipe_ingredient.any part of the main query,
# but it turns out that at least sqlite wont use indexes for that correctly anymore and
# takes a big hit, so prefiltering it is
ingredient_ids = (
session.execute(
select(RecipeIngredientModel.id).filter(
or_(
RecipeIngredientModel.note_normalized.op("%>")(search),
RecipeIngredientModel.original_text_normalized.op("%>")(search),
)
)
)
.scalars()
.all()
)
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
return query.filter(
or_(
RecipeModel.name_normalized.op("%>")(search),
RecipeModel.description_normalized.op("%>")(search),
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
func.least(
RecipeModel.name_normalized.op("<->>")(search),
)
)
else:
ingredient_ids = (
session.execute(
select(RecipeIngredientModel.id).filter(
or_(
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in search_list],
*[RecipeIngredientModel.original_text_normalized.like(f"%{ns}%") for ns in search_list],
)
)
)
.scalars()
.all()
)
return query.filter(
or_(
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in search_list],
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in search_list],
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by(desc(RecipeModel.name_normalized.like(f"%{search}%")))
class RecipeLastMade(BaseModel):
timestamp: datetime.datetime

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import datetime
import enum
from fractions import Fraction
from typing import ClassVar
from uuid import UUID, uuid4
from pydantic import UUID4, Field, validator
@@ -50,6 +51,8 @@ class IngredientFood(CreateIngredientFood):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "description"]
class Config:
orm_mode = True
getter_dict = ExtrasGetterDict
@@ -78,6 +81,8 @@ class IngredientUnit(CreateIngredientUnit):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "abbreviation", "description"]
class Config:
orm_mode = True

View File

@@ -0,0 +1,67 @@
import re
from sqlalchemy import Select
from sqlalchemy.orm import Session
from text_unidecode import unidecode
from ...db.models._model_base import SqlAlchemyBase
from .._mealie import MealieModel, SearchType
class SearchFilter:
"""
0. fuzzy search (postgres only) and tokenized search are performed separately
1. take search string and do a little pre-normalization
2. look for internal quoted strings and keep them together as "literal" parts of the search
3. remove special characters from each non-literal search string
"""
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
@classmethod
def _normalize_search(cls, search: str, normalize_characters: bool) -> str:
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
if normalize_characters:
search = unidecode(search).lower().strip()
else:
search = search.strip()
return search
@classmethod
def _build_search_list(cls, search: str) -> list[str]:
if cls.quoted_regex.search(search):
# all quoted strings
quoted_search_list = [match.group() for match in cls.quoted_regex.finditer(search)]
# remove outer quotes
quoted_search_list = [cls.remove_quotes_regex.sub("\\1", x) for x in quoted_search_list]
# punctuation->spaces for splitting, but only on unquoted strings
search = cls.quoted_regex.sub("", search) # remove all quoted strings, leaving just non-quoted
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
# all unquoted strings
unquoted_search_list = search.split()
search_list = quoted_search_list + unquoted_search_list
else:
search_list = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))).split()
# remove padding whitespace inside quotes
return [x.strip() for x in search_list]
def __init__(self, session: Session, search: str, normalize_characters: bool = False) -> None:
if session.get_bind().name != "postgresql" or self.quoted_regex.search(search.strip()):
self.search_type = SearchType.tokenized
else:
self.search_type = SearchType.fuzzy
self.session = session
self.search = self._normalize_search(search, normalize_characters)
self.search_list = self._build_search_list(self.search)
def filter_query_by_search(self, query: Select, schema: type[MealieModel], model: type[SqlAlchemyBase]) -> Select:
return schema.filter_search_query(model, query, self.session, self.search_type, self.search, self.search_list)