AI-Discord-Bot/src/ai.py

119 lines
4.4 KiB
Python
Raw Normal View History

# ai.py
# This file handles all AI interactions, including loading/unloading models,
# generating responses, and injecting personas using the Ollama API.
import os
import requests
import re
from dotenv import load_dotenv
from personality import load_persona
from logger import setup_logger
# Set up logger specifically for AI operations
logger = setup_logger("ai")
# Load environment variables from .env file
load_dotenv()
# Base API setup from .env (e.g., http://localhost:11434/api)
BASE_API = os.getenv("OLLAMA_API").rstrip("/") # Remove trailing slash just in case
# API endpoints for different Ollama operations
GEN_ENDPOINT = f"{BASE_API}/generate"
PULL_ENDPOINT = f"{BASE_API}/pull"
# UNLOAD_ENDPOINT is not used because unloading is done via `generate` with keep_alive=0
TAGS_ENDPOINT = f"{BASE_API}/tags"
# Startup model and debug toggle from .env
MODEL_NAME = os.getenv("MODEL_NAME", "llama3:latest")
SHOW_THINKING_BLOCKS = os.getenv("SHOW_THINKING_BLOCKS", "false").lower() == "true"
# Ensure API base is configured
if not BASE_API:
logger.error("❌ OLLAMA_API not set.")
raise ValueError("❌ OLLAMA_API not set.")
# Returns current model from env/config
def get_model_name():
return MODEL_NAME
# Removes <think>...</think> blocks from the LLM response (used by some models)
def strip_thinking_block(text: str) -> str:
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
# Check if a model exists locally by calling /tags
def model_exists_locally(model_name: str) -> bool:
try:
resp = requests.get(TAGS_ENDPOINT)
return model_name in resp.text
except Exception as e:
logger.error(f"❌ Failed to check local models: {e}")
return False
# Attempt to pull (load) a model via Ollama's /pull endpoint
def load_model(model_name: str) -> bool:
try:
logger.info(f"🧠 Preloading model: {model_name}")
resp = requests.post(PULL_ENDPOINT, json={"name": model_name})
logger.info(f"📨 Ollama pull response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during model load: {str(e)}")
return False
# Send an empty prompt to unload a model from VRAM safely using keep_alive: 0
def unload_model(model_name: str) -> bool:
try:
logger.info(f"🧹 Sending safe unload request for `{model_name}`")
payload = {
"model": model_name,
"prompt": "", # ✅ Required to make the request valid
"keep_alive": 0 # ✅ Unload from VRAM but keep on disk
}
resp = requests.post(GEN_ENDPOINT, json=payload)
logger.info(f"🧽 Ollama unload response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during soft-unload: {str(e)}")
return False
# Shortcut for getting the current model (can be expanded later for dynamic switching)
def get_current_model():
return get_model_name()
# Main LLM interaction — injects personality and sends prompt to Ollama
def get_ai_response(user_prompt):
model_name = get_model_name()
load_model(model_name) # Ensures the model is pulled and ready
persona = load_persona()
if persona:
# Clean fancy quotes and build final prompt with character injection
safe_inject = persona["prompt_inject"].replace("", "\"").replace("", "\"").replace("", "'")
full_prompt = f"{safe_inject}\nUser: {user_prompt}\n{persona['name']}:"
else:
full_prompt = user_prompt # fallback to raw prompt if no persona loaded
payload = {
"model": model_name, # 🔧 Suggested fix: previously hardcoded to MODEL_NAME
"prompt": full_prompt,
"stream": False
# optional: add "keep_alive": 300 to keep model warm
}
logger.info("🛰️ SENDING TO OLLAMA /generate")
logger.info(f"Payload: {payload}")
try:
response = requests.post(GEN_ENDPOINT, json=payload)
logger.info(f"📨 Raw response: {response.text}")
if response.status_code == 200:
result = response.json()
response_text = result.get("response", "[No message in response]")
return strip_thinking_block(response_text) if not SHOW_THINKING_BLOCKS else response_text
else:
return f"[Error {response.status_code}] {response.text}"
except Exception as e:
return f"[Exception] {str(e)}"