AI-Discord-Bot/src/ai.py

158 lines
6.3 KiB
Python
Raw Normal View History

# ai.py
# This file handles all AI interactions, including loading/unloading models,
# generating responses, and injecting personas using the Ollama API.
import os
import requests
import re
from dotenv import load_dotenv
from personality import load_persona
from user_profiles import format_profile_for_block
2025-09-20 12:00:55 -04:00
from logger import setup_logger, generate_req_id, log_llm_request, log_llm_response
debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true"
# Set up logger specifically for AI operations
logger = setup_logger("ai")
# Load environment variables from .env file
load_dotenv()
# Base API setup from .env (e.g., http://localhost:11434/api)
2025-09-20 12:00:55 -04:00
# Normalize to ensure the configured base includes the `/api` prefix so
# endpoints like `/generate` and `/tags` are reachable even if the user
# sets `OLLAMA_API` without `/api`.
raw_api = os.getenv("OLLAMA_API") or ""
raw_api = raw_api.rstrip("/")
if raw_api == "":
BASE_API = ""
else:
BASE_API = raw_api if raw_api.endswith("/api") else f"{raw_api}/api"
# API endpoints for different Ollama operations
GEN_ENDPOINT = f"{BASE_API}/generate"
PULL_ENDPOINT = f"{BASE_API}/pull"
# UNLOAD_ENDPOINT is not used because unloading is done via `generate` with keep_alive=0
TAGS_ENDPOINT = f"{BASE_API}/tags"
# Startup model and debug toggle from .env
MODEL_NAME = os.getenv("MODEL_NAME", "llama3:latest")
SHOW_THINKING_BLOCKS = os.getenv("SHOW_THINKING_BLOCKS", "false").lower() == "true"
# Ensure API base is configured
if not BASE_API:
logger.error("❌ OLLAMA_API not set.")
raise ValueError("❌ OLLAMA_API not set.")
# Returns current model from env/config
def get_model_name():
return MODEL_NAME
# Removes <think>...</think> blocks from the LLM response (used by some models)
def strip_thinking_block(text: str) -> str:
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
# Check if a model exists locally by calling /tags
def model_exists_locally(model_name: str) -> bool:
try:
resp = requests.get(TAGS_ENDPOINT)
return model_name in resp.text
except Exception as e:
logger.error(f"❌ Failed to check local models: {e}")
return False
# Attempt to pull (load) a model via Ollama's /pull endpoint
def load_model(model_name: str) -> bool:
try:
logger.info(f"🧠 Preloading model: {model_name}")
resp = requests.post(PULL_ENDPOINT, json={"name": model_name})
if debug_mode:
logger.debug(f"📨 Ollama pull response: {resp.status_code} - {resp.text}")
else:
if resp.status_code == 200:
logger.info("📦 Model pull started successfully.")
else:
logger.warning(f"⚠️ Model pull returned {resp.status_code}: {resp.text[:100]}...")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during model load: {str(e)}")
return False
# Send an empty prompt to unload a model from VRAM safely using keep_alive: 0
def unload_model(model_name: str) -> bool:
try:
logger.info(f"🧹 Sending safe unload request for `{model_name}`")
payload = {
"model": model_name,
"prompt": "", # ✅ Required to make the request valid
"keep_alive": 0 # ✅ Unload from VRAM but keep on disk
}
resp = requests.post(GEN_ENDPOINT, json=payload)
logger.info(f"🧽 Ollama unload response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during soft-unload: {str(e)}")
return False
# Shortcut for getting the current model (can be expanded later for dynamic switching)
def get_current_model():
return get_model_name()
# Main LLM interaction — injects personality and sends prompt to Ollama
def get_ai_response(user_prompt, context=None, user_profile=None):
model_name = get_model_name()
2025-05-14 20:27:49 -04:00
load_model(model_name)
persona = load_persona()
2025-05-14 20:27:49 -04:00
2025-09-20 12:00:55 -04:00
# Build prompt pieces
safe_inject = ""
if persona:
2025-09-20 12:00:55 -04:00
safe_inject = persona["prompt_inject"].replace("", '"').replace("", '"').replace("", "'")
2025-05-14 20:27:49 -04:00
2025-09-20 12:00:55 -04:00
user_block = ""
if user_profile and user_profile.get("custom_prompt"):
2025-09-20 12:00:55 -04:00
user_block = f"[User Instruction]\n{user_profile['custom_prompt']}\n"
context_block = f"[Recent Conversation]\n{context}\n" if context else ""
2025-05-14 20:27:49 -04:00
if persona:
2025-09-20 12:00:55 -04:00
full_prompt = f"{safe_inject}\n{user_block}{context_block}\nUser: {user_prompt}\n{persona['name']}:"
else:
2025-09-20 12:00:55 -04:00
full_prompt = f"{user_block}{context_block}\nUser: {user_prompt}\nResponse:"
2025-09-20 12:00:55 -04:00
payload = {"model": model_name, "prompt": full_prompt, "stream": False}
2025-09-20 12:00:55 -04:00
# Logging: concise info plus debug for full payload/response
req_id = generate_req_id("llm-")
user_label = user_profile.get("display_name") if user_profile else None
log_llm_request(logger, req_id, model_name, user_label, len(context.splitlines()) if context else 0)
logger.debug("%s Sending payload to Ollama: model=%s user=%s", req_id, model_name, user_label)
logger.debug("%s Payload size=%d chars", req_id, len(full_prompt))
2025-09-20 12:00:55 -04:00
import time
start = time.perf_counter()
try:
response = requests.post(GEN_ENDPOINT, json=payload)
2025-09-20 12:00:55 -04:00
duration = time.perf_counter() - start
# Log raw response only at DEBUG to avoid clutter
logger.debug("%s Raw response status=%s", req_id, response.status_code)
logger.debug("%s Raw response body=%s", req_id, getattr(response, "text", ""))
if response.status_code == 200:
result = response.json()
2025-09-20 12:00:55 -04:00
short = (result.get("response") or "").replace("\n", " ")[:240]
log_llm_response(logger, req_id, model_name, duration, short, raw=result)
return result.get("response", "[No message in response]")
else:
2025-09-20 12:00:55 -04:00
# include status in logs and return an error string
log_llm_response(logger, req_id, model_name, duration, f"[Error {response.status_code}]", raw=response.text)
return f"[Error {response.status_code}] {response.text}"
except Exception as e:
2025-09-20 12:00:55 -04:00
duration = time.perf_counter() - start
logger.exception("%s Exception during LLM call", req_id)
log_llm_response(logger, req_id, model_name, duration, f"[Exception] {e}")
return f"[Exception] {str(e)}"