feat: advanced filtering API (#1468)

* created query filter classes

* extended pagination to include query filtering

* added filtering tests

* type improvements

* move type help to dev depedency

* minor type and perf fixes

* breakup test cases

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson
2022-07-09 23:57:09 -05:00
committed by GitHub
parent c64da1fdb7
commit 7f50071312
8 changed files with 480 additions and 353 deletions

View File

@@ -1,6 +1,7 @@
from math import ceil
from typing import Any, Generic, TypeVar, Union
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import func
from sqlalchemy.orm.session import Session
@@ -8,6 +9,7 @@ from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
@@ -236,7 +238,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
are filtered by the user and group id when applicable.
NOTE: When you provide an override you'll need to manually type the result of this method
as the override, as the type system, is not able to infer the result of this method.
as the override, as the type system is not able to infer the result of this method.
"""
eff_schema = override or self.schema
@@ -244,6 +246,15 @@ class RepositoryGeneric(Generic[Schema, Model]):
fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)
except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))
count = q.count()
# interpret -1 as "get_all"

View File

@@ -3,6 +3,7 @@ from random import randint
from typing import Any, Optional
from uuid import UUID
from fastapi import HTTPException
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import and_, func
@@ -20,6 +21,7 @@ from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from .repository_generic import RepositoryGeneric
@@ -147,6 +149,15 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)
except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))
count = q.count()
# interpret -1 as "get_all"

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import datetime
import enum
from typing import Optional, Union
from uuid import UUID, uuid4
@@ -27,6 +28,8 @@ class SaveIngredientFood(CreateIngredientFood):
class IngredientFood(CreateIngredientFood):
id: UUID4
label: Optional[MultiPurposeLabelSummary] = None
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]
class Config:
orm_mode = True
@@ -48,6 +51,8 @@ class SaveIngredientUnit(CreateIngredientUnit):
class IngredientUnit(CreateIngredientUnit):
id: UUID4
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]
class Config:
orm_mode = True

View File

@@ -21,6 +21,7 @@ class PaginationQuery(MealieModel):
per_page: int = 50
order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc
query_filter: str = None
class PaginationBase(GenericModel, Generic[DataT]):

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import re
from enum import Enum
from typing import Any, TypeVar, Union, cast
from dateutil import parser as date_parser
from dateutil.parser import ParserError
from humps import decamelize
from sqlalchemy import bindparam, text
from sqlalchemy.orm.query import Query
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import BindParameter
Model = TypeVar("Model")
class RelationalOperator(Enum):
EQ = "="
NOTEQ = "<>"
GT = ">"
LT = "<"
GTE = ">="
LTE = "<="
class LogicalOperator(Enum):
AND = "AND"
OR = "OR"
class QueryFilterComponent:
"""A single relational statement"""
def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None:
self.attribute_name = decamelize(attribute_name)
self.relational_operator = relational_operator
self.value = value
# remove encasing quotes
if len(value) > 2 and value[0] == '"' and value[-1] == '"':
self.value = value[1:-1]
def __repr__(self) -> str:
return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]"
class QueryFilter:
lsep: str = "("
rsep: str = ")"
seps: set[str] = {lsep, rsep}
def __init__(self, filter_string: str) -> None:
# parse filter string
components = QueryFilter._break_filter_string_into_components(filter_string)
base_components = QueryFilter._break_components_into_base_components(components)
if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep):
raise ValueError("invalid filter string: parenthesis are unbalanced")
# parse base components into a filter group
self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components)
def __repr__(self) -> str:
return f'<<{" ".join([str(component.value if isinstance(component, LogicalOperator) else component) for component in self.filter_components])}>>'
def filter_query(self, query: Query, model: type[Model]) -> Query:
segments: list[str] = []
params: list[BindParameter] = []
for i, component in enumerate(self.filter_components):
if component in QueryFilter.seps:
segments.append(component) # type: ignore
continue
if isinstance(component, LogicalOperator):
segments.append(component.value)
continue
# for some reason typing doesn't like the lsep and rsep literals, so we explicitly mark this as a filter component instead
# cast doesn't actually do anything at runtime
component = cast(QueryFilterComponent, component)
if not hasattr(model, component.attribute_name):
raise ValueError(f"invalid query string: '{component.attribute_name}' does not exist on this schema")
# convert values to their proper types
attr = getattr(model, component.attribute_name)
value: Any = component.value
if isinstance(attr.type, (sqltypes.Date, sqltypes.DateTime)):
try:
value = date_parser.parse(component.value)
except ParserError as e:
raise ValueError(
f"invalid query string: unknown date or datetime format '{component.value}'"
) from e
if isinstance(attr.type, sqltypes.Boolean):
try:
value = component.value.lower()[0] in ["t", "y"] or component.value == "1"
except IndexError as e:
raise ValueError("invalid query string") from e
paramkey = f"P{i+1}"
segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"]))
params.append(bindparam(paramkey, value, attr.type))
qs = text(" ".join(segments)).bindparams(*params)
query = query.filter(qs)
return query
@staticmethod
def _break_filter_string_into_components(filter_string: str) -> list[str]:
"""Recursively break filter string into components based on parenthesis groupings"""
components = [filter_string]
in_quotes = False
while True:
subcomponents = []
for component in components:
# don't parse components comprised of only a separator
if component in QueryFilter.seps:
subcomponents.append(component)
continue
# construct a component until it hits the right separator
new_component = ""
for c in component:
# ignore characters in-between quotes
if c == '"':
in_quotes = not in_quotes
if c in QueryFilter.seps and not in_quotes:
if new_component:
subcomponents.append(new_component)
subcomponents.append(c)
new_component = ""
continue
new_component += c
if new_component:
subcomponents.append(new_component.strip())
if components == subcomponents:
break
components = subcomponents
return components
@staticmethod
def _break_components_into_base_components(components: list[str]) -> list[str]:
"""Further break down components by splitting at relational and logical operators"""
logical_operators = re.compile(
f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE
)
base_components = []
for component in components:
offset = 0
subcomponents = component.split('"')
for i, subcomponent in enumerate(subcomponents):
# don't parse components comprised of only a separator
if subcomponent in QueryFilter.seps:
offset += 1
base_components.append(subcomponent)
continue
# this subscomponent was surrounded in quotes, so we keep it as-is
if (i + offset) % 2:
base_components.append(f'"{subcomponent.strip()}"')
continue
# if the final subcomponent has quotes, it creates an extra empty subcomponent at the end
if not subcomponent:
continue
# parse out logical operators
new_components = [
base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component
]
# parse out relational operators; each base_subcomponent has exactly zero or one relational operator
# we do them one at a time in descending length since some operators overlap (e.g. :> and >)
for component in new_components:
if not component:
continue
added_to_base_components = False
for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True):
if rel_op in component:
new_base_components = [
base_component.strip() for base_component in component.split(rel_op) if base_component
]
new_base_components.insert(1, rel_op)
base_components.extend(new_base_components)
added_to_base_components = True
break
if not added_to_base_components:
base_components.append(component)
return base_components
@staticmethod
def _parse_base_components_into_filter_components(
base_components: list[str],
) -> list[Union[str, QueryFilterComponent, LogicalOperator]]:
"""Walk through base components and construct filter collections"""
relational_operators = [op.value for op in RelationalOperator]
logical_operators = [op.value for op in LogicalOperator]
# parse QueryFilterComponents and logical operators
components: list[Union[str, QueryFilterComponent, LogicalOperator]] = []
for i, base_component in enumerate(base_components):
if base_component in QueryFilter.seps:
components.append(base_component)
elif base_component in relational_operators:
components.append(
QueryFilterComponent(
attribute_name=base_components[i - 1],
relational_operator=RelationalOperator(base_components[i]),
value=base_components[i + 1],
)
)
elif base_component.upper() in logical_operators:
components.append(LogicalOperator(base_component.upper()))
return components