mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-02-09 09:23:12 -05:00
prs-fleshgolem-2070: feat: sqlalchemy 2.0 (#2096)
* upgrade sqlalchemy to 2.0 * rewrite all db models to sqla 2.0 mapping api * fix some importing and typing weirdness * fix types of a lot of nullable columns * remove get_ref methods * fix issues found by tests * rewrite all queries in repository_recipe to 2.0 style * rewrite all repository queries to 2.0 api * rewrite all remaining queries to 2.0 api * remove now-unneeded __allow_unmapped__ flag * remove and fix some unneeded cases of "# type: ignore" * fix formatting * bump black version * run black * can this please be the last one. okay. just. okay. * fix repository errors * remove return * drop open API validator --------- Co-authored-by: Sören Busch <fleshgolem@gmx.net>
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from mealie.db.models.group import Group, GroupMealPlan, ReportEntryModel, ReportModel
|
||||
@@ -70,13 +72,16 @@ PK_GROUP_ID = "group_id"
|
||||
|
||||
|
||||
class RepositoryCategories(RepositoryGeneric[CategoryOut, Category]):
|
||||
def get_empty(self):
|
||||
return self.session.query(Category).filter(~Category.recipes.any()).all()
|
||||
def get_empty(self) -> Sequence[Category]:
|
||||
stmt = select(Category).filter(~Category.recipes.any())
|
||||
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
class RepositoryTags(RepositoryGeneric[TagOut, Tag]):
|
||||
def get_empty(self):
|
||||
return self.session.query(Tag).filter(~Tag.recipes.any()).all()
|
||||
def get_empty(self) -> Sequence[Tag]:
|
||||
stmt = select(Tag).filter(~Tag.recipes.any())
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
class AllRepositories:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import UUID4
|
||||
from sqlalchemy import select
|
||||
|
||||
from mealie.db.models.recipe.ingredient import IngredientFoodModel
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientFood
|
||||
@@ -7,15 +8,13 @@ from .repository_generic import RepositoryGeneric
|
||||
|
||||
|
||||
class RepositoryFood(RepositoryGeneric[IngredientFood, IngredientFoodModel]):
|
||||
def _get_food(self, id: UUID4) -> IngredientFoodModel:
|
||||
stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id}))
|
||||
return self.session.execute(stmt).scalars().one()
|
||||
|
||||
def merge(self, from_food: UUID4, to_food: UUID4) -> IngredientFood | None:
|
||||
|
||||
from_model: IngredientFoodModel = (
|
||||
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_food})).one()
|
||||
)
|
||||
|
||||
to_model: IngredientFoodModel = (
|
||||
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_food})).one()
|
||||
)
|
||||
from_model = self._get_food(from_food)
|
||||
to_model = self._get_food(to_food)
|
||||
|
||||
to_model.ingredients += from_model.ingredients
|
||||
|
||||
@@ -29,4 +28,4 @@ class RepositoryFood(RepositoryGeneric[IngredientFood, IngredientFoodModel]):
|
||||
return self.get_one(to_food)
|
||||
|
||||
def by_group(self, group_id: UUID4) -> "RepositoryFood":
|
||||
return super().by_group(group_id) # type: ignore
|
||||
return super().by_group(group_id)
|
||||
|
||||
@@ -6,21 +6,19 @@ from typing import Any, Generic, TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import UUID4, BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy import Select, delete, func, select
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from mealie.core.root_logger import get_logger
|
||||
from mealie.schema.response.pagination import (
|
||||
OrderDirection,
|
||||
PaginationBase,
|
||||
PaginationQuery,
|
||||
)
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
|
||||
from mealie.schema.response.query_filter import QueryFilter
|
||||
|
||||
Schema = TypeVar("Schema", bound=BaseModel)
|
||||
Model = TypeVar("Model")
|
||||
Model = TypeVar("Model", bound=SqlAlchemyBase)
|
||||
|
||||
T = TypeVar("T", bound="RepositoryGeneric")
|
||||
|
||||
|
||||
class RepositoryGeneric(Generic[Schema, Model]):
|
||||
@@ -33,6 +31,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
user_id: UUID4 | None = None
|
||||
group_id: UUID4 | None = None
|
||||
session: Session
|
||||
|
||||
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
|
||||
self.session = session
|
||||
@@ -42,11 +41,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
self.logger = get_logger()
|
||||
|
||||
def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
|
||||
def by_user(self: T, user_id: UUID4) -> T:
|
||||
self.user_id = user_id
|
||||
return self
|
||||
|
||||
def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
|
||||
def by_group(self: T, group_id: UUID4) -> T:
|
||||
self.group_id = group_id
|
||||
return self
|
||||
|
||||
@@ -55,7 +54,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
self.logger.error(e)
|
||||
|
||||
def _query(self):
|
||||
return self.session.query(self.model)
|
||||
return select(self.model)
|
||||
|
||||
def _filter_builder(self, **kwargs) -> dict[str, Any]:
|
||||
dct = {}
|
||||
@@ -98,8 +97,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
except AttributeError:
|
||||
self.logger.info(f'Attempted to sort by unknown sort property "{order_by}"; ignoring')
|
||||
|
||||
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
|
||||
result = self.session.execute(q.offset(start).limit(limit)).scalars().all()
|
||||
return [eff_schema.from_orm(x) for x in result]
|
||||
|
||||
def multi_query(
|
||||
self,
|
||||
@@ -120,7 +119,9 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
order_attr = order_attr.desc()
|
||||
q = q.order_by(order_attr)
|
||||
|
||||
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
|
||||
q = q.offset(start).limit(limit)
|
||||
result = self.session.execute(q).scalars().all()
|
||||
return [eff_schema.from_orm(x) for x in result]
|
||||
|
||||
def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model:
|
||||
"""
|
||||
@@ -131,14 +132,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
match_key = self.primary_key
|
||||
|
||||
fltr = self._filter_builder(**{match_key: match_value})
|
||||
return self._query().filter_by(**fltr).one()
|
||||
return self.session.execute(self._query().filter_by(**fltr)).scalars().one()
|
||||
|
||||
def get_one(
|
||||
self, value: str | int | UUID4, key: str | None = None, any_case=False, override_schema=None
|
||||
) -> Schema | None:
|
||||
key = key or self.primary_key
|
||||
|
||||
q = self.session.query(self.model)
|
||||
q = self._query()
|
||||
|
||||
if any_case:
|
||||
search_attr = getattr(self.model, key)
|
||||
@@ -146,7 +147,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
else:
|
||||
q = q.filter_by(**self._filter_builder(**{key: value}))
|
||||
|
||||
result = q.one_or_none()
|
||||
result = self.session.execute(q).scalars().one_or_none()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
@@ -156,7 +157,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
def create(self, data: Schema | BaseModel | dict) -> Schema:
|
||||
data = data if isinstance(data, dict) else data.dict()
|
||||
new_document = self.model(session=self.session, **data) # type: ignore
|
||||
new_document = self.model(session=self.session, **data)
|
||||
self.session.add(new_document)
|
||||
self.session.commit()
|
||||
self.session.refresh(new_document)
|
||||
@@ -167,7 +168,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
new_documents = []
|
||||
for document in data:
|
||||
document = document if isinstance(document, dict) else document.dict()
|
||||
new_document = self.model(session=self.session, **document) # type: ignore
|
||||
new_document = self.model(session=self.session, **document)
|
||||
new_documents.append(new_document)
|
||||
|
||||
self.session.add_all(new_documents)
|
||||
@@ -191,7 +192,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
|
||||
|
||||
entry = self._query_one(match_value=match_value)
|
||||
entry.update(session=self.session, **new_data) # type: ignore
|
||||
entry.update(session=self.session, **new_data)
|
||||
|
||||
self.session.commit()
|
||||
return self.schema.from_orm(entry)
|
||||
@@ -202,7 +203,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
document_data = document if isinstance(document, dict) else document.dict()
|
||||
document_data_by_id[document_data["id"]] = document_data
|
||||
|
||||
documents_to_update = self._query().filter(self.model.id.in_(list(document_data_by_id.keys()))) # type: ignore
|
||||
documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
|
||||
documents_to_update = self.session.execute(documents_to_update_query).scalars().all()
|
||||
|
||||
updated_documents = []
|
||||
for document_to_update in documents_to_update:
|
||||
@@ -226,7 +228,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
def delete(self, value, match_key: str | None = None) -> Schema:
|
||||
match_key = match_key or self.primary_key
|
||||
|
||||
result = self._query().filter_by(**{match_key: value}).one()
|
||||
result = self.session.execute(self._query().filter_by(**{match_key: value})).scalars().one()
|
||||
results_as_model = self.schema.from_orm(result)
|
||||
|
||||
try:
|
||||
@@ -239,7 +241,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
return results_as_model
|
||||
|
||||
def delete_many(self, values: Iterable) -> Schema:
|
||||
results = self._query().filter(self.model.id.in_(values)) # type: ignore
|
||||
query = self._query().filter(self.model.id.in_(values)) # type: ignore
|
||||
results = self.session.execute(query).scalars().all()
|
||||
results_as_model = [self.schema.from_orm(result) for result in results]
|
||||
|
||||
try:
|
||||
@@ -256,14 +259,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
return results_as_model # type: ignore
|
||||
|
||||
def delete_all(self) -> None:
|
||||
self._query().delete()
|
||||
delete(self.model)
|
||||
self.session.commit()
|
||||
|
||||
def count_all(self, match_key=None, match_value=None) -> int:
|
||||
if None in [match_key, match_value]:
|
||||
return self._query().count()
|
||||
else:
|
||||
return self._query().filter_by(**{match_key: match_value}).count()
|
||||
q = select(func.count(self.model.id))
|
||||
if None not in [match_key, match_value]:
|
||||
q = q.filter_by(**{match_key: match_value})
|
||||
return self.session.scalar(q)
|
||||
|
||||
def _count_attribute(
|
||||
self,
|
||||
@@ -274,12 +277,12 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
) -> int | list[Schema]: # sourcery skip: assign-if-exp
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
q = self._query().filter(attribute_name == attr_match)
|
||||
|
||||
if count:
|
||||
return q.count()
|
||||
q = select(func.count(self.model.id)).filter(attribute_name == attr_match)
|
||||
return self.session.scalar(q)
|
||||
else:
|
||||
return [eff_schema.from_orm(x) for x in q.all()]
|
||||
q = self._query().filter(attribute_name == attr_match)
|
||||
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
|
||||
|
||||
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
|
||||
"""
|
||||
@@ -293,14 +296,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
"""
|
||||
eff_schema = override or self.schema
|
||||
|
||||
q = self.session.query(self.model)
|
||||
q = self._query()
|
||||
|
||||
fltr = self._filter_builder()
|
||||
q = q.filter_by(**fltr)
|
||||
q, count, total_pages = self.add_pagination_to_query(q, pagination)
|
||||
|
||||
try:
|
||||
data = q.all()
|
||||
data = self.session.execute(q).scalars().all()
|
||||
except Exception as e:
|
||||
self._log_exception(e)
|
||||
self.session.rollback()
|
||||
@@ -314,7 +317,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
items=[eff_schema.from_orm(s) for s in data],
|
||||
)
|
||||
|
||||
def add_pagination_to_query(self, query: Query, pagination: PaginationQuery) -> tuple[Query, int, int]:
|
||||
def add_pagination_to_query(self, query: Select, pagination: PaginationQuery) -> tuple[Select, int, int]:
|
||||
"""
|
||||
Adds pagination data to an existing query.
|
||||
|
||||
@@ -333,7 +336,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
self.logger.error(e)
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
count = query.count()
|
||||
count_query = select(func.count()).select_from(query)
|
||||
count = self.session.scalar(count_query)
|
||||
|
||||
# interpret -1 as "get_all"
|
||||
if pagination.per_page == -1:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import UUID4
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from mealie.db.models.group import Group
|
||||
from mealie.db.models.recipe.category import Category
|
||||
@@ -9,21 +10,26 @@ from mealie.db.models.users.users import User
|
||||
from mealie.schema.group.group_statistics import GroupStatistics
|
||||
from mealie.schema.user.user import GroupInDB
|
||||
|
||||
from ..db.models._model_base import SqlAlchemyBase
|
||||
from .repository_generic import RepositoryGeneric
|
||||
|
||||
|
||||
class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]):
|
||||
def get_by_name(self, name: str, limit=1) -> GroupInDB | Group | None:
|
||||
dbgroup = self.session.query(self.model).filter_by(**{"name": name}).one_or_none()
|
||||
def get_by_name(self, name: str) -> GroupInDB | None:
|
||||
dbgroup = self.session.execute(select(self.model).filter_by(name=name)).scalars().one_or_none()
|
||||
if dbgroup is None:
|
||||
return None
|
||||
return self.schema.from_orm(dbgroup)
|
||||
|
||||
def statistics(self, group_id: UUID4) -> GroupStatistics:
|
||||
def model_count(model: type[SqlAlchemyBase]) -> int:
|
||||
stmt = select(func.count(model.id)).filter_by(group_id=group_id)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
return GroupStatistics(
|
||||
total_recipes=self.session.query(RecipeModel).filter_by(group_id=group_id).count(),
|
||||
total_users=self.session.query(User).filter_by(group_id=group_id).count(),
|
||||
total_categories=self.session.query(Category).filter_by(group_id=group_id).count(),
|
||||
total_tags=self.session.query(Tag).filter_by(group_id=group_id).count(),
|
||||
total_tools=self.session.query(Tool).filter_by(group_id=group_id).count(),
|
||||
total_recipes=model_count(RecipeModel),
|
||||
total_users=model_count(User),
|
||||
total_categories=model_count(Category),
|
||||
total_tags=model_count(Tag),
|
||||
total_tools=model_count(Tool),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from mealie.db.models.group.mealplan import GroupMealPlanRules
|
||||
from mealie.schema.meal_plan.plan_rules import PlanRulesDay, PlanRulesOut, PlanRulesType
|
||||
@@ -10,10 +10,10 @@ from .repository_generic import RepositoryGeneric
|
||||
|
||||
class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules]):
|
||||
def by_group(self, group_id: UUID) -> "RepositoryMealPlanRules":
|
||||
return super().by_group(group_id) # type: ignore
|
||||
return super().by_group(group_id)
|
||||
|
||||
def get_rules(self, day: PlanRulesDay, entry_type: PlanRulesType) -> list[PlanRulesOut]:
|
||||
qry = self.session.query(GroupMealPlanRules).filter(
|
||||
stmt = select(GroupMealPlanRules).filter(
|
||||
or_(
|
||||
GroupMealPlanRules.day == day,
|
||||
GroupMealPlanRules.day.is_(None),
|
||||
@@ -26,4 +26,6 @@ class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules
|
||||
),
|
||||
)
|
||||
|
||||
return [self.schema.from_orm(x) for x in qry.all()]
|
||||
rules = self.session.execute(stmt).scalars().all()
|
||||
|
||||
return [self.schema.from_orm(x) for x in rules]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from mealie.db.models.group import GroupMealPlan
|
||||
from mealie.schema.meal_plan.new_meal import ReadPlanEntry
|
||||
|
||||
@@ -9,10 +11,10 @@ from .repository_generic import RepositoryGeneric
|
||||
|
||||
class RepositoryMeals(RepositoryGeneric[ReadPlanEntry, GroupMealPlan]):
|
||||
def by_group(self, group_id: UUID) -> "RepositoryMeals":
|
||||
return super().by_group(group_id) # type: ignore
|
||||
return super().by_group(group_id)
|
||||
|
||||
def get_today(self, group_id: UUID) -> list[ReadPlanEntry]:
|
||||
today = date.today()
|
||||
qry = self.session.query(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)
|
||||
|
||||
return [self.schema.from_orm(x) for x in qry.all()]
|
||||
stmt = select(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)
|
||||
plans = self.session.execute(stmt).scalars().all()
|
||||
return [self.schema.from_orm(x) for x in plans]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from random import randint
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import UUID4
|
||||
from slugify import slugify
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
@@ -18,13 +18,14 @@ from mealie.schema.cookbook.cookbook import ReadCookBook
|
||||
from mealie.schema.recipe import Recipe
|
||||
from mealie.schema.recipe.recipe import (
|
||||
RecipeCategory,
|
||||
RecipePagination,
|
||||
RecipeSummary,
|
||||
RecipeSummaryWithIngredients,
|
||||
RecipeTag,
|
||||
RecipeTool,
|
||||
)
|
||||
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
|
||||
from mealie.schema.response.pagination import PaginationBase, PaginationQuery
|
||||
from mealie.schema.response.pagination import PaginationQuery
|
||||
|
||||
from .repository_generic import RepositoryGeneric
|
||||
|
||||
@@ -46,34 +47,31 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
raise
|
||||
|
||||
def by_group(self, group_id: UUID) -> "RepositoryRecipes":
|
||||
return super().by_group(group_id) # type: ignore
|
||||
return super().by_group(group_id)
|
||||
|
||||
def get_all_public(self, limit: int | None = None, order_by: str | None = None, start=0, override_schema=None):
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
if order_by:
|
||||
order_attr = getattr(self.model, str(order_by))
|
||||
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.model)
|
||||
stmt = (
|
||||
select(self.model)
|
||||
.join(RecipeSettings)
|
||||
.filter(RecipeSettings.public == True) # noqa: 711
|
||||
.order_by(order_attr.desc())
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.all()
|
||||
]
|
||||
)
|
||||
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.model)
|
||||
stmt = (
|
||||
select(self.model)
|
||||
.join(RecipeSettings)
|
||||
.filter(RecipeSettings.public == True) # noqa: 711
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.all()
|
||||
]
|
||||
)
|
||||
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def update_image(self, slug: str, _: str | None = None) -> int:
|
||||
entry: RecipeModel = self._query_one(match_value=slug)
|
||||
@@ -100,7 +98,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
|
||||
def summary(
|
||||
self, group_id, start=0, limit=99999, load_foods=False, order_by="created_at", order_descending=True
|
||||
) -> Any:
|
||||
) -> Sequence[RecipeModel]:
|
||||
args = [
|
||||
joinedload(RecipeModel.recipe_category),
|
||||
joinedload(RecipeModel.tags),
|
||||
@@ -126,15 +124,15 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
else:
|
||||
order_attr = order_attr.asc()
|
||||
|
||||
return (
|
||||
self.session.query(RecipeModel)
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
.options(*args)
|
||||
.filter(RecipeModel.group_id == group_id)
|
||||
.order_by(order_attr)
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def page_all(
|
||||
self,
|
||||
@@ -145,8 +143,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
categories: list[UUID4 | str] | None = None,
|
||||
tags: list[UUID4 | str] | None = None,
|
||||
tools: list[UUID4 | str] | None = None,
|
||||
) -> PaginationBase[RecipeSummary]:
|
||||
q = self.session.query(self.model)
|
||||
) -> RecipePagination:
|
||||
q = select(self.model)
|
||||
|
||||
args = [
|
||||
joinedload(RecipeModel.recipe_category),
|
||||
@@ -154,6 +152,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
joinedload(RecipeModel.tools),
|
||||
]
|
||||
|
||||
item_class: type[RecipeSummary | RecipeSummaryWithIngredients]
|
||||
|
||||
if load_food:
|
||||
args.append(joinedload(RecipeModel.recipe_ingredient).options(joinedload(RecipeIngredient.food)))
|
||||
args.append(joinedload(RecipeModel.recipe_ingredient).options(joinedload(RecipeIngredient.unit)))
|
||||
@@ -205,14 +205,14 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
q, count, total_pages = self.add_pagination_to_query(q, pagination)
|
||||
|
||||
try:
|
||||
data = q.all()
|
||||
data = self.session.execute(q).scalars().unique().all()
|
||||
except Exception as e:
|
||||
self._log_exception(e)
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
items = [item_class.from_orm(item) for item in data]
|
||||
return PaginationBase(
|
||||
return RecipePagination(
|
||||
page=pagination.page,
|
||||
per_page=pagination.per_page,
|
||||
total=count,
|
||||
@@ -226,14 +226,12 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
"""
|
||||
|
||||
ids = [x.id for x in categories]
|
||||
|
||||
return [
|
||||
RecipeSummary.from_orm(x)
|
||||
for x in self.session.query(RecipeModel)
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
.join(RecipeModel.recipe_category)
|
||||
.filter(RecipeModel.recipe_category.any(Category.id.in_(ids)))
|
||||
.all()
|
||||
]
|
||||
)
|
||||
return [RecipeSummary.from_orm(x) for x in self.session.execute(stmt).unique().scalars().all()]
|
||||
|
||||
def _category_tag_filters(
|
||||
self,
|
||||
@@ -284,8 +282,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
fltr = self._category_tag_filters(
|
||||
categories, tags, tools, require_all_categories, require_all_tags, require_all_tools
|
||||
)
|
||||
|
||||
return [self.schema.from_orm(x) for x in self.session.query(RecipeModel).filter(*fltr).all()]
|
||||
stmt = select(RecipeModel).filter(*fltr)
|
||||
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_random_by_categories_and_tags(
|
||||
self, categories: list[RecipeCategory], tags: list[RecipeTag]
|
||||
@@ -300,33 +298,27 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
# - https://stackoverflow.com/questions/60805/getting-random-row-through-sqlalchemy
|
||||
|
||||
filters = self._category_tag_filters(categories, tags) # type: ignore
|
||||
|
||||
return [
|
||||
self.schema.from_orm(x)
|
||||
for x in self.session.query(RecipeModel)
|
||||
.filter(and_(*filters))
|
||||
.order_by(func.random()) # Postgres and SQLite specific
|
||||
.limit(1)
|
||||
]
|
||||
stmt = (
|
||||
select(RecipeModel).filter(and_(*filters)).order_by(func.random()).limit(1) # Postgres and SQLite specific
|
||||
)
|
||||
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_random(self, limit=1) -> list[Recipe]:
|
||||
return [
|
||||
self.schema.from_orm(x)
|
||||
for x in self.session.query(RecipeModel)
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
.filter(RecipeModel.group_id == self.group_id)
|
||||
.order_by(func.random()) # Postgres and SQLite specific
|
||||
.limit(limit)
|
||||
]
|
||||
)
|
||||
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_by_slug(self, group_id: UUID4, slug: str, limit=1) -> Recipe | None:
|
||||
dbrecipe = (
|
||||
self.session.query(RecipeModel)
|
||||
.filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
|
||||
dbrecipe = self.session.execute(stmt).scalars().one_or_none()
|
||||
if dbrecipe is None:
|
||||
return None
|
||||
return self.schema.from_orm(dbrecipe)
|
||||
|
||||
def all_ids(self, group_id: UUID4) -> list[UUID4]:
|
||||
return [tpl[0] for tpl in self.session.query(RecipeModel.id).filter(RecipeModel.group_id == group_id).all()]
|
||||
def all_ids(self, group_id: UUID4) -> Sequence[UUID4]:
|
||||
stmt = select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import UUID4
|
||||
from sqlalchemy import select
|
||||
|
||||
from mealie.db.models.recipe.ingredient import IngredientUnitModel
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit
|
||||
@@ -7,15 +8,13 @@ from .repository_generic import RepositoryGeneric
|
||||
|
||||
|
||||
class RepositoryUnit(RepositoryGeneric[IngredientUnit, IngredientUnitModel]):
|
||||
def _get_unit(self, id: UUID4) -> IngredientUnitModel:
|
||||
stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id}))
|
||||
return self.session.execute(stmt).scalars().one()
|
||||
|
||||
def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None:
|
||||
|
||||
from_model: IngredientUnitModel = (
|
||||
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_unit})).one()
|
||||
)
|
||||
|
||||
to_model: IngredientUnitModel = (
|
||||
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_unit})).one()
|
||||
)
|
||||
from_model = self._get_unit(from_unit)
|
||||
to_model = self._get_unit(to_unit)
|
||||
|
||||
to_model.ingredients += from_model.ingredients
|
||||
|
||||
@@ -29,4 +28,4 @@ class RepositoryUnit(RepositoryGeneric[IngredientUnit, IngredientUnitModel]):
|
||||
return self.get_one(to_unit)
|
||||
|
||||
def by_group(self, group_id: UUID4) -> "RepositoryUnit":
|
||||
return super().by_group(group_id) # type: ignore
|
||||
return super().by_group(group_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import random
|
||||
import shutil
|
||||
|
||||
from pydantic import UUID4
|
||||
from sqlalchemy import select
|
||||
|
||||
from mealie.assets import users as users_assets
|
||||
from mealie.schema.user.user import PrivateUser, User
|
||||
@@ -35,12 +36,14 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
|
||||
entry = super().delete(value, match_key)
|
||||
# Delete the user's directory
|
||||
shutil.rmtree(PrivateUser.get_directory(value))
|
||||
return entry # type: ignore
|
||||
return entry
|
||||
|
||||
def get_by_username(self, username: str) -> PrivateUser | None:
|
||||
dbuser = self.session.query(User).filter(User.username == username).one_or_none()
|
||||
stmt = select(User).filter(User.username == username)
|
||||
dbuser = self.session.execute(stmt).scalars().one_or_none()
|
||||
return None if dbuser is None else self.schema.from_orm(dbuser)
|
||||
|
||||
def get_locked_users(self) -> list[PrivateUser]:
|
||||
results = self.session.query(User).filter(User.locked_at != None).all() # noqa E711
|
||||
stmt = select(User).filter(User.locked_at != None) # noqa E711
|
||||
results = self.session.execute(stmt).scalars().all()
|
||||
return [self.schema.from_orm(x) for x in results]
|
||||
|
||||
Reference in New Issue
Block a user