AI-Discord-Bot/src/ai.py

383 lines
No EOL
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ai.py
# This file handles all AI interactions, including loading/unloading models,
# generating responses, and injecting personas using the Ollama API.
import os
import requests
import re
import yaml
from dotenv import load_dotenv
from personality import load_persona
from user_profiles import format_profile_for_block
from logger import setup_logger, generate_req_id, log_llm_request, log_llm_response
from modelfile import load_modfile_if_exists, parse_mod_file
debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true"
# Set up logger specifically for AI operations
logger = setup_logger("ai")
# Load environment variables from .env file
load_dotenv()
# Load settings.yml to fetch ai.modfile config
try:
settings_path = os.path.join(os.path.dirname(__file__), "settings.yml")
with open(settings_path, "r", encoding="utf-8") as f:
SETTINGS = yaml.safe_load(f)
except Exception:
SETTINGS = {}
# Modelfile config
AI_USE_MODFILE = SETTINGS.get("ai", {}).get("use_modfile", False)
AI_MODFILE_PATH = SETTINGS.get("ai", {}).get("modfile_path")
MODFILE = None
if AI_USE_MODFILE and AI_MODFILE_PATH:
try:
MODFILE = load_modfile_if_exists(AI_MODFILE_PATH)
if MODFILE:
# Resolve includes (best-effort): merge params and append system/template
def _resolve_includes(mod):
merged = dict(mod)
src = merged.get('_source_path')
includes = merged.get('includes', []) or []
base_dir = os.path.dirname(src) if src else os.path.dirname(__file__)
for inc in includes:
try:
# Resolve relative to base_dir
cand = inc if os.path.isabs(inc) else os.path.normpath(os.path.join(base_dir, inc))
if not os.path.exists(cand):
continue
inc_mod = parse_mod_file(cand)
# Merge params (included params do not override main ones)
inc_params = inc_mod.get('params', {}) or {}
for k, v in inc_params.items():
if k not in merged.get('params', {}):
merged.setdefault('params', {})[k] = v
# Append system text if main doesn't have one
if not merged.get('system') and inc_mod.get('system'):
merged['system'] = inc_mod.get('system')
# If main has no template, adopt included template
if not merged.get('template') and inc_mod.get('template'):
merged['template'] = inc_mod.get('template')
except Exception:
continue
return merged
MODFILE = _resolve_includes(MODFILE)
logger.info(f"🔁 Modelfile loaded: {AI_MODFILE_PATH}")
else:
logger.warning(f"⚠️ Modelfile not found or failed to parse: {AI_MODFILE_PATH}")
except Exception as e:
logger.exception("⚠️ Exception while loading modelfile: %s", e)
# If no modelfile explicitly configured, attempt to auto-load a `delta.mod` or
# `delta.json` in common example/persona locations so the bot has a default persona.
if not MODFILE:
for candidate in [
os.path.join(os.path.dirname(__file__), '..', 'examples', 'delta.mod'),
os.path.join(os.path.dirname(__file__), '..', 'examples', 'delta.json'),
os.path.join(os.path.dirname(__file__), '..', 'personas', 'delta.mod'),
]:
try:
mod = load_modfile_if_exists(candidate)
if mod:
MODFILE = mod
logger.info(f"🔁 Auto-loaded default modelfile: {candidate}")
break
except Exception:
continue
def list_modelfiles(search_dirs=None):
"""Return a list of candidate modelfile paths from common locations."""
base_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..'))
if search_dirs is None:
search_dirs = [
os.path.join(base_dir, 'examples'),
os.path.join(base_dir, 'personas'),
os.path.join(base_dir, 'src'),
base_dir,
]
results = []
for d in search_dirs:
try:
if not os.path.isdir(d):
continue
for fname in os.listdir(d):
if fname.endswith('.mod') or fname.endswith('.json'):
results.append(os.path.join(d, fname))
except Exception:
continue
return sorted(results)
# Base API setup from .env (e.g., http://localhost:11434/api)
# Normalize to ensure the configured base includes the `/api` prefix so
# endpoints like `/generate` and `/tags` are reachable even if the user
# sets `OLLAMA_API` without `/api`.
raw_api = os.getenv("OLLAMA_API") or ""
raw_api = raw_api.rstrip("/")
if raw_api == "":
BASE_API = ""
else:
BASE_API = raw_api if raw_api.endswith("/api") else f"{raw_api}/api"
# API endpoints for different Ollama operations
GEN_ENDPOINT = f"{BASE_API}/generate"
PULL_ENDPOINT = f"{BASE_API}/pull"
# UNLOAD_ENDPOINT is not used because unloading is done via `generate` with keep_alive=0
TAGS_ENDPOINT = f"{BASE_API}/tags"
# Startup model and debug toggle from .env
MODEL_NAME = os.getenv("MODEL_NAME", "llama3:latest")
SHOW_THINKING_BLOCKS = os.getenv("SHOW_THINKING_BLOCKS", "false").lower() == "true"
AI_INCLUDE_CONTEXT = os.getenv("AI_INCLUDE_CONTEXT", "true").lower() == "true"
# Ensure API base is configured
if not BASE_API:
logger.error("❌ OLLAMA_API not set.")
raise ValueError("❌ OLLAMA_API not set.")
# Returns current model from env/config
def get_model_name():
return MODEL_NAME
# Removes <think>...</think> blocks from the LLM response (used by some models)
def strip_thinking_block(text: str) -> str:
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
# Check if a model exists locally by calling /tags
def model_exists_locally(model_name: str) -> bool:
try:
resp = requests.get(TAGS_ENDPOINT)
return model_name in resp.text
except Exception as e:
logger.error(f"❌ Failed to check local models: {e}")
return False
# Attempt to pull (load) a model via Ollama's /pull endpoint
def load_model(model_name: str) -> bool:
try:
logger.info(f"🧠 Preloading model: {model_name}")
resp = requests.post(PULL_ENDPOINT, json={"name": model_name})
if debug_mode:
logger.debug(f"📨 Ollama pull response: {resp.status_code} - {resp.text}")
else:
if resp.status_code == 200:
logger.info("📦 Model pull started successfully.")
else:
logger.warning(f"⚠️ Model pull returned {resp.status_code}: {resp.text[:100]}...")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during model load: {str(e)}")
return False
# Send an empty prompt to unload a model from VRAM safely using keep_alive: 0
def unload_model(model_name: str) -> bool:
try:
logger.info(f"🧹 Sending safe unload request for `{model_name}`")
payload = {
"model": model_name,
"prompt": "", # ✅ Required to make the request valid
"keep_alive": 0 # ✅ Unload from VRAM but keep on disk
}
resp = requests.post(GEN_ENDPOINT, json=payload)
logger.info(f"🧽 Ollama unload response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during soft-unload: {str(e)}")
return False
# Shortcut for getting the current model (can be expanded later for dynamic switching)
def get_current_model():
return get_model_name()
# Main LLM interaction — injects personality and sends prompt to Ollama
def get_ai_response(user_prompt, context=None, user_profile=None):
model_name = get_model_name()
load_model(model_name)
persona = load_persona()
# Build prompt pieces
# If a modelfile is active and provides a SYSTEM, prefer it over persona prompt_inject
system_inject = ""
if MODFILE and MODFILE.get('system'):
system_inject = MODFILE.get('system')
elif persona:
system_inject = persona["prompt_inject"].replace("", '"').replace("", '"').replace("", "'")
user_block = ""
if user_profile and user_profile.get("custom_prompt"):
user_block = f"[User Instruction]\n{user_profile['custom_prompt']}\n"
context_block = f"[Recent Conversation]\n{context}\n" if (context and AI_INCLUDE_CONTEXT) else ""
# If a modelfile is active and defines a template, render it (best-effort)
full_prompt = None
if MODFILE:
tpl = MODFILE.get('template')
if tpl:
# Simple template handling: remove simple Go-style conditionals
tpl_work = re.sub(r"\{\{\s*if\s+\.System\s*\}\}", "", tpl)
tpl_work = re.sub(r"\{\{\s*end\s*\}\}", "", tpl_work)
# Build the prompt body we want to inject as .Prompt
prompt_body = f"{user_block}{context_block}User: {user_prompt}\n"
# Replace common placeholders
tpl_work = tpl_work.replace("{{ .System }}", system_inject)
tpl_work = tpl_work.replace("{{ .Prompt }}", prompt_body)
tpl_work = tpl_work.replace("{{ .User }}", user_block)
full_prompt = tpl_work.strip()
else:
# No template: use system_inject and do not append persona name
full_prompt = f"{system_inject}\n{user_block}{context_block}User: {user_prompt}\nResponse:"
else:
# No modelfile active: fall back to persona behaviour (include persona name)
if persona:
full_prompt = f"{system_inject}\n{user_block}{context_block}\nUser: {user_prompt}\n{persona['name']}:"
else:
full_prompt = f"{user_block}{context_block}\nUser: {user_prompt}\nResponse:"
# Build base payload and merge modelfile params if present
payload = {"model": model_name, "prompt": full_prompt, "stream": False}
if MODFILE and MODFILE.get('params'):
for k, v in MODFILE.get('params', {}).items():
payload[k] = v
# Logging: concise info plus debug for full payload/response
req_id = generate_req_id("llm-")
user_label = user_profile.get("display_name") if user_profile else None
log_llm_request(logger, req_id, model_name, user_label, len(context.splitlines()) if context else 0)
logger.debug("%s Sending payload to Ollama: model=%s user=%s", req_id, model_name, user_label)
logger.debug("%s Payload size=%d chars", req_id, len(full_prompt))
import time
start = time.perf_counter()
try:
response = requests.post(GEN_ENDPOINT, json=payload)
duration = time.perf_counter() - start
# Log raw response only at DEBUG to avoid clutter
logger.debug("%s Raw response status=%s", req_id, response.status_code)
logger.debug("%s Raw response body=%s", req_id, getattr(response, "text", ""))
if response.status_code == 200:
result = response.json()
short = (result.get("response") or "").replace("\n", " ")[:240]
log_llm_response(logger, req_id, model_name, duration, short, raw=result)
return result.get("response", "[No message in response]")
else:
# include status in logs and return an error string
log_llm_response(logger, req_id, model_name, duration, f"[Error {response.status_code}]", raw=response.text)
return f"[Error {response.status_code}] {response.text}"
except Exception as e:
duration = time.perf_counter() - start
logger.exception("%s Exception during LLM call", req_id)
log_llm_response(logger, req_id, model_name, duration, f"[Exception] {e}")
return f"[Exception] {str(e)}"
# Runtime modelfile management APIs -------------------------------------------------
def load_modelfile(path: str = None) -> bool:
"""Load (or reload) a modelfile at runtime.
If `path` is provided, update the configured modelfile path and attempt
to load from that location. Returns True on success.
"""
global MODFILE, AI_MODFILE_PATH, AI_USE_MODFILE
if path:
AI_MODFILE_PATH = path
try:
# Enable modelfile usage if it was disabled
AI_USE_MODFILE = True
if not AI_MODFILE_PATH:
logger.warning("⚠️ No modelfile path configured to load.")
return False
mod = load_modfile_if_exists(AI_MODFILE_PATH)
MODFILE = mod
if MODFILE:
logger.info(f"🔁 Modelfile loaded: {AI_MODFILE_PATH}")
return True
else:
logger.warning(f"⚠️ Modelfile not found or failed to parse: {AI_MODFILE_PATH}")
return False
except Exception as e:
logger.exception("⚠️ Exception while loading modelfile: %s", e)
return False
def unload_modelfile() -> bool:
"""Disable/unload the currently active modelfile so persona injection
falls back to the standard `persona.json` mechanism."""
global MODFILE, AI_USE_MODFILE
MODFILE = None
AI_USE_MODFILE = False
logger.info("🔁 Modelfile unloaded/disabled at runtime.")
return True
def get_modelfile_info() -> dict | None:
"""Return a small diagnostic dict about the currently loaded modelfile,
or None if no modelfile is active."""
if not MODFILE:
return None
return {
"_source_path": MODFILE.get("_source_path"),
"base_model": MODFILE.get("base_model"),
"params": MODFILE.get("params"),
"system_preview": (MODFILE.get("system") or "")[:300]
}
def build_dryrun_payload(user_prompt, context=None, user_profile=None) -> dict:
"""Build and return the assembled prompt and payload that would be
sent to the model, without performing any HTTP calls. Useful for
inspecting template rendering and merged modelfile params.
Returns: { 'prompt': str, 'payload': dict }
"""
model_name = get_model_name()
# Reuse main prompt building logic but avoid calling load_model()
persona = load_persona()
# Build prompt pieces (same logic as `get_ai_response`)
system_inject = ""
if MODFILE and MODFILE.get('system'):
system_inject = MODFILE.get('system')
elif persona:
system_inject = persona["prompt_inject"].replace("", '"').replace("", '"').replace("", "'")
user_block = ""
if user_profile and user_profile.get("custom_prompt"):
user_block = f"[User Instruction]\n{user_profile['custom_prompt']}\n"
context_block = f"[Recent Conversation]\n{context}\n" if (context and AI_INCLUDE_CONTEXT) else ""
if MODFILE:
tpl = MODFILE.get('template')
if tpl:
tpl_work = re.sub(r"\{\{\s*if\s+\.System\s*\}\}", "", tpl)
tpl_work = re.sub(r"\{\{\s*end\s*\}\}", "", tpl_work)
prompt_body = f"{user_block}{context_block}User: {user_prompt}\n"
tpl_work = tpl_work.replace("{{ .System }}", system_inject)
tpl_work = tpl_work.replace("{{ .Prompt }}", prompt_body)
tpl_work = tpl_work.replace("{{ .User }}", user_block)
full_prompt = tpl_work.strip()
else:
full_prompt = f"{system_inject}\n{user_block}{context_block}User: {user_prompt}\nResponse:"
else:
if persona:
full_prompt = f"{system_inject}\n{user_block}{context_block}\nUser: {user_prompt}\n{persona['name']}:"
else:
full_prompt = f"{user_block}{context_block}\nUser: {user_prompt}\nResponse:"
# Build payload and merge modelfile params
payload = {"model": model_name, "prompt": full_prompt, "stream": False}
if MODFILE and MODFILE.get('params'):
for k, v in MODFILE.get('params', {}).items():
payload[k] = v
return {"prompt": full_prompt, "payload": payload}