feat: User-specific Recipe Ratings (#3345)

This commit is contained in:
Michael Genson
2024-04-11 21:28:43 -05:00
committed by GitHub
parent 8ab09cf03b
commit 2a541f081a
50 changed files with 1497 additions and 443 deletions

View File

@@ -3,11 +3,11 @@ from collections.abc import Sequence
from random import randint
from uuid import UUID
import sqlalchemy as sa
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import InstrumentedAttribute, joinedload
from mealie.db.models.recipe.category import Category
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
@@ -15,11 +15,12 @@ from mealie.db.models.recipe.recipe import RecipeModel
from mealie.db.models.recipe.settings import RecipeSettings
from mealie.db.models.recipe.tag import Tag
from mealie.db.models.recipe.tool import Tool
from mealie.db.models.users.user_to_recipe import UserToRecipe
from mealie.schema.cookbook.cookbook import ReadCookBook
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import PaginationQuery
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
from ..db.models._model_base import SqlAlchemyBase
from ..schema._mealie.mealie_model import extract_uuids
@@ -51,7 +52,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
if order_by:
order_attr = getattr(self.model, str(order_by))
stmt = (
select(self.model)
sa.select(self.model)
.join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: E712
.order_by(order_attr.desc())
@@ -61,7 +62,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
return [eff_schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
stmt = (
select(self.model)
sa.select(self.model)
.join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: E712
.offset(start)
@@ -121,7 +122,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
order_attr = order_attr.asc()
stmt = (
select(RecipeModel)
sa.select(RecipeModel)
.options(*args)
.filter(RecipeModel.group_id == group_id)
.order_by(order_attr)
@@ -145,9 +146,54 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
ids.append(i_as_uuid)
except ValueError:
slugs.append(i)
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
additional_ids = self.session.execute(sa.select(model.id).filter(model.slug.in_(slugs))).scalars().all()
return ids + additional_ids
def add_order_attr_to_query(
self,
query: sa.Select,
order_attr: InstrumentedAttribute,
order_dir: OrderDirection,
order_by_null: OrderByNullPosition | None,
) -> sa.Select:
"""Special handling for ordering recipes by rating"""
column_name = order_attr.key
if column_name != "rating" or not self.user_id:
return super().add_order_attr_to_query(query, order_attr, order_dir, order_by_null)
# calculate the effictive rating for the user by using the user's rating if it exists,
# falling back to the recipe's rating if it doesn't
effective_rating_column_name = "_effective_rating"
query = query.add_columns(
sa.case(
(
sa.exists().where(
UserToRecipe.recipe_id == self.model.id,
UserToRecipe.user_id == self.user_id,
UserToRecipe.rating is not None,
UserToRecipe.rating > 0,
),
sa.select(UserToRecipe.rating)
.where(UserToRecipe.recipe_id == self.model.id, UserToRecipe.user_id == self.user_id)
.scalar_subquery(),
),
else_=self.model.rating,
).label(effective_rating_column_name)
)
order_attr = effective_rating_column_name
if order_dir is OrderDirection.asc:
order_attr = sa.asc(order_attr)
elif order_dir is OrderDirection.desc:
order_attr = sa.desc(order_attr)
if order_by_null is OrderByNullPosition.first:
order_attr = sa.nulls_first(order_attr)
elif order_by_null is OrderByNullPosition.last:
order_attr = sa.nulls_last(order_attr)
return query.order_by(order_attr)
def page_all( # type: ignore
self,
pagination: PaginationQuery,
@@ -165,7 +211,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
) -> RecipePagination:
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
pagination_result = pagination.model_copy()
q = select(self.model)
q = sa.select(self.model)
args = [
joinedload(RecipeModel.recipe_category),
@@ -236,7 +282,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
ids = [x.id for x in categories]
stmt = (
select(RecipeModel)
sa.select(RecipeModel)
.join(RecipeModel.recipe_category)
.filter(RecipeModel.recipe_category.any(Category.id.in_(ids)))
)
@@ -301,7 +347,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
require_all_tags=require_all_tags,
require_all_tools=require_all_tools,
)
stmt = select(RecipeModel).filter(*fltr)
stmt = sa.select(RecipeModel).filter(*fltr)
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_random_by_categories_and_tags(
@@ -318,26 +364,29 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
filters = self._build_recipe_filter(extract_uuids(categories), extract_uuids(tags)) # type: ignore
stmt = (
select(RecipeModel).filter(and_(*filters)).order_by(func.random()).limit(1) # Postgres and SQLite specific
sa.select(RecipeModel)
.filter(sa.and_(*filters))
.order_by(sa.func.random())
.limit(1) # Postgres and SQLite specific
)
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_random(self, limit=1) -> list[Recipe]:
stmt = (
select(RecipeModel)
sa.select(RecipeModel)
.filter(RecipeModel.group_id == self.group_id)
.order_by(func.random()) # Postgres and SQLite specific
.order_by(sa.func.random()) # Postgres and SQLite specific
.limit(limit)
)
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_by_slug(self, group_id: UUID4, slug: str, limit=1) -> Recipe | None:
stmt = select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
def get_by_slug(self, group_id: UUID4, slug: str) -> Recipe | None:
stmt = sa.select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
dbrecipe = self.session.execute(stmt).scalars().one_or_none()
if dbrecipe is None:
return None
return self.schema.model_validate(dbrecipe)
def all_ids(self, group_id: UUID4) -> Sequence[UUID4]:
stmt = select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
stmt = sa.select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
return self.session.execute(stmt).scalars().all()