prs-fleshgolem-2070: feat: sqlalchemy 2.0 (#2096)

* upgrade sqlalchemy to 2.0

* rewrite all db models to sqla 2.0 mapping api

* fix some importing and typing weirdness

* fix types of a lot of nullable columns

* remove get_ref methods

* fix issues found by tests

* rewrite all queries in repository_recipe to 2.0 style

* rewrite all repository queries to 2.0 api

* rewrite all remaining queries to 2.0 api

* remove now-unneeded __allow_unmapped__ flag

* remove and fix some unneeded cases of "# type: ignore"

* fix formatting

* bump black version

* run black

* can this please be the last one. okay. just. okay.

* fix repository errors

* remove return

* drop open API validator

---------

Co-authored-by: Sören Busch <fleshgolem@gmx.net>
This commit is contained in:
Hayden
2023-02-06 18:43:12 -09:00
committed by GitHub
parent 91cd00976a
commit 9e77a9f367
86 changed files with 1776 additions and 1572 deletions

View File

@@ -1,5 +1,7 @@
from collections.abc import Sequence
from functools import cached_property
from sqlalchemy import select
from sqlalchemy.orm import Session
from mealie.db.models.group import Group, GroupMealPlan, ReportEntryModel, ReportModel
@@ -70,13 +72,16 @@ PK_GROUP_ID = "group_id"
class RepositoryCategories(RepositoryGeneric[CategoryOut, Category]):
def get_empty(self):
return self.session.query(Category).filter(~Category.recipes.any()).all()
def get_empty(self) -> Sequence[Category]:
stmt = select(Category).filter(~Category.recipes.any())
return self.session.execute(stmt).scalars().all()
class RepositoryTags(RepositoryGeneric[TagOut, Tag]):
def get_empty(self):
return self.session.query(Tag).filter(~Tag.recipes.any()).all()
def get_empty(self) -> Sequence[Tag]:
stmt = select(Tag).filter(~Tag.recipes.any())
return self.session.execute(stmt).scalars().all()
class AllRepositories:

View File

@@ -1,4 +1,5 @@
from pydantic import UUID4
from sqlalchemy import select
from mealie.db.models.recipe.ingredient import IngredientFoodModel
from mealie.schema.recipe.recipe_ingredient import IngredientFood
@@ -7,15 +8,13 @@ from .repository_generic import RepositoryGeneric
class RepositoryFood(RepositoryGeneric[IngredientFood, IngredientFoodModel]):
def _get_food(self, id: UUID4) -> IngredientFoodModel:
stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id}))
return self.session.execute(stmt).scalars().one()
def merge(self, from_food: UUID4, to_food: UUID4) -> IngredientFood | None:
from_model: IngredientFoodModel = (
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_food})).one()
)
to_model: IngredientFoodModel = (
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_food})).one()
)
from_model = self._get_food(from_food)
to_model = self._get_food(to_food)
to_model.ingredients += from_model.ingredients
@@ -29,4 +28,4 @@ class RepositoryFood(RepositoryGeneric[IngredientFood, IngredientFoodModel]):
return self.get_one(to_food)
def by_group(self, group_id: UUID4) -> "RepositoryFood":
return super().by_group(group_id) # type: ignore
return super().by_group(group_id)

View File

@@ -6,21 +6,19 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Query
from sqlalchemy import Select, delete, func, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.schema.response.pagination import (
OrderDirection,
PaginationBase,
PaginationQuery,
)
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
Model = TypeVar("Model", bound=SqlAlchemyBase)
T = TypeVar("T", bound="RepositoryGeneric")
class RepositoryGeneric(Generic[Schema, Model]):
@@ -33,6 +31,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
user_id: UUID4 | None = None
group_id: UUID4 | None = None
session: Session
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
self.session = session
@@ -42,11 +41,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger = get_logger()
def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
def by_user(self: T, user_id: UUID4) -> T:
self.user_id = user_id
return self
def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
def by_group(self: T, group_id: UUID4) -> T:
self.group_id = group_id
return self
@@ -55,7 +54,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger.error(e)
def _query(self):
return self.session.query(self.model)
return select(self.model)
def _filter_builder(self, **kwargs) -> dict[str, Any]:
dct = {}
@@ -98,8 +97,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
except AttributeError:
self.logger.info(f'Attempted to sort by unknown sort property "{order_by}"; ignoring')
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
result = self.session.execute(q.offset(start).limit(limit)).scalars().all()
return [eff_schema.from_orm(x) for x in result]
def multi_query(
self,
@@ -120,7 +119,9 @@ class RepositoryGeneric(Generic[Schema, Model]):
order_attr = order_attr.desc()
q = q.order_by(order_attr)
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
q = q.offset(start).limit(limit)
result = self.session.execute(q).scalars().all()
return [eff_schema.from_orm(x) for x in result]
def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model:
"""
@@ -131,14 +132,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
match_key = self.primary_key
fltr = self._filter_builder(**{match_key: match_value})
return self._query().filter_by(**fltr).one()
return self.session.execute(self._query().filter_by(**fltr)).scalars().one()
def get_one(
self, value: str | int | UUID4, key: str | None = None, any_case=False, override_schema=None
) -> Schema | None:
key = key or self.primary_key
q = self.session.query(self.model)
q = self._query()
if any_case:
search_attr = getattr(self.model, key)
@@ -146,7 +147,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
else:
q = q.filter_by(**self._filter_builder(**{key: value}))
result = q.one_or_none()
result = self.session.execute(q).scalars().one_or_none()
if not result:
return None
@@ -156,7 +157,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def create(self, data: Schema | BaseModel | dict) -> Schema:
data = data if isinstance(data, dict) else data.dict()
new_document = self.model(session=self.session, **data) # type: ignore
new_document = self.model(session=self.session, **data)
self.session.add(new_document)
self.session.commit()
self.session.refresh(new_document)
@@ -167,7 +168,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
new_documents = []
for document in data:
document = document if isinstance(document, dict) else document.dict()
new_document = self.model(session=self.session, **document) # type: ignore
new_document = self.model(session=self.session, **document)
new_documents.append(new_document)
self.session.add_all(new_documents)
@@ -191,7 +192,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(match_value=match_value)
entry.update(session=self.session, **new_data) # type: ignore
entry.update(session=self.session, **new_data)
self.session.commit()
return self.schema.from_orm(entry)
@@ -202,7 +203,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
document_data = document if isinstance(document, dict) else document.dict()
document_data_by_id[document_data["id"]] = document_data
documents_to_update = self._query().filter(self.model.id.in_(list(document_data_by_id.keys()))) # type: ignore
documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
documents_to_update = self.session.execute(documents_to_update_query).scalars().all()
updated_documents = []
for document_to_update in documents_to_update:
@@ -226,7 +228,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def delete(self, value, match_key: str | None = None) -> Schema:
match_key = match_key or self.primary_key
result = self._query().filter_by(**{match_key: value}).one()
result = self.session.execute(self._query().filter_by(**{match_key: value})).scalars().one()
results_as_model = self.schema.from_orm(result)
try:
@@ -239,7 +241,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
return results_as_model
def delete_many(self, values: Iterable) -> Schema:
results = self._query().filter(self.model.id.in_(values)) # type: ignore
query = self._query().filter(self.model.id.in_(values)) # type: ignore
results = self.session.execute(query).scalars().all()
results_as_model = [self.schema.from_orm(result) for result in results]
try:
@@ -256,14 +259,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
return results_as_model # type: ignore
def delete_all(self) -> None:
self._query().delete()
delete(self.model)
self.session.commit()
def count_all(self, match_key=None, match_value=None) -> int:
if None in [match_key, match_value]:
return self._query().count()
else:
return self._query().filter_by(**{match_key: match_value}).count()
q = select(func.count(self.model.id))
if None not in [match_key, match_value]:
q = q.filter_by(**{match_key: match_value})
return self.session.scalar(q)
def _count_attribute(
self,
@@ -274,12 +277,12 @@ class RepositoryGeneric(Generic[Schema, Model]):
) -> int | list[Schema]: # sourcery skip: assign-if-exp
eff_schema = override_schema or self.schema
q = self._query().filter(attribute_name == attr_match)
if count:
return q.count()
q = select(func.count(self.model.id)).filter(attribute_name == attr_match)
return self.session.scalar(q)
else:
return [eff_schema.from_orm(x) for x in q.all()]
q = self._query().filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
"""
@@ -293,14 +296,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
"""
eff_schema = override or self.schema
q = self.session.query(self.model)
q = self._query()
fltr = self._filter_builder()
q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
try:
data = q.all()
data = self.session.execute(q).scalars().all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
@@ -314,7 +317,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
items=[eff_schema.from_orm(s) for s in data],
)
def add_pagination_to_query(self, query: Query, pagination: PaginationQuery) -> tuple[Query, int, int]:
def add_pagination_to_query(self, query: Select, pagination: PaginationQuery) -> tuple[Select, int, int]:
"""
Adds pagination data to an existing query.
@@ -333,7 +336,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e)) from e
count = query.count()
count_query = select(func.count()).select_from(query)
count = self.session.scalar(count_query)
# interpret -1 as "get_all"
if pagination.per_page == -1:

View File

@@ -1,4 +1,5 @@
from pydantic import UUID4
from sqlalchemy import func, select
from mealie.db.models.group import Group
from mealie.db.models.recipe.category import Category
@@ -9,21 +10,26 @@ from mealie.db.models.users.users import User
from mealie.schema.group.group_statistics import GroupStatistics
from mealie.schema.user.user import GroupInDB
from ..db.models._model_base import SqlAlchemyBase
from .repository_generic import RepositoryGeneric
class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]):
def get_by_name(self, name: str, limit=1) -> GroupInDB | Group | None:
dbgroup = self.session.query(self.model).filter_by(**{"name": name}).one_or_none()
def get_by_name(self, name: str) -> GroupInDB | None:
dbgroup = self.session.execute(select(self.model).filter_by(name=name)).scalars().one_or_none()
if dbgroup is None:
return None
return self.schema.from_orm(dbgroup)
def statistics(self, group_id: UUID4) -> GroupStatistics:
def model_count(model: type[SqlAlchemyBase]) -> int:
stmt = select(func.count(model.id)).filter_by(group_id=group_id)
return self.session.scalar(stmt)
return GroupStatistics(
total_recipes=self.session.query(RecipeModel).filter_by(group_id=group_id).count(),
total_users=self.session.query(User).filter_by(group_id=group_id).count(),
total_categories=self.session.query(Category).filter_by(group_id=group_id).count(),
total_tags=self.session.query(Tag).filter_by(group_id=group_id).count(),
total_tools=self.session.query(Tool).filter_by(group_id=group_id).count(),
total_recipes=model_count(RecipeModel),
total_users=model_count(User),
total_categories=model_count(Category),
total_tags=model_count(Tag),
total_tools=model_count(Tool),
)

View File

@@ -1,6 +1,6 @@
from uuid import UUID
from sqlalchemy import or_
from sqlalchemy import or_, select
from mealie.db.models.group.mealplan import GroupMealPlanRules
from mealie.schema.meal_plan.plan_rules import PlanRulesDay, PlanRulesOut, PlanRulesType
@@ -10,10 +10,10 @@ from .repository_generic import RepositoryGeneric
class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules]):
def by_group(self, group_id: UUID) -> "RepositoryMealPlanRules":
return super().by_group(group_id) # type: ignore
return super().by_group(group_id)
def get_rules(self, day: PlanRulesDay, entry_type: PlanRulesType) -> list[PlanRulesOut]:
qry = self.session.query(GroupMealPlanRules).filter(
stmt = select(GroupMealPlanRules).filter(
or_(
GroupMealPlanRules.day == day,
GroupMealPlanRules.day.is_(None),
@@ -26,4 +26,6 @@ class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules
),
)
return [self.schema.from_orm(x) for x in qry.all()]
rules = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in rules]

View File

@@ -1,6 +1,8 @@
from datetime import date
from uuid import UUID
from sqlalchemy import select
from mealie.db.models.group import GroupMealPlan
from mealie.schema.meal_plan.new_meal import ReadPlanEntry
@@ -9,10 +11,10 @@ from .repository_generic import RepositoryGeneric
class RepositoryMeals(RepositoryGeneric[ReadPlanEntry, GroupMealPlan]):
def by_group(self, group_id: UUID) -> "RepositoryMeals":
return super().by_group(group_id) # type: ignore
return super().by_group(group_id)
def get_today(self, group_id: UUID) -> list[ReadPlanEntry]:
today = date.today()
qry = self.session.query(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)
return [self.schema.from_orm(x) for x in qry.all()]
stmt = select(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)
plans = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in plans]

View File

@@ -1,10 +1,10 @@
from collections.abc import Sequence
from random import randint
from typing import Any
from uuid import UUID
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import and_, func
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
@@ -18,13 +18,14 @@ from mealie.schema.cookbook.cookbook import ReadCookBook
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import (
RecipeCategory,
RecipePagination,
RecipeSummary,
RecipeSummaryWithIngredients,
RecipeTag,
RecipeTool,
)
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import PaginationBase, PaginationQuery
from mealie.schema.response.pagination import PaginationQuery
from .repository_generic import RepositoryGeneric
@@ -46,34 +47,31 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
raise
def by_group(self, group_id: UUID) -> "RepositoryRecipes":
return super().by_group(group_id) # type: ignore
return super().by_group(group_id)
def get_all_public(self, limit: int | None = None, order_by: str | None = None, start=0, override_schema=None):
eff_schema = override_schema or self.schema
if order_by:
order_attr = getattr(self.model, str(order_by))
return [
eff_schema.from_orm(x)
for x in self.session.query(self.model)
stmt = (
select(self.model)
.join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: 711
.order_by(order_attr.desc())
.offset(start)
.limit(limit)
.all()
]
)
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [
eff_schema.from_orm(x)
for x in self.session.query(self.model)
stmt = (
select(self.model)
.join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: 711
.offset(start)
.limit(limit)
.all()
]
)
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
def update_image(self, slug: str, _: str | None = None) -> int:
entry: RecipeModel = self._query_one(match_value=slug)
@@ -100,7 +98,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
def summary(
self, group_id, start=0, limit=99999, load_foods=False, order_by="created_at", order_descending=True
) -> Any:
) -> Sequence[RecipeModel]:
args = [
joinedload(RecipeModel.recipe_category),
joinedload(RecipeModel.tags),
@@ -126,15 +124,15 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
else:
order_attr = order_attr.asc()
return (
self.session.query(RecipeModel)
stmt = (
select(RecipeModel)
.options(*args)
.filter(RecipeModel.group_id == group_id)
.order_by(order_attr)
.offset(start)
.limit(limit)
.all()
)
return self.session.execute(stmt).scalars().all()
def page_all(
self,
@@ -145,8 +143,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
categories: list[UUID4 | str] | None = None,
tags: list[UUID4 | str] | None = None,
tools: list[UUID4 | str] | None = None,
) -> PaginationBase[RecipeSummary]:
q = self.session.query(self.model)
) -> RecipePagination:
q = select(self.model)
args = [
joinedload(RecipeModel.recipe_category),
@@ -154,6 +152,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
joinedload(RecipeModel.tools),
]
item_class: type[RecipeSummary | RecipeSummaryWithIngredients]
if load_food:
args.append(joinedload(RecipeModel.recipe_ingredient).options(joinedload(RecipeIngredient.food)))
args.append(joinedload(RecipeModel.recipe_ingredient).options(joinedload(RecipeIngredient.unit)))
@@ -205,14 +205,14 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
q, count, total_pages = self.add_pagination_to_query(q, pagination)
try:
data = q.all()
data = self.session.execute(q).scalars().unique().all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
raise e
items = [item_class.from_orm(item) for item in data]
return PaginationBase(
return RecipePagination(
page=pagination.page,
per_page=pagination.per_page,
total=count,
@@ -226,14 +226,12 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
"""
ids = [x.id for x in categories]
return [
RecipeSummary.from_orm(x)
for x in self.session.query(RecipeModel)
stmt = (
select(RecipeModel)
.join(RecipeModel.recipe_category)
.filter(RecipeModel.recipe_category.any(Category.id.in_(ids)))
.all()
]
)
return [RecipeSummary.from_orm(x) for x in self.session.execute(stmt).unique().scalars().all()]
def _category_tag_filters(
self,
@@ -284,8 +282,8 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
fltr = self._category_tag_filters(
categories, tags, tools, require_all_categories, require_all_tags, require_all_tools
)
return [self.schema.from_orm(x) for x in self.session.query(RecipeModel).filter(*fltr).all()]
stmt = select(RecipeModel).filter(*fltr)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
def get_random_by_categories_and_tags(
self, categories: list[RecipeCategory], tags: list[RecipeTag]
@@ -300,33 +298,27 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
# - https://stackoverflow.com/questions/60805/getting-random-row-through-sqlalchemy
filters = self._category_tag_filters(categories, tags) # type: ignore
return [
self.schema.from_orm(x)
for x in self.session.query(RecipeModel)
.filter(and_(*filters))
.order_by(func.random()) # Postgres and SQLite specific
.limit(1)
]
stmt = (
select(RecipeModel).filter(and_(*filters)).order_by(func.random()).limit(1) # Postgres and SQLite specific
)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
def get_random(self, limit=1) -> list[Recipe]:
return [
self.schema.from_orm(x)
for x in self.session.query(RecipeModel)
stmt = (
select(RecipeModel)
.filter(RecipeModel.group_id == self.group_id)
.order_by(func.random()) # Postgres and SQLite specific
.limit(limit)
]
)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
def get_by_slug(self, group_id: UUID4, slug: str, limit=1) -> Recipe | None:
dbrecipe = (
self.session.query(RecipeModel)
.filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
.one_or_none()
)
stmt = select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
dbrecipe = self.session.execute(stmt).scalars().one_or_none()
if dbrecipe is None:
return None
return self.schema.from_orm(dbrecipe)
def all_ids(self, group_id: UUID4) -> list[UUID4]:
return [tpl[0] for tpl in self.session.query(RecipeModel.id).filter(RecipeModel.group_id == group_id).all()]
def all_ids(self, group_id: UUID4) -> Sequence[UUID4]:
stmt = select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
return self.session.execute(stmt).scalars().all()

View File

@@ -1,4 +1,5 @@
from pydantic import UUID4
from sqlalchemy import select
from mealie.db.models.recipe.ingredient import IngredientUnitModel
from mealie.schema.recipe.recipe_ingredient import IngredientUnit
@@ -7,15 +8,13 @@ from .repository_generic import RepositoryGeneric
class RepositoryUnit(RepositoryGeneric[IngredientUnit, IngredientUnitModel]):
def _get_unit(self, id: UUID4) -> IngredientUnitModel:
stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id}))
return self.session.execute(stmt).scalars().one()
def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None:
from_model: IngredientUnitModel = (
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_unit})).one()
)
to_model: IngredientUnitModel = (
self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_unit})).one()
)
from_model = self._get_unit(from_unit)
to_model = self._get_unit(to_unit)
to_model.ingredients += from_model.ingredients
@@ -29,4 +28,4 @@ class RepositoryUnit(RepositoryGeneric[IngredientUnit, IngredientUnitModel]):
return self.get_one(to_unit)
def by_group(self, group_id: UUID4) -> "RepositoryUnit":
return super().by_group(group_id) # type: ignore
return super().by_group(group_id)

View File

@@ -2,6 +2,7 @@ import random
import shutil
from pydantic import UUID4
from sqlalchemy import select
from mealie.assets import users as users_assets
from mealie.schema.user.user import PrivateUser, User
@@ -35,12 +36,14 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
entry = super().delete(value, match_key)
# Delete the user's directory
shutil.rmtree(PrivateUser.get_directory(value))
return entry # type: ignore
return entry
def get_by_username(self, username: str) -> PrivateUser | None:
dbuser = self.session.query(User).filter(User.username == username).one_or_none()
stmt = select(User).filter(User.username == username)
dbuser = self.session.execute(stmt).scalars().one_or_none()
return None if dbuser is None else self.schema.from_orm(dbuser)
def get_locked_users(self) -> list[PrivateUser]:
results = self.session.query(User).filter(User.locked_at != None).all() # noqa E711
stmt = select(User).filter(User.locked_at != None) # noqa E711
results = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in results]