AI-Discord-Bot/src/ai.py

98 lines
3.1 KiB
Python
Raw Normal View History

# ai.py
import os
import requests
import re
from dotenv import load_dotenv
from personality import load_persona
from logger import setup_logger
logger = setup_logger("ai")
load_dotenv()
BASE_API = os.getenv("OLLAMA_API").rstrip("/") # ← ensures no trailing slash issue
GEN_ENDPOINT = f"{BASE_API}/generate"
PULL_ENDPOINT = f"{BASE_API}/pull"
#UNLOAD_ENDPOINT = f"{BASE_API}/unload"
TAGS_ENDPOINT = f"{BASE_API}/tags"
MODEL_NAME = os.getenv("MODEL_NAME", "llama3:latest")
SHOW_THINKING_BLOCKS = os.getenv("SHOW_THINKING_BLOCKS", "false").lower() == "true"
if not BASE_API:
logger.error("❌ OLLAMA_API not set.")
raise ValueError("❌ OLLAMA_API not set.")
def get_model_name():
return MODEL_NAME
def strip_thinking_block(text: str) -> str:
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
def model_exists_locally(model_name: str) -> bool:
try:
resp = requests.get(TAGS_ENDPOINT)
return model_name in resp.text
except Exception as e:
logger.error(f"❌ Failed to check local models: {e}")
return False
def load_model(model_name: str) -> bool:
try:
logger.info(f"🧠 Preloading model: {model_name}")
resp = requests.post(PULL_ENDPOINT, json={"name": model_name})
logger.info(f"📨 Ollama pull response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during model load: {str(e)}")
return False
def unload_model(model_name: str) -> bool:
try:
logger.info(f"🧹 Soft-unloading model from VRAM: {model_name}")
resp = requests.post(GEN_ENDPOINT, json={
"model": model_name,
"keep_alive": 0,
"prompt": ""
})
logger.info(f"🧽 Ollama unload response: {resp.status_code} - {resp.text}")
return resp.status_code == 200
except Exception as e:
logger.error(f"❌ Exception during soft-unload: {str(e)}")
return False
def get_current_model():
return get_model_name()
def get_ai_response(user_prompt):
model_name = get_model_name()
load_model(model_name)
persona = load_persona()
if persona:
safe_inject = persona["prompt_inject"].replace("", "\"").replace("", "\"").replace("", "'")
full_prompt = f"{safe_inject}\nUser: {user_prompt}\n{persona['name']}:"
else:
full_prompt = user_prompt
payload = {
"model": model_name,
"prompt": full_prompt,
"stream": False
}
logger.info("🛰️ SENDING TO OLLAMA /generate")
logger.info(f"Payload: {payload}")
try:
response = requests.post(GEN_ENDPOINT, json=payload)
logger.info(f"📨 Raw response: {response.text}")
if response.status_code == 200:
result = response.json()
response_text = result.get("response", "[No message in response]")
return strip_thinking_block(response_text) if not SHOW_THINKING_BLOCKS else response_text
else:
return f"[Error {response.status_code}] {response.text}"
except Exception as e:
return f"[Exception] {str(e)}"