mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-02-02 22:13:11 -05:00
Co-authored-by: Michael Genson <71845777+michael-genson@users.noreply.github.com> Co-authored-by: Michael Genson <genson.michael@gmail.com>
This commit is contained in:
@@ -412,6 +412,11 @@ class AppSettings(AppLoggingSettings):
|
||||
"""
|
||||
The number of seconds to wait for an OpenAI request to complete before cancelling the request
|
||||
"""
|
||||
OPENAI_CUSTOM_PROMPT_DIR: str | None = None
|
||||
"""
|
||||
Path to a folder containing custom prompt files;
|
||||
files are individually optional, each prompt name will fall back to the default if no custom file exists
|
||||
"""
|
||||
|
||||
@property
|
||||
def OPENAI_FEATURE(self) -> FeatureDetails:
|
||||
|
||||
@@ -10,11 +10,14 @@ from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.pkgs import img
|
||||
|
||||
from .._base_service import BaseService
|
||||
|
||||
logger = root_logger.get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIDataInjection(BaseModel):
|
||||
description: str
|
||||
@@ -85,6 +88,7 @@ class OpenAIService(BaseService):
|
||||
self.workers = settings.OPENAI_WORKERS
|
||||
self.send_db_data = settings.OPENAI_SEND_DATABASE_DATA
|
||||
self.enable_image_services = settings.OPENAI_ENABLE_IMAGE_SERVICES
|
||||
self.custom_prompt_dir = settings.OPENAI_CUSTOM_PROMPT_DIR
|
||||
|
||||
self.get_client = lambda: AsyncOpenAI(
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
@@ -96,8 +100,64 @@ class OpenAIService(BaseService):
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_prompt(cls, name: str, data_injections: list[OpenAIDataInjection] | None = None) -> str:
|
||||
def _get_prompt_file_candidates(self, name: str) -> list[Path]:
|
||||
"""
|
||||
Returns a list of prompt file path candidates.
|
||||
First optional entry is the users custom prompt file, if configured and existing,
|
||||
second one (or only one) is the systems default prompt file
|
||||
"""
|
||||
tree = name.split(".")
|
||||
relative_path = Path(*tree[:-1], tree[-1] + ".txt")
|
||||
default_prompt_file = Path(self.PROMPTS_DIR, relative_path)
|
||||
|
||||
try:
|
||||
# Only include custom files if the custom_dir is configured, is a directory, and the prompt file exists
|
||||
custom_dir = Path(self.custom_prompt_dir) if self.custom_prompt_dir else None
|
||||
if custom_dir and not custom_dir.is_dir():
|
||||
custom_dir = None
|
||||
except Exception:
|
||||
custom_dir = None
|
||||
|
||||
if custom_dir:
|
||||
custom_prompt_file = Path(custom_dir, relative_path)
|
||||
if custom_prompt_file.exists():
|
||||
logger.debug(f"Found valid custom prompt file: {custom_prompt_file}")
|
||||
return [custom_prompt_file, default_prompt_file]
|
||||
else:
|
||||
logger.debug(f"Custom prompt file doesn't exist: {custom_prompt_file}")
|
||||
else:
|
||||
logger.debug(f"Custom prompt dir doesn't exist: {custom_dir}")
|
||||
|
||||
# Otherwise, only return the default internal prompt file
|
||||
return [default_prompt_file]
|
||||
|
||||
def _load_prompt_from_file(self, name: str) -> str:
|
||||
"""Attempts to load custom prompt, otherwise falling back to the default"""
|
||||
prompt_file_candidates = self._get_prompt_file_candidates(name)
|
||||
content = None
|
||||
last_error = None
|
||||
for prompt_file in prompt_file_candidates:
|
||||
try:
|
||||
logger.debug(f"Trying to load prompt file: {prompt_file}")
|
||||
with open(prompt_file) as f:
|
||||
content = f.read()
|
||||
if content:
|
||||
logger.debug(f"Successfully read prompt from {prompt_file}")
|
||||
break
|
||||
except OSError as e:
|
||||
last_error = e
|
||||
|
||||
if not content:
|
||||
if last_error:
|
||||
raise OSError(f"Unable to load prompt {name}") from last_error
|
||||
else:
|
||||
# This handles the case where the list was empty (no existing candidates found)
|
||||
attempted_paths = ", ".join(map(str, prompt_file_candidates))
|
||||
raise OSError(f"Unable to load prompt '{name}'. No valid content found in files: {attempted_paths}")
|
||||
|
||||
return content
|
||||
|
||||
def get_prompt(self, name: str, data_injections: list[OpenAIDataInjection] | None = None) -> str:
|
||||
"""
|
||||
Load stored prompt and inject data into it.
|
||||
|
||||
@@ -109,13 +169,7 @@ class OpenAIService(BaseService):
|
||||
if not name:
|
||||
raise ValueError("Prompt name cannot be empty")
|
||||
|
||||
tree = name.split(".")
|
||||
prompt_dir = os.path.join(cls.PROMPTS_DIR, *tree[:-1], tree[-1] + ".txt")
|
||||
try:
|
||||
with open(prompt_dir) as f:
|
||||
content = f.read()
|
||||
except OSError as e:
|
||||
raise OSError(f"Unable to load prompt {name}") from e
|
||||
content = self._load_prompt_from_file(name)
|
||||
|
||||
if not data_injections:
|
||||
return content
|
||||
|
||||
Reference in New Issue
Block a user