135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
import os
|
|
import ast
|
|
import yaml
|
|
import mysql.connector
|
|
from keybert import KeyBERT
|
|
from sentence_transformers import SentenceTransformer
|
|
from collections import Counter
|
|
|
|
# === Load multilingual model for KeyBERT ===
|
|
model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
|
|
kw_model = KeyBERT(model)
|
|
|
|
# === Load label hierarchy from YAML ===
|
|
LABEL_FILE = os.getenv("LABEL_CONFIG_PATH", "labels.yml")
|
|
with open(LABEL_FILE, "r", encoding="utf-8") as f:
|
|
label_config = yaml.safe_load(f)
|
|
|
|
# === DB Credentials ===
|
|
DB_HOST = os.getenv("DB_HOST", "localhost")
|
|
DB_PORT = int(os.getenv("DB_PORT", 3306))
|
|
DB_USER = os.getenv("DB_USER", "emailuser")
|
|
DB_PASSWORD = os.getenv("DB_PASSWORD", "miguel33020")
|
|
DB_NAME = os.getenv("DB_NAME", "emailassistant")
|
|
|
|
# === Connect to DB ===
|
|
conn = mysql.connector.connect(
|
|
host=DB_HOST,
|
|
port=DB_PORT,
|
|
user=DB_USER,
|
|
password=DB_PASSWORD,
|
|
database=DB_NAME
|
|
)
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
# === Logging Helper ===
|
|
def log_event(cursor, level, source, message):
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO logs (level, source, message) VALUES (%s, %s, %s)",
|
|
(level, source, message)
|
|
)
|
|
except:
|
|
print(f"[LOG ERROR] {level} from {source}: {message}")
|
|
|
|
# === Recursive label matcher ===
|
|
def match_labels(keywords, label_tree, prefix=""):
|
|
for label, data in label_tree.items():
|
|
full_label = f"{prefix}/{label}".strip("/")
|
|
label_keywords = [kw.lower() for kw in data.get("keywords", [])]
|
|
if any(kw in keywords for kw in label_keywords):
|
|
children = data.get("children", {})
|
|
child_match = match_labels(keywords, children, prefix=full_label)
|
|
return child_match if child_match else full_label
|
|
return None
|
|
|
|
# === Smart Label Aggregator ===
|
|
def smart_label(email):
|
|
votes = []
|
|
|
|
# 1. FROM address rules
|
|
from_addr = email.get("sender", "").lower()
|
|
if any(x in from_addr for x in ["paypal", "bankofamerica", "chase"]):
|
|
votes.append("bank")
|
|
if "indeed" in from_addr or "hiring" in from_addr:
|
|
votes.append("job")
|
|
|
|
# 2. Subject keyword analysis
|
|
subject = email.get("subject", "")
|
|
if subject:
|
|
keywords = kw_model.extract_keywords(
|
|
subject, keyphrase_ngram_range=(1, 2), stop_words="english", top_n=5
|
|
)
|
|
keyword_set = set(k[0].lower() for k in keywords)
|
|
label_from_subject = match_labels(keyword_set, label_config)
|
|
if label_from_subject:
|
|
votes.append(label_from_subject)
|
|
|
|
# 3. AI summary matching
|
|
summary = email.get("ai_summary", "").lower()
|
|
if "payment" in summary or "transaction" in summary:
|
|
votes.append("bank")
|
|
if "your order" in summary or "delivered" in summary:
|
|
votes.append("promo")
|
|
|
|
# 4. Gmail label logic (from "labels" column)
|
|
raw_label = email.get("labels", "")
|
|
try:
|
|
gmail_labels = ast.literal_eval(raw_label) if raw_label else []
|
|
gmail_labels = [label.upper() for label in gmail_labels]
|
|
except (ValueError, SyntaxError):
|
|
gmail_labels = []
|
|
|
|
if "CATEGORY_PROMOTIONS" in gmail_labels:
|
|
votes.append("promo")
|
|
elif "CATEGORY_SOCIAL" in gmail_labels:
|
|
votes.append("social")
|
|
elif "CATEGORY_UPDATES" in gmail_labels:
|
|
votes.append("work")
|
|
elif "IMPORTANT" in gmail_labels:
|
|
votes.append("work")
|
|
|
|
# 5. Count votes
|
|
label_counts = Counter(votes)
|
|
return label_counts.most_common(1)[0][0] if label_counts else "unlabeled"
|
|
|
|
# === Fetch unlabeled emails ===
|
|
cursor.execute("SELECT id, sender, subject, ai_summary, labels, ai_category FROM emails")
|
|
|
|
emails = cursor.fetchall()
|
|
print(f"📬 Found {len(emails)} total emails for re-labeling")
|
|
|
|
# === Main Labeling Loop ===
|
|
for email in emails:
|
|
email_id = email["id"]
|
|
try:
|
|
label = smart_label(email)
|
|
cursor.execute("""
|
|
UPDATE emails
|
|
SET ai_category = %s,
|
|
ai_label_source = %s,
|
|
is_ai_reviewed = FALSE
|
|
WHERE id = %s
|
|
""", (label, "smart_labeler", email_id))
|
|
|
|
log_event(cursor, "INFO", "smart_labeler", f"Labeled email {email_id} as '{label}'")
|
|
print(f"🏷️ Email {email_id} labeled as: {label}")
|
|
|
|
except Exception as e:
|
|
log_event(cursor, "ERROR", "smart_labeler", f"Error labeling email {email_id}: {str(e)}")
|
|
print(f"❌ Error labeling email {email_id}: {e}")
|
|
|
|
# === Commit & Close ===
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|