From ac8c1c6f2a512ac3222acbd5f5509313e66e616a Mon Sep 17 00:00:00 2001 From: guo zebin Date: Wed, 3 Dec 2025 22:56:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96qa=E7=AE=A1=E7=90=86=E7=AD=96?= =?UTF-8?q?=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1、优化采纳及取消采纳的逻辑; 2、采纳不记录think内容; 3、相似问题的答案随机选择输出。 --- core/content_db.py | 416 ++++++++++++++++++--------------- core/qa_service.py | 270 ++++++++++++--------- gui/flask_server.py | 39 ++++ gui/static/js/index.js | 44 +++- gui/templates/index.html | 4 +- qa.csv | 5 +- utils/stream_text_processor.py | 2 +- 7 files changed, 474 insertions(+), 306 deletions(-) diff --git a/core/content_db.py b/core/content_db.py index a5ccce4..f93d91a 100644 --- a/core/content_db.py +++ b/core/content_db.py @@ -1,188 +1,228 @@ -import sqlite3 -import time -import threading -import functools -from utils import util - -def synchronized(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - with self.lock: - return func(self, *args, **kwargs) - return wrapper - -__content_tb = None -def new_instance(): - global __content_tb - if __content_tb is None: - __content_tb = Content_Db() - __content_tb.init_db() - return __content_tb - -class Content_Db: - - def __init__(self) -> None: - self.lock = threading.Lock() - - # 初始化数据库 - def init_db(self): - conn = sqlite3.connect('memory/fay.db') - conn.text_factory = str - c = conn.cursor() - c.execute('''CREATE TABLE IF NOT EXISTS T_Msg - (id INTEGER PRIMARY KEY AUTOINCREMENT, - type CHAR(10), - way CHAR(10), - content TEXT NOT NULL, - createtime INT, - username TEXT DEFAULT 'User', - uid INT);''') - # 对话采纳记录表 - c.execute('''CREATE TABLE IF NOT EXISTS T_Adopted - (id INTEGER PRIMARY KEY AUTOINCREMENT, - msg_id INTEGER UNIQUE, - adopted_time INT, - FOREIGN KEY(msg_id) REFERENCES T_Msg(id));''') - conn.commit() - conn.close() - - # 添加对话 - @synchronized - def add_content(self, type, way, content, username='User', uid=0): - conn = sqlite3.connect("memory/fay.db") - conn.text_factory = str - cur = conn.cursor() - try: - cur.execute("INSERT INTO T_Msg (type, way, content, createtime, username, uid) VALUES (?, ?, ?, ?, ?, ?)", - (type, way, content, int(time.time()), username, uid)) - conn.commit() - last_id = cur.lastrowid - except Exception as e: - util.log(1, "请检查参数是否有误: {}".format(e)) - conn.close() - return 0 - conn.close() - return last_id - - # 更新对话内容 - @synchronized - def update_content(self, msg_id, content): - """ - 更新指定ID的消息内容 - :param msg_id: 消息ID - :param content: 新的内容 - :return: 是否更新成功 - """ - conn = sqlite3.connect("memory/fay.db") - conn.text_factory = str - cur = conn.cursor() - try: - cur.execute("UPDATE T_Msg SET content = ? WHERE id = ?", (content, msg_id)) - conn.commit() - affected_rows = cur.rowcount - except Exception as e: - util.log(1, f"更新消息内容失败: {e}") - conn.close() - return False - conn.close() - return affected_rows > 0 - - # 根据ID查询对话记录 - @synchronized - def get_content_by_id(self, msg_id): - conn = sqlite3.connect("memory/fay.db") - conn.text_factory = str - cur = conn.cursor() - cur.execute("SELECT * FROM T_Msg WHERE id = ?", (msg_id,)) - record = cur.fetchone() - conn.close() - return record - - # 添加对话采纳记录 - @synchronized - def adopted_message(self, msg_id): - conn = sqlite3.connect('memory/fay.db') - conn.text_factory = str - cur = conn.cursor() - # 检查消息ID是否存在 - cur.execute("SELECT 1 FROM T_Msg WHERE id = ?", (msg_id,)) - if cur.fetchone() is None: - util.log(1, "消息ID不存在") - conn.close() - return False - try: - cur.execute("INSERT INTO T_Adopted (msg_id, adopted_time) VALUES (?, ?)", (msg_id, int(time.time()))) - conn.commit() - except sqlite3.IntegrityError: - util.log(1, "该消息已被采纳") - conn.close() - return False - conn.close() - return True - - # 获取对话内容 - @synchronized - def get_list(self, way, order, limit, uid=0): - conn = sqlite3.connect("memory/fay.db") - conn.text_factory = str - cur = conn.cursor() - where_uid = "" - if int(uid) != 0: - where_uid = f" AND T_Msg.uid = {uid} " - base_query = f""" - SELECT T_Msg.type, T_Msg.way, T_Msg.content, T_Msg.createtime, - datetime(T_Msg.createtime, 'unixepoch', 'localtime') AS timetext, - T_Msg.username,T_Msg.id, - CASE WHEN T_Adopted.msg_id IS NOT NULL THEN 1 ELSE 0 END AS is_adopted - FROM T_Msg - LEFT JOIN T_Adopted ON T_Msg.id = T_Adopted.msg_id - WHERE 1 {where_uid} - """ - if way == 'all': - query = base_query + f" ORDER BY T_Msg.id {order} LIMIT ?" - cur.execute(query, (limit,)) - elif way == 'notappended': - query = base_query + f" AND T_Msg.way != 'appended' ORDER BY T_Msg.id {order} LIMIT ?" - cur.execute(query, (limit,)) - else: - query = base_query + f" AND T_Msg.way = ? ORDER BY T_Msg.id {order} LIMIT ?" - cur.execute(query, (way, limit)) - list = cur.fetchall() - conn.close() - return list - - - @synchronized - def get_recent_messages_by_user(self, username='User', limit=30): - conn = sqlite3.connect("memory/fay.db") - conn.text_factory = str - cur = conn.cursor() - cur.execute( - """ - SELECT type, content - FROM T_Msg - WHERE username = ? - ORDER BY id DESC - LIMIT ? - """, - (username, limit), - ) - rows = cur.fetchall() - conn.close() - rows.reverse() - return rows - - @synchronized - def get_previous_user_message(self, msg_id): - conn = sqlite3.connect("memory/fay.db") - cur = conn.cursor() - cur.execute(""" - SELECT id, type, way, content, createtime, datetime(createtime, 'unixepoch', 'localtime') AS timetext, username - FROM T_Msg - WHERE id < ? AND type != 'fay' - ORDER BY id DESC - LIMIT 1 - """, (msg_id,)) - record = cur.fetchone() - conn.close() - return record +import sqlite3 +import time +import threading +import functools +from utils import util + +def synchronized(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper + +__content_tb = None +def new_instance(): + global __content_tb + if __content_tb is None: + __content_tb = Content_Db() + __content_tb.init_db() + return __content_tb + +class Content_Db: + + def __init__(self) -> None: + self.lock = threading.Lock() + + # 初始化数据库 + def init_db(self): + conn = sqlite3.connect('memory/fay.db') + conn.text_factory = str + c = conn.cursor() + c.execute('''CREATE TABLE IF NOT EXISTS T_Msg + (id INTEGER PRIMARY KEY AUTOINCREMENT, + type CHAR(10), + way CHAR(10), + content TEXT NOT NULL, + createtime INT, + username TEXT DEFAULT 'User', + uid INT);''') + # 对话采纳记录表 + c.execute('''CREATE TABLE IF NOT EXISTS T_Adopted + (id INTEGER PRIMARY KEY AUTOINCREMENT, + msg_id INTEGER UNIQUE, + adopted_time INT, + FOREIGN KEY(msg_id) REFERENCES T_Msg(id));''') + conn.commit() + conn.close() + + # 添加对话 + @synchronized + def add_content(self, type, way, content, username='User', uid=0): + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str + cur = conn.cursor() + try: + cur.execute("INSERT INTO T_Msg (type, way, content, createtime, username, uid) VALUES (?, ?, ?, ?, ?, ?)", + (type, way, content, int(time.time()), username, uid)) + conn.commit() + last_id = cur.lastrowid + except Exception as e: + util.log(1, "请检查参数是否有误: {}".format(e)) + conn.close() + return 0 + conn.close() + return last_id + + # 更新对话内容 + @synchronized + def update_content(self, msg_id, content): + """ + 更新指定ID的消息内容 + :param msg_id: 消息ID + :param content: 新的内容 + :return: 是否更新成功 + """ + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str + cur = conn.cursor() + try: + cur.execute("UPDATE T_Msg SET content = ? WHERE id = ?", (content, msg_id)) + conn.commit() + affected_rows = cur.rowcount + except Exception as e: + util.log(1, f"更新消息内容失败: {e}") + conn.close() + return False + conn.close() + return affected_rows > 0 + + # 根据ID查询对话记录 + @synchronized + def get_content_by_id(self, msg_id): + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str + cur = conn.cursor() + cur.execute("SELECT * FROM T_Msg WHERE id = ?", (msg_id,)) + record = cur.fetchone() + conn.close() + return record + + # 添加对话采纳记录 + @synchronized + def adopted_message(self, msg_id): + conn = sqlite3.connect('memory/fay.db') + conn.text_factory = str + cur = conn.cursor() + # 检查消息ID是否存在 + cur.execute("SELECT 1 FROM T_Msg WHERE id = ?", (msg_id,)) + if cur.fetchone() is None: + util.log(1, "消息ID不存在") + conn.close() + return False + try: + cur.execute("INSERT INTO T_Adopted (msg_id, adopted_time) VALUES (?, ?)", (msg_id, int(time.time()))) + conn.commit() + except sqlite3.IntegrityError: + util.log(1, "该消息已被采纳") + conn.close() + return False + conn.close() + return True + + # 取消采纳:删除采纳记录并返回相同clean_content的所有消息ID + @synchronized + def unadopt_message(self, msg_id, clean_content): + """ + 取消采纳消息 + :param msg_id: 消息ID + :param clean_content: 过滤掉think标签后的干净内容,用于匹配QA文件 + :return: (success, same_content_ids) + """ + import re + conn = sqlite3.connect('memory/fay.db') + conn.text_factory = str + cur = conn.cursor() + + # 获取所有fay类型的消息,检查过滤think后的内容是否匹配 + cur.execute("SELECT id, content FROM T_Msg WHERE type = 'fay'") + all_fay_msgs = cur.fetchall() + + # 规范化目标内容:去掉换行符和首尾空格 + clean_content_normalized = clean_content.replace('\n', '').replace('\r', '').strip() + + same_content_ids = [] + for row in all_fay_msgs: + row_id, row_content = row + # 过滤掉think标签内容后比较 + row_clean = re.sub(r'[\s\S]*?', '', row_content, flags=re.IGNORECASE).strip() + # 规范化后比较 + row_clean_normalized = row_clean.replace('\n', '').replace('\r', '').strip() + if row_clean_normalized == clean_content_normalized: + same_content_ids.append(row_id) + + # 删除这些消息的采纳记录 + if same_content_ids: + placeholders = ','.join('?' * len(same_content_ids)) + cur.execute(f"DELETE FROM T_Adopted WHERE msg_id IN ({placeholders})", same_content_ids) + conn.commit() + + conn.close() + return True, same_content_ids + + # 获取对话内容 + @synchronized + def get_list(self, way, order, limit, uid=0): + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str + cur = conn.cursor() + where_uid = "" + if int(uid) != 0: + where_uid = f" AND T_Msg.uid = {uid} " + base_query = f""" + SELECT T_Msg.type, T_Msg.way, T_Msg.content, T_Msg.createtime, + datetime(T_Msg.createtime, 'unixepoch', 'localtime') AS timetext, + T_Msg.username,T_Msg.id, + CASE WHEN T_Adopted.msg_id IS NOT NULL THEN 1 ELSE 0 END AS is_adopted + FROM T_Msg + LEFT JOIN T_Adopted ON T_Msg.id = T_Adopted.msg_id + WHERE 1 {where_uid} + """ + if way == 'all': + query = base_query + f" ORDER BY T_Msg.id {order} LIMIT ?" + cur.execute(query, (limit,)) + elif way == 'notappended': + query = base_query + f" AND T_Msg.way != 'appended' ORDER BY T_Msg.id {order} LIMIT ?" + cur.execute(query, (limit,)) + else: + query = base_query + f" AND T_Msg.way = ? ORDER BY T_Msg.id {order} LIMIT ?" + cur.execute(query, (way, limit)) + list = cur.fetchall() + conn.close() + return list + + + @synchronized + def get_recent_messages_by_user(self, username='User', limit=30): + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str + cur = conn.cursor() + cur.execute( + """ + SELECT type, content + FROM T_Msg + WHERE username = ? + ORDER BY id DESC + LIMIT ? + """, + (username, limit), + ) + rows = cur.fetchall() + conn.close() + rows.reverse() + return rows + + @synchronized + def get_previous_user_message(self, msg_id): + conn = sqlite3.connect("memory/fay.db") + cur = conn.cursor() + cur.execute(""" + SELECT id, type, way, content, createtime, datetime(createtime, 'unixepoch', 'localtime') AS timetext, username + FROM T_Msg + WHERE id < ? AND type != 'fay' + ORDER BY id DESC + LIMIT 1 + """, (msg_id,)) + record = cur.fetchone() + conn.close() + return record diff --git a/core/qa_service.py b/core/qa_service.py index ad94faf..4942995 100644 --- a/core/qa_service.py +++ b/core/qa_service.py @@ -1,112 +1,158 @@ -import os -import csv -import difflib -import random -from utils import config_util as cfg -from scheduler.thread_manager import MyThread -import shlex -import subprocess -import time -from utils import util - -class QAService: - - def __init__(self): - # 人设提问关键字 - self.attribute_keyword = [ - [['你叫什么名字', '你的名字是什么'], 'name'], - [['你是男的还是女的', '你是男生还是女生', '你的性别是什么', '你是男生吗', '你是女生吗', '你是男的吗', '你是女的吗', '你是男孩子吗', '你是女孩子吗', ], 'gender', ], - [['你今年多大了', '你多大了', '你今年多少岁', '你几岁了', '你今年几岁了', '你今年几岁了', '你什么时候出生', '你的生日是什么', '你的年龄'], 'age', ], - [['你的家乡在哪', '你的家乡是什么', '你家在哪', '你住在哪', '你出生在哪', '你的出生地在哪', '你的出生地是什么', ], 'birth', ], - [['你的生肖是什么', '你属什么', ], 'zodiac', ], - [['你是什么座', '你是什么星座', '你的星座是什么', ], 'constellation', ], - [['你是做什么的', '你的职业是什么', '你是干什么的', '你的职位是什么', '你的工作是什么', '你是做什么工作的'], 'job', ], - [['你的爱好是什么', '你有爱好吗', '你喜欢什么', '你喜欢做什么'], 'hobby'], - [['联系方式', '联系你们', '怎么联系客服', '有没有客服'], 'contact'] - ] - - self.command_keyword = [ - [['关闭', '再见', '你走吧'], 'stop'], - [['静音', '闭嘴', '我想静静'], 'mute'], - [['取消静音', '你在哪呢', '你可以说话了'], 'unmute'], - [['换个性别', '换个声音'], 'changeVoice'] - ] - - def question(self, query_type, text): - if query_type == 'qa': - answer_dict = self.__read_qna(cfg.config['interact'].get('QnA')) - answer, action = self.__get_keyword(answer_dict, text, query_type) - if action: - MyThread(target=self.__run, args=[action]).start() - return answer, 'qa' - - elif query_type == 'Persona': - answer_dict = self.attribute_keyword - answer, action = self.__get_keyword(answer_dict, text, query_type) - return answer, 'Persona' - elif query_type == 'command': - answer, action = self.__get_keyword(self.command_keyword, text, query_type) - return answer, 'command' - return None, None - - def __run(self, action): - time.sleep(0.1) - args = shlex.split(action) # 分割命令行参数 - subprocess.Popen(args) - - def __read_qna(self, filename): - qna = [] - try: - with open(filename, 'r', encoding='utf-8') as csvfile: - reader = csv.reader(csvfile) - next(reader) # 跳过表头 - for row in reader: - if len(row) >= 2: - qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None]) - except Exception as e: - pass - return qna - - def record_qapair(self, question, answer): - if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv': - util.log(1, 'qa文件没有指定,不记录大模型回复') - return - log_file = cfg.config['interact']['QnA'] # 指定 CSV 文件的名称或路径 - file_exists = os.path.isfile(log_file) - with open(log_file, 'a', newline='', encoding='utf-8') as csvfile: - writer = csv.writer(csvfile) - if not file_exists: - # 写入表头 - writer.writerow(['Question', 'Answer']) - writer.writerow([question, answer]) - - def __get_keyword(self, keyword_dict, text, query_type): - threshold = 0.6 - candidates = [] - - for qa in keyword_dict: - if len(qa) < 2: - continue - for quest in qa[0]: - similar = self.__string_similar(text, quest) - if quest in text: - similar += 0.3 - if similar >= threshold: - action = qa[2] if (query_type == "qa" and len(qa) > 2) else None - candidates.append((similar, qa[1], action)) - - if not candidates: - return None, None - - candidates.sort(key=lambda x: x[0], reverse=True) - - max_hits = max(1, int(len(keyword_dict) * 0.1)) - candidates = candidates[:max_hits] - - chosen = random.choice(candidates) - return chosen[1], chosen[2] - - def __string_similar(self, s1, s2): - return difflib.SequenceMatcher(None, s1, s2).quick_ratio() - - +import os +import csv +import difflib +import random +from utils import config_util as cfg +from scheduler.thread_manager import MyThread +import shlex +import subprocess +import time +from utils import util + +class QAService: + + def __init__(self): + # 人设提问关键字 + self.attribute_keyword = [ + [['你叫什么名字', '你的名字是什么'], 'name'], + [['你是男的还是女的', '你是男生还是女生', '你的性别是什么', '你是男生吗', '你是女生吗', '你是男的吗', '你是女的吗', '你是男孩子吗', '你是女孩子吗', ], 'gender', ], + [['你今年多大了', '你多大了', '你今年多少岁', '你几岁了', '你今年几岁了', '你今年几岁了', '你什么时候出生', '你的生日是什么', '你的年龄'], 'age', ], + [['你的家乡在哪', '你的家乡是什么', '你家在哪', '你住在哪', '你出生在哪', '你的出生地在哪', '你的出生地是什么', ], 'birth', ], + [['你的生肖是什么', '你属什么', ], 'zodiac', ], + [['你是什么座', '你是什么星座', '你的星座是什么', ], 'constellation', ], + [['你是做什么的', '你的职业是什么', '你是干什么的', '你的职位是什么', '你的工作是什么', '你是做什么工作的'], 'job', ], + [['你的爱好是什么', '你有爱好吗', '你喜欢什么', '你喜欢做什么'], 'hobby'], + [['联系方式', '联系你们', '怎么联系客服', '有没有客服'], 'contact'] + ] + + self.command_keyword = [ + [['关闭', '再见', '你走吧'], 'stop'], + [['静音', '闭嘴', '我想静静'], 'mute'], + [['取消静音', '你在哪呢', '你可以说话了'], 'unmute'], + [['换个性别', '换个声音'], 'changeVoice'] + ] + + def question(self, query_type, text): + if query_type == 'qa': + answer_dict = self.__read_qna(cfg.config['interact'].get('QnA')) + answer, action = self.__get_keyword(answer_dict, text, query_type) + if action: + MyThread(target=self.__run, args=[action]).start() + return answer, 'qa' + + elif query_type == 'Persona': + answer_dict = self.attribute_keyword + answer, action = self.__get_keyword(answer_dict, text, query_type) + return answer, 'Persona' + elif query_type == 'command': + answer, action = self.__get_keyword(self.command_keyword, text, query_type) + return answer, 'command' + return None, None + + def __run(self, action): + time.sleep(0.1) + args = shlex.split(action) # 分割命令行参数 + subprocess.Popen(args) + + def __read_qna(self, filename): + qna = [] + try: + with open(filename, 'r', encoding='utf-8') as csvfile: + reader = csv.reader(csvfile) + next(reader) # 跳过表头 + for row in reader: + if len(row) >= 2: + qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None]) + except Exception as e: + pass + return qna + + def record_qapair(self, question, answer): + if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv': + util.log(1, 'qa文件没有指定,不记录大模型回复') + return + log_file = cfg.config['interact']['QnA'] # 指定 CSV 文件的名称或路径 + file_exists = os.path.isfile(log_file) + with open(log_file, 'a', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + if not file_exists: + # 写入表头 + writer.writerow(['Question', 'Answer']) + writer.writerow([question, answer]) + + def remove_qapair(self, answer): + """从QA文件中删除指定答案的记录""" + if not cfg.config['interact']['QnA'] or cfg.config['interact']['QnA'][-3:] != 'csv': + util.log(1, 'qa文件没有指定') + return False + log_file = cfg.config['interact']['QnA'] + if not os.path.isfile(log_file): + util.log(1, 'qa文件不存在') + return False + + try: + # 读取所有记录 + rows = [] + with open(log_file, 'r', encoding='utf-8') as csvfile: + reader = csv.reader(csvfile) + rows = list(reader) + + if len(rows) <= 1: + return False + + # 过滤掉匹配答案的记录(保留表头) + # 规范化答案:去掉换行符和首尾空格后比较 + header = rows[0] + filtered_rows = [header] + removed_count = 0 + answer_normalized = answer.replace('\n', '').replace('\r', '').strip() + for row in rows[1:]: + if len(row) >= 2: + row_answer_normalized = row[1].replace('\n', '').replace('\r', '').strip() + if row_answer_normalized == answer_normalized: + removed_count += 1 + else: + filtered_rows.append(row) + else: + filtered_rows.append(row) + + if removed_count > 0: + # 写回文件 + with open(log_file, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerows(filtered_rows) + util.log(1, f'从QA文件中删除了 {removed_count} 条记录') + return True + else: + util.log(1, '未找到匹配的QA记录') + return False + except Exception as e: + util.log(1, f'删除QA记录时出错: {e}') + return False + + def __get_keyword(self, keyword_dict, text, query_type): + threshold = 0.6 + candidates = [] + + for qa in keyword_dict: + if len(qa) < 2: + continue + for quest in qa[0]: + similar = self.__string_similar(text, quest) + if quest in text: + similar += 0.3 + if similar >= threshold: + action = qa[2] if (query_type == "qa" and len(qa) > 2) else None + candidates.append((similar, qa[1], action)) + + if not candidates: + return None, None + + # 从所有超过阈值的候选项中随机选择一个 + chosen = random.choice(candidates) + return chosen[1], chosen[2] + + def __string_similar(self, s1, s2): + return difflib.SequenceMatcher(None, s1, s2).quick_ratio() + + diff --git a/gui/flask_server.py b/gui/flask_server.py index 9ab6450..aae743b 100644 --- a/gui/flask_server.py +++ b/gui/flask_server.py @@ -392,6 +392,8 @@ def adopt_msg(): info = content_db.new_instance().get_content_by_id(id) content = info[3] if info else '' if info is not None: + # 过滤掉 think 标签及其内容 + content = re.sub(r'[\s\S]*?', '', content, flags=re.IGNORECASE).strip() previous_info = content_db.new_instance().get_previous_user_message(id) previous_content = previous_info[3] if previous_info else '' result = content_db.new_instance().adopted_message(id) @@ -405,6 +407,43 @@ def adopt_msg(): except Exception as e: return jsonify({'status':'error', 'msg': f'采纳消息时出错: {e}'}), 500 +@__app.route('/api/unadopt-msg', methods=['POST']) +def unadopt_msg(): + # 取消采纳消息 + data = request.get_json() + if not data: + return jsonify({'status':'error', 'msg': '未提供数据'}) + + id = data.get('id') + + if not id: + return jsonify({'status':'error', 'msg': 'id不能为空'}) + + try: + info = content_db.new_instance().get_content_by_id(id) + if info is None: + return jsonify({'status':'error', 'msg': '消息未找到'}), 404 + + content = info[3] + # 过滤掉 think 标签及其内容,用于匹配 QA 文件中的答案 + clean_content = re.sub(r'[\s\S]*?', '', content, flags=re.IGNORECASE).strip() + + # 从数据库中删除采纳记录,并获取所有相同内容的消息ID + success, same_content_ids = content_db.new_instance().unadopt_message(id, clean_content) + + if success: + # 从 QA 文件中删除对应记录 + qa_service.QAService().remove_qapair(clean_content) + return jsonify({ + 'status': 'success', + 'msg': '取消采纳成功', + 'unadopted_ids': same_content_ids + }) + else: + return jsonify({'status':'error', 'msg': '取消采纳失败'}), 500 + except Exception as e: + return jsonify({'status':'error', 'msg': f'取消采纳时出错: {e}'}), 500 + def gpt_stream_response(last_content, username): sm = stream_manager.new_instance() _, nlp_Stream = sm.get_Stream(username) diff --git a/gui/static/js/index.js b/gui/static/js/index.js index cfdaab5..62bc301 100644 --- a/gui/static/js/index.js +++ b/gui/static/js/index.js @@ -555,7 +555,7 @@ this.fayService.fetchData(`${this.base_url}/api/adopt-msg`, { message: response.msg, // 显示成功消息 type: 'success', }); - + this.loadMessageHistory(this.selectedUser[1], 'adopt'); } else { // 处理失败的响应 @@ -574,6 +574,48 @@ this.fayService.fetchData(`${this.base_url}/api/adopt-msg`, { type: 'error', }); }); +}, + +// 取消采纳 +unadoptText(id) { + this.fayService.fetchData(`${this.base_url}/api/unadopt-msg`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ id }) + }) + .then((response) => { + if (response && response.status === 'success') { + this.$notify({ + title: '成功', + message: response.msg, + type: 'success', + }); + + // 更新本地消息列表中所有相关消息的采纳状态 + if (response.unadopted_ids && response.unadopted_ids.length > 0) { + this.messages.forEach(msg => { + if (response.unadopted_ids.includes(msg.id)) { + msg.is_adopted = 0; + } + }); + } + } else { + this.$notify({ + title: '失败', + message: response ? response.msg : '请求失败', + type: 'error', + }); + } + }) + .catch((error) => { + this.$notify({ + title: '错误', + message: error.message || '请求失败', + type: 'error', + }); + }); } , minimizeThinkPanel() { diff --git a/gui/templates/index.html b/gui/templates/index.html index 5c902bf..2627225 100644 --- a/gui/templates/index.html +++ b/gui/templates/index.html @@ -58,8 +58,8 @@
采纳图标
-
- 采纳图标 +
+ 已采纳图标
diff --git a/qa.csv b/qa.csv index 72450a6..7af3515 100644 --- a/qa.csv +++ b/qa.csv @@ -1,2 +1,3 @@ -问题,答案,脚本 - + +问题,答案,脚本 +哈哈,哈哈,看来你找到乐子了呀!是不是觉得我刚才的回答像个复读机,有点好玩? diff --git a/utils/stream_text_processor.py b/utils/stream_text_processor.py index 07097e7..fd0df5c 100644 --- a/utils/stream_text_processor.py +++ b/utils/stream_text_processor.py @@ -119,7 +119,7 @@ class StreamTextProcessor: username, marked_text, conversation_id=conversation_id ) if success: - accumulated_text = accumulated_text[punct_index + 1:].lstrip() + accumulated_text = accumulated_text[punct_index + 1:] first_sentence_sent = True # 标记已发送第一个句子 sent_successfully = True break