diff --git a/core/content_db.py b/core/content_db.py
index a5ccce4..f93d91a 100644
--- a/core/content_db.py
+++ b/core/content_db.py
@@ -1,188 +1,228 @@
-import sqlite3
-import time
-import threading
-import functools
-from utils import util
-
-def synchronized(func):
- @functools.wraps(func)
- def wrapper(self, *args, **kwargs):
- with self.lock:
- return func(self, *args, **kwargs)
- return wrapper
-
-__content_tb = None
-def new_instance():
- global __content_tb
- if __content_tb is None:
- __content_tb = Content_Db()
- __content_tb.init_db()
- return __content_tb
-
-class Content_Db:
-
- def __init__(self) -> None:
- self.lock = threading.Lock()
-
- # 初始化数据库
- def init_db(self):
- conn = sqlite3.connect('memory/fay.db')
- conn.text_factory = str
- c = conn.cursor()
- c.execute('''CREATE TABLE IF NOT EXISTS T_Msg
- (id INTEGER PRIMARY KEY AUTOINCREMENT,
- type CHAR(10),
- way CHAR(10),
- content TEXT NOT NULL,
- createtime INT,
- username TEXT DEFAULT 'User',
- uid INT);''')
- # 对话采纳记录表
- c.execute('''CREATE TABLE IF NOT EXISTS T_Adopted
- (id INTEGER PRIMARY KEY AUTOINCREMENT,
- msg_id INTEGER UNIQUE,
- adopted_time INT,
- FOREIGN KEY(msg_id) REFERENCES T_Msg(id));''')
- conn.commit()
- conn.close()
-
- # 添加对话
- @synchronized
- def add_content(self, type, way, content, username='User', uid=0):
- conn = sqlite3.connect("memory/fay.db")
- conn.text_factory = str
- cur = conn.cursor()
- try:
- cur.execute("INSERT INTO T_Msg (type, way, content, createtime, username, uid) VALUES (?, ?, ?, ?, ?, ?)",
- (type, way, content, int(time.time()), username, uid))
- conn.commit()
- last_id = cur.lastrowid
- except Exception as e:
- util.log(1, "请检查参数是否有误: {}".format(e))
- conn.close()
- return 0
- conn.close()
- return last_id
-
- # 更新对话内容
- @synchronized
- def update_content(self, msg_id, content):
- """
- 更新指定ID的消息内容
- :param msg_id: 消息ID
- :param content: 新的内容
- :return: 是否更新成功
- """
- conn = sqlite3.connect("memory/fay.db")
- conn.text_factory = str
- cur = conn.cursor()
- try:
- cur.execute("UPDATE T_Msg SET content = ? WHERE id = ?", (content, msg_id))
- conn.commit()
- affected_rows = cur.rowcount
- except Exception as e:
- util.log(1, f"更新消息内容失败: {e}")
- conn.close()
- return False
- conn.close()
- return affected_rows > 0
-
- # 根据ID查询对话记录
- @synchronized
- def get_content_by_id(self, msg_id):
- conn = sqlite3.connect("memory/fay.db")
- conn.text_factory = str
- cur = conn.cursor()
- cur.execute("SELECT * FROM T_Msg WHERE id = ?", (msg_id,))
- record = cur.fetchone()
- conn.close()
- return record
-
- # 添加对话采纳记录
- @synchronized
- def adopted_message(self, msg_id):
- conn = sqlite3.connect('memory/fay.db')
- conn.text_factory = str
- cur = conn.cursor()
- # 检查消息ID是否存在
- cur.execute("SELECT 1 FROM T_Msg WHERE id = ?", (msg_id,))
- if cur.fetchone() is None:
- util.log(1, "消息ID不存在")
- conn.close()
- return False
- try:
- cur.execute("INSERT INTO T_Adopted (msg_id, adopted_time) VALUES (?, ?)", (msg_id, int(time.time())))
- conn.commit()
- except sqlite3.IntegrityError:
- util.log(1, "该消息已被采纳")
- conn.close()
- return False
- conn.close()
- return True
-
- # 获取对话内容
- @synchronized
- def get_list(self, way, order, limit, uid=0):
- conn = sqlite3.connect("memory/fay.db")
- conn.text_factory = str
- cur = conn.cursor()
- where_uid = ""
- if int(uid) != 0:
- where_uid = f" AND T_Msg.uid = {uid} "
- base_query = f"""
- SELECT T_Msg.type, T_Msg.way, T_Msg.content, T_Msg.createtime,
- datetime(T_Msg.createtime, 'unixepoch', 'localtime') AS timetext,
- T_Msg.username,T_Msg.id,
- CASE WHEN T_Adopted.msg_id IS NOT NULL THEN 1 ELSE 0 END AS is_adopted
- FROM T_Msg
- LEFT JOIN T_Adopted ON T_Msg.id = T_Adopted.msg_id
- WHERE 1 {where_uid}
- """
- if way == 'all':
- query = base_query + f" ORDER BY T_Msg.id {order} LIMIT ?"
- cur.execute(query, (limit,))
- elif way == 'notappended':
- query = base_query + f" AND T_Msg.way != 'appended' ORDER BY T_Msg.id {order} LIMIT ?"
- cur.execute(query, (limit,))
- else:
- query = base_query + f" AND T_Msg.way = ? ORDER BY T_Msg.id {order} LIMIT ?"
- cur.execute(query, (way, limit))
- list = cur.fetchall()
- conn.close()
- return list
-
-
- @synchronized
- def get_recent_messages_by_user(self, username='User', limit=30):
- conn = sqlite3.connect("memory/fay.db")
- conn.text_factory = str
- cur = conn.cursor()
- cur.execute(
- """
- SELECT type, content
- FROM T_Msg
- WHERE username = ?
- ORDER BY id DESC
- LIMIT ?
- """,
- (username, limit),
- )
- rows = cur.fetchall()
- conn.close()
- rows.reverse()
- return rows
-
- @synchronized
- def get_previous_user_message(self, msg_id):
- conn = sqlite3.connect("memory/fay.db")
- cur = conn.cursor()
- cur.execute("""
- SELECT id, type, way, content, createtime, datetime(createtime, 'unixepoch', 'localtime') AS timetext, username
- FROM T_Msg
- WHERE id < ? AND type != 'fay'
- ORDER BY id DESC
- LIMIT 1
- """, (msg_id,))
- record = cur.fetchone()
- conn.close()
- return record
+import sqlite3
+import time
+import threading
+import functools
+from utils import util
+
+def synchronized(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with self.lock:
+ return func(self, *args, **kwargs)
+ return wrapper
+
+__content_tb = None
+def new_instance():
+ global __content_tb
+ if __content_tb is None:
+ __content_tb = Content_Db()
+ __content_tb.init_db()
+ return __content_tb
+
+class Content_Db:
+
+ def __init__(self) -> None:
+ self.lock = threading.Lock()
+
+ # 初始化数据库
+ def init_db(self):
+ conn = sqlite3.connect('memory/fay.db')
+ conn.text_factory = str
+ c = conn.cursor()
+ c.execute('''CREATE TABLE IF NOT EXISTS T_Msg
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
+ type CHAR(10),
+ way CHAR(10),
+ content TEXT NOT NULL,
+ createtime INT,
+ username TEXT DEFAULT 'User',
+ uid INT);''')
+ # 对话采纳记录表
+ c.execute('''CREATE TABLE IF NOT EXISTS T_Adopted
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
+ msg_id INTEGER UNIQUE,
+ adopted_time INT,
+ FOREIGN KEY(msg_id) REFERENCES T_Msg(id));''')
+ conn.commit()
+ conn.close()
+
+ # 添加对话
+ @synchronized
+ def add_content(self, type, way, content, username='User', uid=0):
+ conn = sqlite3.connect("memory/fay.db")
+ conn.text_factory = str
+ cur = conn.cursor()
+ try:
+ cur.execute("INSERT INTO T_Msg (type, way, content, createtime, username, uid) VALUES (?, ?, ?, ?, ?, ?)",
+ (type, way, content, int(time.time()), username, uid))
+ conn.commit()
+ last_id = cur.lastrowid
+ except Exception as e:
+ util.log(1, "请检查参数是否有误: {}".format(e))
+ conn.close()
+ return 0
+ conn.close()
+ return last_id
+
+ # 更新对话内容
+ @synchronized
+ def update_content(self, msg_id, content):
+ """
+ 更新指定ID的消息内容
+ :param msg_id: 消息ID
+ :param content: 新的内容
+ :return: 是否更新成功
+ """
+ conn = sqlite3.connect("memory/fay.db")
+ conn.text_factory = str
+ cur = conn.cursor()
+ try:
+ cur.execute("UPDATE T_Msg SET content = ? WHERE id = ?", (content, msg_id))
+ conn.commit()
+ affected_rows = cur.rowcount
+ except Exception as e:
+ util.log(1, f"更新消息内容失败: {e}")
+ conn.close()
+ return False
+ conn.close()
+ return affected_rows > 0
+
+ # 根据ID查询对话记录
+ @synchronized
+ def get_content_by_id(self, msg_id):
+ conn = sqlite3.connect("memory/fay.db")
+ conn.text_factory = str
+ cur = conn.cursor()
+ cur.execute("SELECT * FROM T_Msg WHERE id = ?", (msg_id,))
+ record = cur.fetchone()
+ conn.close()
+ return record
+
+ # 添加对话采纳记录
+ @synchronized
+ def adopted_message(self, msg_id):
+ conn = sqlite3.connect('memory/fay.db')
+ conn.text_factory = str
+ cur = conn.cursor()
+ # 检查消息ID是否存在
+ cur.execute("SELECT 1 FROM T_Msg WHERE id = ?", (msg_id,))
+ if cur.fetchone() is None:
+ util.log(1, "消息ID不存在")
+ conn.close()
+ return False
+ try:
+ cur.execute("INSERT INTO T_Adopted (msg_id, adopted_time) VALUES (?, ?)", (msg_id, int(time.time())))
+ conn.commit()
+ except sqlite3.IntegrityError:
+ util.log(1, "该消息已被采纳")
+ conn.close()
+ return False
+ conn.close()
+ return True
+
+ # 取消采纳:删除采纳记录并返回相同clean_content的所有消息ID
+ @synchronized
+ def unadopt_message(self, msg_id, clean_content):
+ """
+ 取消采纳消息
+ :param msg_id: 消息ID
+ :param clean_content: 过滤掉think标签后的干净内容,用于匹配QA文件
+ :return: (success, same_content_ids)
+ """
+ import re
+ conn = sqlite3.connect('memory/fay.db')
+ conn.text_factory = str
+ cur = conn.cursor()
+
+ # 获取所有fay类型的消息,检查过滤think后的内容是否匹配
+ cur.execute("SELECT id, content FROM T_Msg WHERE type = 'fay'")
+ all_fay_msgs = cur.fetchall()
+
+ # 规范化目标内容:去掉换行符和首尾空格
+ clean_content_normalized = clean_content.replace('\n', '').replace('\r', '').strip()
+
+ same_content_ids = []
+ for row in all_fay_msgs:
+ row_id, row_content = row
+ # 过滤掉think标签内容后比较
+ row_clean = re.sub(r'[\s\S]*?', '', row_content, flags=re.IGNORECASE).strip()
+ # 规范化后比较
+ row_clean_normalized = row_clean.replace('\n', '').replace('\r', '').strip()
+ if row_clean_normalized == clean_content_normalized:
+ same_content_ids.append(row_id)
+
+ # 删除这些消息的采纳记录
+ if same_content_ids:
+ placeholders = ','.join('?' * len(same_content_ids))
+ cur.execute(f"DELETE FROM T_Adopted WHERE msg_id IN ({placeholders})", same_content_ids)
+ conn.commit()
+
+ conn.close()
+ return True, same_content_ids
+
+ # 获取对话内容
+ @synchronized
+ def get_list(self, way, order, limit, uid=0):
+ conn = sqlite3.connect("memory/fay.db")
+ conn.text_factory = str
+ cur = conn.cursor()
+ where_uid = ""
+ if int(uid) != 0:
+ where_uid = f" AND T_Msg.uid = {uid} "
+ base_query = f"""
+ SELECT T_Msg.type, T_Msg.way, T_Msg.content, T_Msg.createtime,
+ datetime(T_Msg.createtime, 'unixepoch', 'localtime') AS timetext,
+ T_Msg.username,T_Msg.id,
+ CASE WHEN T_Adopted.msg_id IS NOT NULL THEN 1 ELSE 0 END AS is_adopted
+ FROM T_Msg
+ LEFT JOIN T_Adopted ON T_Msg.id = T_Adopted.msg_id
+ WHERE 1 {where_uid}
+ """
+ if way == 'all':
+ query = base_query + f" ORDER BY T_Msg.id {order} LIMIT ?"
+ cur.execute(query, (limit,))
+ elif way == 'notappended':
+ query = base_query + f" AND T_Msg.way != 'appended' ORDER BY T_Msg.id {order} LIMIT ?"
+ cur.execute(query, (limit,))
+ else:
+ query = base_query + f" AND T_Msg.way = ? ORDER BY T_Msg.id {order} LIMIT ?"
+ cur.execute(query, (way, limit))
+ list = cur.fetchall()
+ conn.close()
+ return list
+
+
+ @synchronized
+ def get_recent_messages_by_user(self, username='User', limit=30):
+ conn = sqlite3.connect("memory/fay.db")
+ conn.text_factory = str
+ cur = conn.cursor()
+ cur.execute(
+ """
+ SELECT type, content
+ FROM T_Msg
+ WHERE username = ?
+ ORDER BY id DESC
+ LIMIT ?
+ """,
+ (username, limit),
+ )
+ rows = cur.fetchall()
+ conn.close()
+ rows.reverse()
+ return rows
+
+ @synchronized
+ def get_previous_user_message(self, msg_id):
+ conn = sqlite3.connect("memory/fay.db")
+ cur = conn.cursor()
+ cur.execute("""
+ SELECT id, type, way, content, createtime, datetime(createtime, 'unixepoch', 'localtime') AS timetext, username
+ FROM T_Msg
+ WHERE id < ? AND type != 'fay'
+ ORDER BY id DESC
+ LIMIT 1
+ """, (msg_id,))
+ record = cur.fetchone()
+ conn.close()
+ return record
diff --git a/core/qa_service.py b/core/qa_service.py
index ad94faf..4942995 100644
--- a/core/qa_service.py
+++ b/core/qa_service.py
@@ -1,112 +1,158 @@
-import os
-import csv
-import difflib
-import random
-from utils import config_util as cfg
-from scheduler.thread_manager import MyThread
-import shlex
-import subprocess
-import time
-from utils import util
-
-class QAService:
-
- def __init__(self):
- # 人设提问关键字
- self.attribute_keyword = [
- [['你叫什么名字', '你的名字是什么'], 'name'],
- [['你是男的还是女的', '你是男生还是女生', '你的性别是什么', '你是男生吗', '你是女生吗', '你是男的吗', '你是女的吗', '你是男孩子吗', '你是女孩子吗', ], 'gender', ],
- [['你今年多大了', '你多大了', '你今年多少岁', '你几岁了', '你今年几岁了', '你今年几岁了', '你什么时候出生', '你的生日是什么', '你的年龄'], 'age', ],
- [['你的家乡在哪', '你的家乡是什么', '你家在哪', '你住在哪', '你出生在哪', '你的出生地在哪', '你的出生地是什么', ], 'birth', ],
- [['你的生肖是什么', '你属什么', ], 'zodiac', ],
- [['你是什么座', '你是什么星座', '你的星座是什么', ], 'constellation', ],
- [['你是做什么的', '你的职业是什么', '你是干什么的', '你的职位是什么', '你的工作是什么', '你是做什么工作的'], 'job', ],
- [['你的爱好是什么', '你有爱好吗', '你喜欢什么', '你喜欢做什么'], 'hobby'],
- [['联系方式', '联系你们', '怎么联系客服', '有没有客服'], 'contact']
- ]
-
- self.command_keyword = [
- [['关闭', '再见', '你走吧'], 'stop'],
- [['静音', '闭嘴', '我想静静'], 'mute'],
- [['取消静音', '你在哪呢', '你可以说话了'], 'unmute'],
- [['换个性别', '换个声音'], 'changeVoice']
- ]
-
- def question(self, query_type, text):
- if query_type == 'qa':
- answer_dict = self.__read_qna(cfg.config['interact'].get('QnA'))
- answer, action = self.__get_keyword(answer_dict, text, query_type)
- if action:
- MyThread(target=self.__run, args=[action]).start()
- return answer, 'qa'
-
- elif query_type == 'Persona':
- answer_dict = self.attribute_keyword
- answer, action = self.__get_keyword(answer_dict, text, query_type)
- return answer, 'Persona'
- elif query_type == 'command':
- answer, action = self.__get_keyword(self.command_keyword, text, query_type)
- return answer, 'command'
- return None, None
-
- def __run(self, action):
- time.sleep(0.1)
- args = shlex.split(action) # 分割命令行参数
- subprocess.Popen(args)
-
- def __read_qna(self, filename):
- qna = []
- try:
- with open(filename, 'r', encoding='utf-8') as csvfile:
- reader = csv.reader(csvfile)
- next(reader) # 跳过表头
- for row in reader:
- if len(row) >= 2:
- qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
- except Exception as e:
- pass
- return qna
-
- def record_qapair(self, question, answer):
- if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv':
- util.log(1, 'qa文件没有指定,不记录大模型回复')
- return
- log_file = cfg.config['interact']['QnA'] # 指定 CSV 文件的名称或路径
- file_exists = os.path.isfile(log_file)
- with open(log_file, 'a', newline='', encoding='utf-8') as csvfile:
- writer = csv.writer(csvfile)
- if not file_exists:
- # 写入表头
- writer.writerow(['Question', 'Answer'])
- writer.writerow([question, answer])
-
- def __get_keyword(self, keyword_dict, text, query_type):
- threshold = 0.6
- candidates = []
-
- for qa in keyword_dict:
- if len(qa) < 2:
- continue
- for quest in qa[0]:
- similar = self.__string_similar(text, quest)
- if quest in text:
- similar += 0.3
- if similar >= threshold:
- action = qa[2] if (query_type == "qa" and len(qa) > 2) else None
- candidates.append((similar, qa[1], action))
-
- if not candidates:
- return None, None
-
- candidates.sort(key=lambda x: x[0], reverse=True)
-
- max_hits = max(1, int(len(keyword_dict) * 0.1))
- candidates = candidates[:max_hits]
-
- chosen = random.choice(candidates)
- return chosen[1], chosen[2]
-
- def __string_similar(self, s1, s2):
- return difflib.SequenceMatcher(None, s1, s2).quick_ratio()
-
-
+import os
+import csv
+import difflib
+import random
+from utils import config_util as cfg
+from scheduler.thread_manager import MyThread
+import shlex
+import subprocess
+import time
+from utils import util
+
+class QAService:
+
+ def __init__(self):
+ # 人设提问关键字
+ self.attribute_keyword = [
+ [['你叫什么名字', '你的名字是什么'], 'name'],
+ [['你是男的还是女的', '你是男生还是女生', '你的性别是什么', '你是男生吗', '你是女生吗', '你是男的吗', '你是女的吗', '你是男孩子吗', '你是女孩子吗', ], 'gender', ],
+ [['你今年多大了', '你多大了', '你今年多少岁', '你几岁了', '你今年几岁了', '你今年几岁了', '你什么时候出生', '你的生日是什么', '你的年龄'], 'age', ],
+ [['你的家乡在哪', '你的家乡是什么', '你家在哪', '你住在哪', '你出生在哪', '你的出生地在哪', '你的出生地是什么', ], 'birth', ],
+ [['你的生肖是什么', '你属什么', ], 'zodiac', ],
+ [['你是什么座', '你是什么星座', '你的星座是什么', ], 'constellation', ],
+ [['你是做什么的', '你的职业是什么', '你是干什么的', '你的职位是什么', '你的工作是什么', '你是做什么工作的'], 'job', ],
+ [['你的爱好是什么', '你有爱好吗', '你喜欢什么', '你喜欢做什么'], 'hobby'],
+ [['联系方式', '联系你们', '怎么联系客服', '有没有客服'], 'contact']
+ ]
+
+ self.command_keyword = [
+ [['关闭', '再见', '你走吧'], 'stop'],
+ [['静音', '闭嘴', '我想静静'], 'mute'],
+ [['取消静音', '你在哪呢', '你可以说话了'], 'unmute'],
+ [['换个性别', '换个声音'], 'changeVoice']
+ ]
+
+ def question(self, query_type, text):
+ if query_type == 'qa':
+ answer_dict = self.__read_qna(cfg.config['interact'].get('QnA'))
+ answer, action = self.__get_keyword(answer_dict, text, query_type)
+ if action:
+ MyThread(target=self.__run, args=[action]).start()
+ return answer, 'qa'
+
+ elif query_type == 'Persona':
+ answer_dict = self.attribute_keyword
+ answer, action = self.__get_keyword(answer_dict, text, query_type)
+ return answer, 'Persona'
+ elif query_type == 'command':
+ answer, action = self.__get_keyword(self.command_keyword, text, query_type)
+ return answer, 'command'
+ return None, None
+
+ def __run(self, action):
+ time.sleep(0.1)
+ args = shlex.split(action) # 分割命令行参数
+ subprocess.Popen(args)
+
+ def __read_qna(self, filename):
+ qna = []
+ try:
+ with open(filename, 'r', encoding='utf-8') as csvfile:
+ reader = csv.reader(csvfile)
+ next(reader) # 跳过表头
+ for row in reader:
+ if len(row) >= 2:
+ qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
+ except Exception as e:
+ pass
+ return qna
+
+ def record_qapair(self, question, answer):
+ if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv':
+ util.log(1, 'qa文件没有指定,不记录大模型回复')
+ return
+ log_file = cfg.config['interact']['QnA'] # 指定 CSV 文件的名称或路径
+ file_exists = os.path.isfile(log_file)
+ with open(log_file, 'a', newline='', encoding='utf-8') as csvfile:
+ writer = csv.writer(csvfile)
+ if not file_exists:
+ # 写入表头
+ writer.writerow(['Question', 'Answer'])
+ writer.writerow([question, answer])
+
+ def remove_qapair(self, answer):
+ """从QA文件中删除指定答案的记录"""
+ if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv':
+ util.log(1, 'qa文件没有指定')
+ return False
+ log_file = cfg.config['interact']['QnA']
+ if not os.path.isfile(log_file):
+ util.log(1, 'qa文件不存在')
+ return False
+
+ try:
+ # 读取所有记录
+ rows = []
+ with open(log_file, 'r', encoding='utf-8') as csvfile:
+ reader = csv.reader(csvfile)
+ rows = list(reader)
+
+ if len(rows) <= 1:
+ return False
+
+ # 过滤掉匹配答案的记录(保留表头)
+ # 规范化答案:去掉换行符和首尾空格后比较
+ header = rows[0]
+ filtered_rows = [header]
+ removed_count = 0
+ answer_normalized = answer.replace('\n', '').replace('\r', '').strip()
+ for row in rows[1:]:
+ if len(row) >= 2:
+ row_answer_normalized = row[1].replace('\n', '').replace('\r', '').strip()
+ if row_answer_normalized == answer_normalized:
+ removed_count += 1
+ else:
+ filtered_rows.append(row)
+ else:
+ filtered_rows.append(row)
+
+ if removed_count > 0:
+ # 写回文件
+ with open(log_file, 'w', newline='', encoding='utf-8') as csvfile:
+ writer = csv.writer(csvfile)
+ writer.writerows(filtered_rows)
+ util.log(1, f'从QA文件中删除了 {removed_count} 条记录')
+ return True
+ else:
+ util.log(1, '未找到匹配的QA记录')
+ return False
+ except Exception as e:
+ util.log(1, f'删除QA记录时出错: {e}')
+ return False
+
+ def __get_keyword(self, keyword_dict, text, query_type):
+ threshold = 0.6
+ candidates = []
+
+ for qa in keyword_dict:
+ if len(qa) < 2:
+ continue
+ for quest in qa[0]:
+ similar = self.__string_similar(text, quest)
+ if quest in text:
+ similar += 0.3
+ if similar >= threshold:
+ action = qa[2] if (query_type == "qa" and len(qa) > 2) else None
+ candidates.append((similar, qa[1], action))
+
+ if not candidates:
+ return None, None
+
+ # 从所有超过阈值的候选项中随机选择一个
+ chosen = random.choice(candidates)
+ return chosen[1], chosen[2]
+
+ def __string_similar(self, s1, s2):
+ return difflib.SequenceMatcher(None, s1, s2).quick_ratio()
+
+
diff --git a/gui/flask_server.py b/gui/flask_server.py
index 9ab6450..aae743b 100644
--- a/gui/flask_server.py
+++ b/gui/flask_server.py
@@ -392,6 +392,8 @@ def adopt_msg():
info = content_db.new_instance().get_content_by_id(id)
content = info[3] if info else ''
if info is not None:
+ # 过滤掉 think 标签及其内容
+ content = re.sub(r'[\s\S]*?', '', content, flags=re.IGNORECASE).strip()
previous_info = content_db.new_instance().get_previous_user_message(id)
previous_content = previous_info[3] if previous_info else ''
result = content_db.new_instance().adopted_message(id)
@@ -405,6 +407,43 @@ def adopt_msg():
except Exception as e:
return jsonify({'status':'error', 'msg': f'采纳消息时出错: {e}'}), 500
+@__app.route('/api/unadopt-msg', methods=['POST'])
+def unadopt_msg():
+ # 取消采纳消息
+ data = request.get_json()
+ if not data:
+ return jsonify({'status':'error', 'msg': '未提供数据'})
+
+ id = data.get('id')
+
+ if not id:
+ return jsonify({'status':'error', 'msg': 'id不能为空'})
+
+ try:
+ info = content_db.new_instance().get_content_by_id(id)
+ if info is None:
+ return jsonify({'status':'error', 'msg': '消息未找到'}), 404
+
+ content = info[3]
+ # 过滤掉 think 标签及其内容,用于匹配 QA 文件中的答案
+ clean_content = re.sub(r'[\s\S]*?', '', content, flags=re.IGNORECASE).strip()
+
+ # 从数据库中删除采纳记录,并获取所有相同内容的消息ID
+ success, same_content_ids = content_db.new_instance().unadopt_message(id, clean_content)
+
+ if success:
+ # 从 QA 文件中删除对应记录
+ qa_service.QAService().remove_qapair(clean_content)
+ return jsonify({
+ 'status': 'success',
+ 'msg': '取消采纳成功',
+ 'unadopted_ids': same_content_ids
+ })
+ else:
+ return jsonify({'status':'error', 'msg': '取消采纳失败'}), 500
+ except Exception as e:
+ return jsonify({'status':'error', 'msg': f'取消采纳时出错: {e}'}), 500
+
def gpt_stream_response(last_content, username):
sm = stream_manager.new_instance()
_, nlp_Stream = sm.get_Stream(username)
diff --git a/gui/static/js/index.js b/gui/static/js/index.js
index cfdaab5..62bc301 100644
--- a/gui/static/js/index.js
+++ b/gui/static/js/index.js
@@ -555,7 +555,7 @@ this.fayService.fetchData(`${this.base_url}/api/adopt-msg`, {
message: response.msg, // 显示成功消息
type: 'success',
});
-
+
this.loadMessageHistory(this.selectedUser[1], 'adopt');
} else {
// 处理失败的响应
@@ -574,6 +574,48 @@ this.fayService.fetchData(`${this.base_url}/api/adopt-msg`, {
type: 'error',
});
});
+},
+
+// 取消采纳
+unadoptText(id) {
+ this.fayService.fetchData(`${this.base_url}/api/unadopt-msg`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify({ id })
+ })
+ .then((response) => {
+ if (response && response.status === 'success') {
+ this.$notify({
+ title: '成功',
+ message: response.msg,
+ type: 'success',
+ });
+
+ // 更新本地消息列表中所有相关消息的采纳状态
+ if (response.unadopted_ids && response.unadopted_ids.length > 0) {
+ this.messages.forEach(msg => {
+ if (response.unadopted_ids.includes(msg.id)) {
+ msg.is_adopted = 0;
+ }
+ });
+ }
+ } else {
+ this.$notify({
+ title: '失败',
+ message: response ? response.msg : '请求失败',
+ type: 'error',
+ });
+ }
+ })
+ .catch((error) => {
+ this.$notify({
+ title: '错误',
+ message: error.message || '请求失败',
+ type: 'error',
+ });
+ });
}
,
minimizeThinkPanel() {
diff --git a/gui/templates/index.html b/gui/templates/index.html
index 5c902bf..2627225 100644
--- a/gui/templates/index.html
+++ b/gui/templates/index.html
@@ -58,8 +58,8 @@
-
diff --git a/qa.csv b/qa.csv
index 72450a6..7af3515 100644
--- a/qa.csv
+++ b/qa.csv
@@ -1,2 +1,3 @@
-问题,答案,脚本
-
+
+问题,答案,脚本
+哈哈,哈哈,看来你找到乐子了呀!是不是觉得我刚才的回答像个复读机,有点好玩?
diff --git a/utils/stream_text_processor.py b/utils/stream_text_processor.py
index 07097e7..fd0df5c 100644
--- a/utils/stream_text_processor.py
+++ b/utils/stream_text_processor.py
@@ -119,7 +119,7 @@ class StreamTextProcessor:
username, marked_text, conversation_id=conversation_id
)
if success:
- accumulated_text = accumulated_text[punct_index + 1:].lstrip()
+ accumulated_text = accumulated_text[punct_index + 1:]
first_sentence_sent = True # 标记已发送第一个句子
sent_successfully = True
break