mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-05-26 11:40:27 -04:00
feat: In-app AI Provider Configuration (#7650)
This commit is contained in:
206
tests/unit_tests/schema_tests/test_ai_providers.py
Normal file
206
tests/unit_tests/schema_tests/test_ai_providers.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mealie.schema.group.ai_providers import (
|
||||
AIProviderCreate,
|
||||
AIProviderSettingsOut,
|
||||
AIProviderSummary,
|
||||
)
|
||||
|
||||
|
||||
class AIProviderCreateTests:
|
||||
def test_valid_create(self):
|
||||
provider = AIProviderCreate(name="test", api_key="key", model="gpt-4o")
|
||||
assert provider.name == "test"
|
||||
assert provider.model == "gpt-4o"
|
||||
assert provider.timeout == 300
|
||||
assert provider.base_url is None
|
||||
|
||||
@pytest.mark.parametrize("field", ["name", "api_key", "model"])
|
||||
def test_empty_field_raises(self, field: str):
|
||||
data: dict = {"name": "test", "api_key": "key", "model": "gpt-4o", field: ""}
|
||||
with pytest.raises(ValidationError):
|
||||
AIProviderCreate(**data)
|
||||
|
||||
@pytest.mark.parametrize("timeout", [-1, -100])
|
||||
def test_negative_timeout_raises(self, timeout: int):
|
||||
with pytest.raises(ValidationError):
|
||||
AIProviderCreate(name="test", api_key="key", model="gpt-4o", timeout=timeout)
|
||||
|
||||
def test_zero_timeout_is_valid(self):
|
||||
provider = AIProviderCreate(name="test", api_key="key", model="gpt-4o", timeout=0)
|
||||
assert provider.timeout == 0
|
||||
|
||||
@pytest.mark.parametrize("base_url", ["", None])
|
||||
def test_base_url_empty_becomes_none(self, base_url: str | None):
|
||||
provider = AIProviderCreate(name="test", api_key="key", model="gpt-4o", base_url=base_url)
|
||||
assert provider.base_url is None
|
||||
|
||||
def test_api_key_excluded_from_serialization(self):
|
||||
provider = AIProviderCreate(name="test", api_key="secret", model="gpt-4o")
|
||||
dumped = provider.model_dump()
|
||||
assert "api_key" not in dumped
|
||||
|
||||
def test_api_key_excluded_from_json(self):
|
||||
provider = AIProviderCreate(name="test", api_key="secret", model="gpt-4o")
|
||||
json_str = provider.model_dump_json()
|
||||
assert "api_key" not in json_str
|
||||
assert "secret" not in json_str
|
||||
|
||||
|
||||
class AIProviderSettingsOutTests:
|
||||
def _make_settings(
|
||||
self,
|
||||
*,
|
||||
default_provider_id=None,
|
||||
audio_provider_id=None,
|
||||
image_provider_id=None,
|
||||
providers=None,
|
||||
) -> AIProviderSettingsOut:
|
||||
if providers is None:
|
||||
providers = []
|
||||
return AIProviderSettingsOut(
|
||||
default_provider_id=default_provider_id,
|
||||
audio_provider_id=audio_provider_id,
|
||||
image_provider_id=image_provider_id,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
# --- ai_enabled ---
|
||||
|
||||
def test_ai_enabled_false_when_no_default(self):
|
||||
s = self._make_settings()
|
||||
assert not s.ai_enabled
|
||||
|
||||
def test_ai_enabled_true_when_default_set(self):
|
||||
pid = uuid4()
|
||||
s = self._make_settings(default_provider_id=pid, providers=[AIProviderSummary(id=pid, name="p")])
|
||||
assert s.ai_enabled
|
||||
|
||||
# --- audio_provider_enabled ---
|
||||
|
||||
def test_audio_provider_disabled_when_no_default(self):
|
||||
audio_id = uuid4()
|
||||
s = self._make_settings(
|
||||
audio_provider_id=audio_id,
|
||||
providers=[AIProviderSummary(id=audio_id, name="audio")],
|
||||
)
|
||||
# audio_provider_id is valid, but validate_providers sets audio_provider_id to None
|
||||
# because without default_provider_id, it would be fine; let's test audio_provider_enabled
|
||||
# which requires ai_enabled to be True
|
||||
assert not s.ai_enabled
|
||||
assert not s.audio_provider_enabled
|
||||
|
||||
def test_audio_provider_disabled_when_only_default_set(self):
|
||||
pid = uuid4()
|
||||
s = self._make_settings(default_provider_id=pid, providers=[AIProviderSummary(id=pid, name="p")])
|
||||
assert s.ai_enabled
|
||||
assert not s.audio_provider_enabled
|
||||
|
||||
def test_audio_provider_enabled_when_both_set(self):
|
||||
pid = uuid4()
|
||||
audio_id = uuid4()
|
||||
s = self._make_settings(
|
||||
default_provider_id=pid,
|
||||
audio_provider_id=audio_id,
|
||||
providers=[AIProviderSummary(id=pid, name="p"), AIProviderSummary(id=audio_id, name="audio")],
|
||||
)
|
||||
assert s.ai_enabled
|
||||
assert s.audio_provider_enabled
|
||||
|
||||
# --- image_provider_enabled ---
|
||||
|
||||
def test_image_provider_disabled_when_no_default(self):
|
||||
image_id = uuid4()
|
||||
s = self._make_settings(
|
||||
image_provider_id=image_id,
|
||||
providers=[AIProviderSummary(id=image_id, name="img")],
|
||||
)
|
||||
assert not s.ai_enabled
|
||||
assert not s.image_provider_enabled
|
||||
|
||||
def test_image_provider_disabled_when_only_default_set(self):
|
||||
pid = uuid4()
|
||||
s = self._make_settings(default_provider_id=pid, providers=[AIProviderSummary(id=pid, name="p")])
|
||||
assert s.ai_enabled
|
||||
assert not s.image_provider_enabled
|
||||
|
||||
def test_image_provider_enabled_when_both_set(self):
|
||||
pid = uuid4()
|
||||
image_id = uuid4()
|
||||
s = self._make_settings(
|
||||
default_provider_id=pid,
|
||||
image_provider_id=image_id,
|
||||
providers=[AIProviderSummary(id=pid, name="p"), AIProviderSummary(id=image_id, name="img")],
|
||||
)
|
||||
assert s.ai_enabled
|
||||
assert s.image_provider_enabled
|
||||
|
||||
# --- validate_providers model validator ---
|
||||
|
||||
def test_validate_providers_strips_unknown_default(self):
|
||||
s = self._make_settings(default_provider_id=uuid4(), providers=[])
|
||||
assert s.default_provider_id is None
|
||||
assert not s.ai_enabled
|
||||
|
||||
def test_validate_providers_strips_unknown_audio(self):
|
||||
pid = uuid4()
|
||||
providers = [AIProviderSummary(id=pid, name="p")]
|
||||
s = self._make_settings(default_provider_id=pid, audio_provider_id=uuid4(), providers=providers)
|
||||
assert s.default_provider_id == pid
|
||||
assert s.audio_provider_id is None
|
||||
|
||||
def test_validate_providers_strips_unknown_image(self):
|
||||
pid = uuid4()
|
||||
providers = [AIProviderSummary(id=pid, name="p")]
|
||||
s = self._make_settings(default_provider_id=pid, image_provider_id=uuid4(), providers=providers)
|
||||
assert s.default_provider_id == pid
|
||||
assert s.image_provider_id is None
|
||||
|
||||
def test_validate_providers_keeps_valid_ids(self):
|
||||
pid = uuid4()
|
||||
audio_id = uuid4()
|
||||
image_id = uuid4()
|
||||
providers = [
|
||||
AIProviderSummary(id=pid, name="p"),
|
||||
AIProviderSummary(id=audio_id, name="audio"),
|
||||
AIProviderSummary(id=image_id, name="img"),
|
||||
]
|
||||
s = self._make_settings(
|
||||
default_provider_id=pid,
|
||||
audio_provider_id=audio_id,
|
||||
image_provider_id=image_id,
|
||||
providers=providers,
|
||||
)
|
||||
assert s.default_provider_id == pid
|
||||
assert s.audio_provider_id == audio_id
|
||||
assert s.image_provider_id == image_id
|
||||
|
||||
def test_validate_providers_strips_all_if_empty_list(self):
|
||||
pid = uuid4()
|
||||
s = self._make_settings(
|
||||
default_provider_id=pid,
|
||||
audio_provider_id=uuid4(),
|
||||
image_provider_id=uuid4(),
|
||||
providers=[],
|
||||
)
|
||||
assert s.default_provider_id is None
|
||||
assert s.audio_provider_id is None
|
||||
assert s.image_provider_id is None
|
||||
|
||||
def test_validate_providers_partial_strip(self):
|
||||
"""Only the IDs pointing to missing providers are stripped."""
|
||||
pid = uuid4()
|
||||
audio_id = uuid4()
|
||||
providers = [AIProviderSummary(id=pid, name="p"), AIProviderSummary(id=audio_id, name="audio")]
|
||||
s = self._make_settings(
|
||||
default_provider_id=pid,
|
||||
audio_provider_id=audio_id,
|
||||
image_provider_id=uuid4(), # not in list → stripped
|
||||
providers=providers,
|
||||
)
|
||||
assert s.default_provider_id == pid
|
||||
assert s.audio_provider_id == audio_id
|
||||
assert s.image_provider_id is None
|
||||
@@ -49,6 +49,12 @@ def test_openai_parser(
|
||||
|
||||
monkeypatch.setattr(OpenAIService, "get_response", mock_get_response)
|
||||
|
||||
def mock_openai_init(self, repos):
|
||||
self.repos = repos
|
||||
self.custom_prompt_dir = None
|
||||
|
||||
monkeypatch.setattr(OpenAIService, "__init__", mock_openai_init)
|
||||
|
||||
with session_context() as session:
|
||||
loop = asyncio.get_event_loop()
|
||||
parser = get_parser(RegisteredParser.openai, unique_local_group_id, session, get_locale_provider())
|
||||
@@ -69,7 +75,7 @@ def test_openai_parser_sanitize_output(
|
||||
parsed_ingredient_data: tuple[list[IngredientFood], list[IngredientUnit]], # required so database is populated
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
async def mock_get_raw_response(self, prompt: str, content: list[dict], response_schema) -> MagicMock:
|
||||
async def mock_get_raw_response(self, prompt: str, content: list[dict], response_schema, provider) -> MagicMock:
|
||||
# Create data with null character in JSON to test preprocessing
|
||||
data = OpenAIIngredients(
|
||||
ingredients=[
|
||||
@@ -91,6 +97,17 @@ def test_openai_parser_sanitize_output(
|
||||
# Mock the raw response here since we want to make sure our service executes processing before loading the model
|
||||
monkeypatch.setattr(OpenAIService, "_get_raw_response", mock_get_raw_response)
|
||||
|
||||
def mock_openai_init(self, repos):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
self.repos = repos
|
||||
self.custom_prompt_dir = None
|
||||
self.default_provider = MagicMock()
|
||||
self.audio_provider = None
|
||||
self.image_provider = None
|
||||
|
||||
monkeypatch.setattr(OpenAIService, "__init__", mock_openai_init)
|
||||
|
||||
with session_context() as session:
|
||||
loop = asyncio.get_event_loop()
|
||||
parser = get_parser(RegisteredParser.openai, unique_local_group_id, session, get_locale_provider())
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_html_with_recipe_data():
|
||||
url = "https://www.bbc.co.uk/food/recipes/healthy_pasta_bake_60759"
|
||||
translator = get_locale_provider()
|
||||
|
||||
open_graph_strategy = RecipeScraperOpenGraph(url, translator)
|
||||
open_graph_strategy = RecipeScraperOpenGraph(url, translator, None) # type: ignore[arg-type]
|
||||
|
||||
recipe_data = open_graph_strategy.get_recipe_fields(path.read_text())
|
||||
|
||||
@@ -78,7 +78,7 @@ def test_clean_scraper_preserves_notes():
|
||||
html = RecipeScraperPackage.ld_json_to_html(ld_json)
|
||||
scraped = scrape_html(html, org_url="https://example.com", supported_only=False)
|
||||
translator = get_locale_provider()
|
||||
strategy = RecipeScraperPackage("https://example.com", translator)
|
||||
strategy = RecipeScraperPackage("https://example.com", translator, None) # type: ignore[arg-type]
|
||||
|
||||
recipe, _ = strategy.clean_scraper(scraped, "https://example.com")
|
||||
|
||||
|
||||
@@ -1,23 +1,28 @@
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import mealie.services.openai.openai as openai_module
|
||||
from mealie.services.openai.openai import OpenAIService
|
||||
|
||||
|
||||
def _make_mock_repos() -> MagicMock:
|
||||
provider_settings = MagicMock()
|
||||
provider_settings.ai_enabled = True
|
||||
provider_settings.default_provider_id = uuid4()
|
||||
provider_settings.audio_provider_id = None
|
||||
provider_settings.image_provider_id = None
|
||||
|
||||
repos = MagicMock()
|
||||
repos.group_id = uuid4()
|
||||
repos.group_ai_provider_settings.get_one.return_value = provider_settings
|
||||
repos.group_ai_providers.get_one.return_value = MagicMock()
|
||||
return repos
|
||||
|
||||
|
||||
class _SettingsStub:
|
||||
OPENAI_ENABLED = True
|
||||
OPENAI_MODEL = "gpt-4o"
|
||||
OPENAI_AUDIO_MODEL = "whisper-1"
|
||||
OPENAI_WORKERS = 1
|
||||
OPENAI_SEND_DATABASE_DATA = False
|
||||
OPENAI_ENABLE_IMAGE_SERVICES = True
|
||||
OPENAI_ENABLE_TRANSCRIPTION_SERVICES = True
|
||||
OPENAI_CUSTOM_PROMPT_DIR: str | None = None
|
||||
OPENAI_BASE_URL: str | None = None
|
||||
OPENAI_API_KEY = "dummy"
|
||||
OPENAI_REQUEST_TIMEOUT = 30
|
||||
OPENAI_CUSTOM_HEADERS: dict = {}
|
||||
OPENAI_CUSTOM_PARAMS: dict = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -39,7 +44,7 @@ def settings_stub(tmp_path, monkeypatch):
|
||||
|
||||
|
||||
def test_get_prompt_default_only(settings_stub):
|
||||
svc = OpenAIService()
|
||||
svc = OpenAIService(_make_mock_repos())
|
||||
out = svc.get_prompt("recipes.parse-recipe-ingredients")
|
||||
assert out == "DEFAULT PROMPT"
|
||||
|
||||
@@ -51,7 +56,7 @@ def test_get_prompt_custom_dir_used(settings_stub, tmp_path):
|
||||
|
||||
settings_stub.OPENAI_CUSTOM_PROMPT_DIR = str(custom_dir)
|
||||
|
||||
svc = OpenAIService()
|
||||
svc = OpenAIService(_make_mock_repos())
|
||||
out = svc.get_prompt("recipes.parse-recipe-ingredients")
|
||||
assert out == "CUSTOM PROMPT"
|
||||
|
||||
@@ -62,7 +67,7 @@ def test_get_prompt_custom_empty_falls_back_to_default(settings_stub, tmp_path):
|
||||
(custom_dir / "recipes" / "parse-recipe-ingredients.txt").write_text("")
|
||||
|
||||
settings_stub.OPENAI_CUSTOM_PROMPT_DIR = str(custom_dir)
|
||||
svc = OpenAIService()
|
||||
svc = OpenAIService(_make_mock_repos())
|
||||
out = svc.get_prompt("recipes.parse-recipe-ingredients")
|
||||
assert out == "DEFAULT PROMPT"
|
||||
|
||||
@@ -73,7 +78,7 @@ def test_get_prompt_raises_when_no_files(settings_stub, monkeypatch):
|
||||
for p in prompts_dir.rglob("*.txt"):
|
||||
p.unlink()
|
||||
|
||||
svc = OpenAIService()
|
||||
svc = OpenAIService(_make_mock_repos())
|
||||
with pytest.raises(OSError) as ei:
|
||||
svc.get_prompt("recipes.parse-recipe-ingredients")
|
||||
assert "Unable to load prompt" in str(ei.value)
|
||||
|
||||
@@ -352,7 +352,6 @@ def test_oidc_settings_validation(data: OIDCValidationCase, monkeypatch: pytest.
|
||||
def test_sensitive_settings_mask(monkeypatch: pytest.MonkeyPatch):
|
||||
sensitive_settings = [
|
||||
"LDAP_QUERY_PASSWORD",
|
||||
"OPENAI_API_KEY",
|
||||
"SMTP_USER",
|
||||
"SMTP_PASSWORD",
|
||||
"OIDC_CLIENT_SECRET",
|
||||
|
||||
Reference in New Issue
Block a user