# chinese_chat.py
import os
import gc
import threading
import tkinter as tk
from tkinter import ttk, scrolledtext, messagebox
import queue
import re
# 环境设置
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
class ChineseChatBot:
def __init__(self, root):
self.root = root
self.root.title("中文AI聊天机器人")
self.root.geometry("700x600")
self.model_path = "./models/gpt2"
self.model = None
self.tokenizer = None
self.model_loaded = False
self.message_queue = queue.Queue()
self.create_widgets()
self.root.after(1000, self.load_model)
self.root.after(100, self.process_queue)
def create_widgets(self):
main_frame = ttk.Frame(self.root, padding="10")
main_frame.pack(fill=tk.BOTH, expand=True)
title_label = ttk.Label(main_frame, text="🤖 中文AI聊天机器人",
font=("微软雅黑", 16, "bold"))
title_label.pack(pady=10)
self.status_label = ttk.Label(main_frame, text="正在初始化...",
font=("微软雅黑", 10))
self.status_label.pack(pady=5)
self.chat_text = scrolledtext.ScrolledText(
main_frame,
wrap=tk.WORD,
width=70,
height=20,
font=("微软雅黑", 11),
state=tk.DISABLED
)
self.chat_text.pack(fill=tk.BOTH, expand=True, pady=10)
input_frame = ttk.Frame(main_frame)
input_frame.pack(fill=tk.X, pady=10)
self.input_entry = ttk.Entry(input_frame, font=("微软雅黑", 12))
self.input_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 10))
self.input_entry.bind("<Return>", self.on_send)
self.send_button = ttk.Button(input_frame, text="发送",
command=self.on_send, state=tk.DISABLED)
self.send_button.pack(side=tk.RIGHT)
self.add_message("system", "🚀 中文AI聊天机器人启动中...")
def load_model(self):
def load():
try:
self.update_status("正在加载AI模型...")
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
# 加载分词器
self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
# 加载模型
self.model = GPT2LMHeadModel.from_pretrained(self.model_path)
self.model.eval()
self.model_loaded = True
self.update_status("✅ AI模型加载成功")
self.add_message("system", "🎉 AI模型加载完成!现在可以进行中文对话")
self.send_button.config(state=tk.NORMAL)
except Exception as e:
self.update_status(f"❌ 加载失败: {e}")
self.add_message("system", f"模型加载失败: {e}")
threading.Thread(target=load, daemon=True).start()
def generate_chinese_response(self, user_input):
"""生成中文回复"""
try:
import torch
# 使用中文提示词,强制模型生成中文
input_text = f"用户: {user_input}\n助手:"
inputs = self.tokenizer.encode(input_text, return_tensors='pt')
attention_mask = torch.ones_like(inputs)
# 优化的生成参数,针对中文回复
with torch.no_grad():
outputs = self.model.generate(
inputs,
attention_mask=attention_mask,
max_length=inputs.shape[1] + 50, # 增加长度以容纳中文
num_return_sequences=1,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2,
repetition_penalty=1.1,
early_stopping=False, # 禁用early_stopping
num_beams=1 # 使用贪婪搜索
)
# 解码
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取助手回复
if "助手:" in full_response:
response = full_response.split("助手:")[-1].strip()
else:
response = full_response.replace(input_text, "").strip()
# 清理和优化中文回复
response = self.clean_chinese_response(response)
return response
except Exception as e:
print(f"生成错误: {e}")
return None
def clean_chinese_response(self, text):
"""清理和优化中文回复"""
if not text:
return self.get_chinese_fallback()
# 移除英文内容,只保留中文
chinese_only = re.sub(r'[^\u4e00-\u9fff\s,。!?;:、]', '', text)
chinese_only = re.sub(r'\s+', ' ', chinese_only).strip()
# 如果中文内容足够,返回
if len(chinese_only) >= 2:
return chinese_only
# 如果中文内容不足,尝试提取有意义的部分
chinese_pattern = re.findall(r'[\u4e00-\u9fff]{2,}', text)
if chinese_pattern:
return ' '.join(chinese_pattern[:3]) # 返回前3个中文短语
# 如果还是没有中文,使用备用回复
return self.get_chinese_fallback()
def get_chinese_fallback(self):
"""中文备用回复"""
fallbacks = [
"我明白你的意思。",
"这是一个有趣的话题。",
"你能详细说说吗?",
"我理解你的观点。",
"这让我想到了相关的内容。",
"感谢你分享这个想法。",
"我会认真考虑你说的。",
"这个话题很有意思。"
]
import random
return random.choice(fallbacks)
def on_send(self, event=None):
if not self.model_loaded:
return
user_input = self.input_entry.get().strip()
if not user_input:
return
self.input_entry.delete(0, tk.END)
self.add_message("user", user_input)
self.send_button.config(state=tk.DISABLED)
threading.Thread(target=self.process_chinese_message, args=(user_input,), daemon=True).start()
def process_chinese_message(self, user_input):
"""处理中文消息"""
try:
# 生成中文回复
response = self.generate_chinese_response(user_input)
if not response:
response = self.get_chinese_fallback()
self.message_queue.put(("response", response))
except Exception as e:
print(f"处理错误: {e}")
self.message_queue.put(("response", "抱歉,我遇到了一些问题。"))
def update_status(self, message):
self.message_queue.put(("status", message))
def add_message(self, sender, message):
self.message_queue.put(("add_msg", sender, message))
def process_queue(self):
try:
while True:
msg = self.message_queue.get_nowait()
if msg[0] == "status":
self.status_label.config(text=msg[1])
elif msg[0] == "add_msg":
self._add_message(msg[1], msg[2])
elif msg[0] == "response":
self._add_message("ai", msg[1])
self.send_button.config(state=tk.NORMAL)
except queue.Empty:
pass
self.root.after(100, self.process_queue)
def _add_message(self, sender, message):
self.chat_text.config(state=tk.NORMAL)
if sender == "user":
prefix = "👤 你: "
tag = "user"
elif sender == "ai":
prefix = "🤖 AI: "
tag = "ai"
else:
prefix = "💡 系统: "
tag = "system"
self.chat_text.insert(tk.END, f"{prefix}{message}\n\n", tag)
self.chat_text.tag_config("user", foreground="blue")
self.chat_text.tag_config("ai", foreground="green")
self.chat_text.tag_config("system", foreground="gray")
self.chat_text.config(state=tk.DISABLED)
self.chat_text.see(tk.END)
def main():
root = tk.Tk()
app = ChineseChatBot(root)
root.mainloop()
if __name__ == "__main__":
main()
中文AI聊天机器人代码ZQ-gpt2模型
最新推荐文章于 2025-12-26 02:23:27 发布

831

被折叠的 条评论
为什么被折叠?



