🏗️ 系统架构总览
文档加载 → 文本分块 → 本地LLM信息抽取 → 图谱存储
↓ ↓ ↓ ↓ ↓
LangChain 文本处理 Ollama Neo4j
向量检索 Qwen2.5
1. 环境准备与安装
1.1 安装必需软件
# 1. 安装Ollama(直接下载exe)
# 访问 https://ollama.ai/download 下载 Windows 版本
# 2. 安装Python依赖
pip install langchain langchain-community neo4j chromadb pymupdf
pip install sentence-transformers flask streamlit neo4j
pip install requests python-dotenv
1.2 下载中文模型
# 在新的PowerShell窗口下载模型
ollama pull qwen2.5:7b
ollama pull llama3.1:8b
# 验证模型
ollama list
2. 📁 项目文件结构
knowledge_graph/
├── config.py
├── document_processor.py
├── llm_extractor.py
├── kg_storage.py
├── vector_search.py
├── main_builder.py
├── start_windows.py
├── check_environment.py
├── requirements.txt
└── documents/ # 存放你的文档
2.1 配置文件 - config.py
import os
from pathlib import Path
class WindowsConfig:
"""Windows环境配置"""
# 基础路径配置
BASE_DIR = Path(__file__).parent.absolute()
DATA_DIR = BASE_DIR / "data"
DOCUMENTS_DIR = BASE_DIR / "documents"
CHROMA_DIR = BASE_DIR / "chroma_db"
LOG_DIR = BASE_DIR / "logs"
# 创建目录
for directory in [DATA_DIR, DOCUMENTS_DIR, CHROMA_DIR, LOG_DIR]:
directory.mkdir(exist_ok=True)
# Ollama配置
OLLAMA_BASE_URL = "http://localhost:11434"
DEFAULT_MODEL = "qwen2.5:7b"
# Neo4j配置
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "password" # 请修改为你的密码
# 处理配置
CHUNK_SIZE = 800
CHUNK_OVERLAP = 100
BATCH_SIZE = 3 # Windows环境下较小的批处理大小
MAX_TEXT_LENGTH = 2500 # 限制文本长度避免内存问题
# 超时配置
REQUEST_TIMEOUT = 120
HEALTH_CHECK_TIMEOUT = 10
config = WindowsConfig()
2.2 文档处理器 - document_processor.py
import os
import fitz # PyMuPDF
import docx
from pathlib import Path
from typing import List, Dict, Any
import hashlib
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class WindowsDocumentProcessor:
def __init__(self):
self.supported_formats = ['.pdf', '.txt', '.md', '.docx']
logger.info("文档处理器初始化完成")
def load_documents(self, folder_path: str) -> List[Dict[str, Any]]:
"""加载文件夹中的所有文档"""
documents = []
folder_path = Path(folder_path)
if not folder_path.exists():
logger.error(f"目录不存在: {folder_path}")
return documents
if not folder_path.is_dir():
logger.error(f"路径不是目录: {folder_path}")
return documents
file_count = 0
for file_path in folder_path.iterdir():
if file_path.is_file() and file_path.suffix.lower() in self.supported_formats:
file_count += 1
if file_count == 0:
logger.warning(f"在 {folder_path} 中没有找到支持的文档文件")
logger.info(f"支持的文件格式: {', '.join(self.supported_formats)}")
return documents
logger.info(f"找到 {file_count} 个文档文件,开始处理...")
for file_path in folder_path.iterdir():
if not file_path.is_file() or file_path.suffix.lower() not in self.supported_formats:
continue
try:
logger.info(f"处理文件: {file_path.name}")
content = ""
file_ext = file_path.suffix.lower()
if file_ext == '.pdf':
content = self._load_pdf(file_path)
elif file_ext == '.txt':
content = self._load_txt(file_path)
elif file_ext == '.docx':
content = self._load_docx(file_path)
elif file_ext == '.md':
content = self._load_txt(file_path)
if not content or len(content.strip()) < 10:
logger.warning(f"文件内容过少或为空: {file_path.name}")
continue
# 生成文档ID
doc_id = hashlib.md5(f"{file_path.name}{content[:1000]}".encode()).hexdigest()[:16]
documents.append({
'id': doc_id,
'filename': file_path.name,
'content': content.strip(),
'file_path': str(file_path),
'file_size': len(content),
'file_type': file_ext[1:] # 去掉点号
})
logger.info(f"成功加载: {file_path.name} ({len(content)} 字符)")
except Exception as e:
logger.error(f"处理文件 {file_path.name} 时出错: {str(e)}")
continue
logger.info(f"成功加载 {len(documents)} 个文档")
return documents
def _load_pdf(self, file_path: Path) -> str:
"""加载PDF文档"""
try:
doc = fitz.open(file_path)
text = ""
for page_num, page in enumerate(doc):
page_text = page.get_text()
if page_text.strip():
text += f"第{page_num + 1}页:\n{page_text}\n\n"
doc.close()
return text.strip()
except Exception as e:
logger.error(f"PDF读取错误 {file_path}: {str(e)}")
return ""
def _load_txt(self, file_path: Path) -> str:
"""加载文本文件"""
try:
# 尝试多种编码
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
content = f.read()
return content
except UnicodeDecodeError:
continue
logger.error(f"无法解码文本文件: {file_path}")
return ""
except Exception as e:
logger.error(f"文本文件读取错误 {file_path}: {str(e)}")
return ""
def _load_docx(self, file_path: Path) -> str:
"""加载Word文档"""
try:
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text += paragraph.text + "\n"
return text.strip()
except Exception as e:
logger.error(f"DOCX读取错误 {file_path}: {str(e)}")
return ""
if __name__ == "__main__":
# 测试文档处理器
processor = WindowsDocumentProcessor()
test_docs = processor.load_documents("./documents")
print(f"测试加载: {len(test_docs)} 个文档")
2.3 LLM提取器 - llm_extractor.py
import json
import requests
import time
from typing import List, Dict, Any
import logging
from config import config
logger = logging.getLogger(__name__)
class WindowsLLMExtractor:
def __init__(self):
self.base_url = config.OLLAMA_BASE_URL
self.current_model = config.DEFAULT_MODEL
self.request_timeout = config.REQUEST_TIMEOUT
logger.info(f"LLM提取器初始化,使用模型: {self.current_model}")
def check_ollama_health(self) -> bool:
"""检查Ollama服务状态"""
try:
response = requests.get(
f"{self.base_url}/api/tags",
timeout=config.HEALTH_CHECK_TIMEOUT
)
if response.status_code == 200:
return True
else:
logger.warning(f"Ollama服务响应异常: {response.status_code}")
return False
except requests.exceptions.ConnectionError:
logger.error("无法连接到Ollama服务,请检查是否启动")
return False
except Exception as e:
logger.error(f"检查Ollama健康状态时出错: {str(e)}")
return False
def get_available_models(self) -> List[str]:
"""获取可用的模型列表"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=10)
if response.status_code == 200:
models = response.json().get('models', [])
return [model['name'] for model in models]
return []
except Exception as e:
logger.error(f"获取模型列表失败: {str(e)}")
return []
def set_model(self, model_name: str) -> bool:
"""设置使用的模型"""
available_models = self.get_available_models()
if not available_models:
logger.error("无法获取可用模型列表")
return False
# 检查模型是否存在
model_found = any(model_name in model for model in available_models)
if model_found:
self.current_model = model_name
logger.info(f"切换到模型: {model_name}")
return True
else:
logger.error(f"模型 {model_name} 不可用,可用模型: {available_models}")
return False
def extract_knowledge(self, text: str, max_retries: int = 3) -> Dict[str, Any]:
"""使用本地LLM提取实体和关系"""
if not text or len(text.strip()) < 10:
logger.warning("输入文本过短,跳过处理")
return {"entities": [], "relations": []}
# 检查服务状态
if not self.check_ollama_health():
logger.error("Ollama服务不可用,跳过处理")
return {"entities": [], "relations": []}
# 预处理文本
clean_text = self._preprocess_text(text)
logger.info(f"开始知识提取,文本长度: {len(clean_text)}")
for attempt in range(max_retries):
try:
prompt = self._build_extraction_prompt(clean_text)
response = self._call_ollama(prompt)
if not response:
logger.warning(f"第 {attempt + 1} 次调用返回空响应")
continue
result = self._parse_response(response)
if self._validate_result(result):
logger.info(f"成功提取 {len(result['entities'])} 实体, {len(result['relations'])} 关系")
return result
else:
logger.warning(f"第 {attempt + 1} 次提取结果验证失败")
if attempt < max_retries - 1:
time.sleep(2) # 重试前等待
except requests.exceptions.Timeout:
logger.error(f"请求超时 (尝试 {attempt + 1}/{max_retries})")
if attempt < max_retries - 1:
time.sleep(3)
except Exception as e:
logger.error(f"提取失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
if attempt < max_retries - 1:
time.sleep(2)
logger.error(f"所有 {max_retries} 次尝试都失败")
return {"entities": [], "relations": []}
def _preprocess_text(self, text: str) -> str:
"""预处理文本,避免内存溢出"""
if len(text) > config.MAX_TEXT_LENGTH:
logger.warning(f"文本过长 ({len(text)} 字符),进行截断")
# 优先保留开头和结尾部分
half_max = config.MAX_TEXT_LENGTH // 2
text = text[:half_max] + "\n[...文本过长,进行截断...]\n" + text[-half_max:]
# 清理多余的空白字符
text = ' '.join(text.split())
return text
def _build_extraction_prompt(self, text: str) -> str:
"""构建优化的提示词"""
return f"""请从以下文本中提取关键信息,严格按照JSON格式返回。
文本内容:
{text}
提取要求:
1. 识别重要实体(人物、组织、地点、概念、技术、产品、事件、方法等)
2. 提取实体间的重要关系
3. 实体类型要准确
4. 关系描述要具体明确
返回格式:
{{
"entities": [
{{
"name": "实体名称",
"type": "人物|组织|地点|概念|技术|产品|事件|方法",
"description": "实体简要描述"
}}
],
"relations": [
{{
"subject": "主体实体名称",
"relation": "关系类型",
"object": "客体实体名称"
}}
]
}}
请确保:
- 实体名称要准确完整
- 关系描述要具体(如"发明了"、"位于"、"属于"等)
- 只返回JSON格式,不要其他任何内容
- 如果文本中没有明确的关系,relations数组可以为空"""
def _call_ollama(self, prompt: str) -> str:
"""调用Ollama API"""
payload = {
"model": self.current_model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"num_predict": 1200, # 限制输出长度
"top_k": 40,
"top_p": 0.9,
"seed": 42 # 固定随机种子保证可重复性
}
}
try:
response = requests.post(
f"{self.base_url}/api/generate",
json=payload,
timeout=self.request_timeout
)
if response.status_code == 200:
result = response.json()
return result.get("response", "")
else:
error_msg = f"Ollama API错误: {response.status_code}"
if response.text:
error_msg += f" - {response.text}"
raise Exception(error_msg)
except requests.exceptions.Timeout:
logger.error(f"Ollama请求超时 (超过 {self.request_timeout} 秒)")
raise
except Exception as e:
logger.error(f"Ollama调用失败: {str(e)}")
raise
def _parse_response(self, response: str) -> Dict[str, Any]:
"""解析LLM响应"""
if not response:
return {"entities": [], "relations": []}
try:
# 清理响应文本
clean_response = response.strip()
# 提取JSON部分
json_str = clean_response
if '```json' in clean_response:
json_str = clean_response.split('```json')[1].split('```')[0]
elif '```' in clean_response:
json_str = clean_response.split('```')[1].split('```')[0]
# 移除可能的Markdown标记和多余空白
json_str = json_str.replace('```', '').strip()
# 处理可能的JSON格式错误
json_str = self._fix_json_format(json_str)
result = json.loads(json_str)
# 后处理验证
result = self._post_process_result(result)
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.debug(f"原始响应: {response[:500]}...")
return {"entities": [], "relations": []}
except Exception as e:
logger.error(f"解析响应时出错: {str(e)}")
return {"entities": [], "relations": []}
def _fix_json_format(self, json_str: str) -> str:
"""修复常见的JSON格式错误"""
# 移除可能的BOM标记
json_str = json_str.lstrip('\ufeff')
# 确保是有效的JSON
json_str = json_str.strip()
# 如果以换行符开始或结束,移除它们
if json_str.startswith('\n'):
json_str = json_str[1:]
if json_str.endswith('\n'):
json_str = json_str[:-1]
return json_str
def _post_process_result(self, result: Dict) -> Dict:
"""后处理提取结果"""
# 确保必要的字段存在
if "entities" not in result:
result["entities"] = []
if "relations" not in result:
result["relations"] = []
# 清理实体数据
cleaned_entities = []
for entity in result["entities"]:
if isinstance(entity, dict) and entity.get("name") and entity.get("type"):
# 确保有描述字段
if "description" not in entity:
entity["description"] = ""
cleaned_entities.append(entity)
result["entities"] = cleaned_entities
# 清理关系数据
cleaned_relations = []
for relation in result["relations"]:
if (isinstance(relation, dict) and
relation.get("subject") and
relation.get("relation") and
relation.get("object")):
cleaned_relations.append(relation)
result["relations"] = cleaned_relations
return result
def _validate_result(self, result: Dict) -> bool:
"""验证提取结果"""
if not isinstance(result, dict):
return False
required_keys = ["entities", "relations"]
if not all(key in result for key in required_keys):
return False
if not isinstance(result["entities"], list) or not isinstance(result["relations"], list):
return False
# 至少有一些有效内容
if len(result["entities"]) == 0 and len(result["relations"]) == 0:
return False
return True
if __name__ == "__main__":
# 测试LLM提取器
extractor = WindowsLLMExtractor()
if extractor.check_ollama_health():
test_text = "苹果公司由史蒂夫·乔布斯在1976年创立,总部位于加利福尼亚州。"
result = extractor.extract_knowledge(test_text)
print("测试提取结果:", json.dumps(result, ensure_ascii=False, indent=2))
else:
print("Ollama服务不可用")
2.4 知识图谱存储 - kg_storage.py
from neo4j import GraphDatabase
from typing import List, Dict, Any
import logging
from config import config
logger = logging.getLogger(__name__)
class KnowledgeGraphStorage:
def __init__(self, uri: str = None, user: str = None, password: str = None):
self.uri = uri or config.NEO4J_URI
self.user = user or config.NEO4J_USER
self.password = password or config.NEO4J_PASSWORD
self.driver = None
self._connect()
def _connect(self):
"""连接Neo4j数据库"""
try:
self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
# 测试连接
with self.driver.session() as session:
session.run("RETURN 1 AS test")
logger.info("成功连接到Neo4j数据库")
self._init_constraints()
except Exception as e:
logger.error(f"连接Neo4j数据库失败: {str(e)}")
logger.info("请确保:")
logger.info("1. Neo4j数据库已启动")
logger.info("2. 连接信息正确 (URI: %s, 用户: %s)", self.uri, self.user)
logger.info("3. 数据库密码正确")
self.driver = None
def _init_constraints(self):
"""创建数据库约束"""
if not self.driver:
return
try:
with self.driver.session() as session:
# 创建唯一约束
constraints = [
"CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
"CREATE CONSTRAINT IF NOT EXISTS FOR (d:Document) REQUIRE d.id IS UNIQUE",
"CREATE CONSTRAINT IF NOT EXISTS FOR (c:Concept) REQUIRE c.name IS UNIQUE"
]
for constraint in constraints:
session.run(constraint)
logger.info("数据库约束初始化完成")
except Exception as e:
logger.error(f"初始化约束失败: {str(e)}")
def is_connected(self) -> bool:
"""检查数据库连接状态"""
if not self.driver:
return False
try:
with self.driver.session() as session:
result = session.run("RETURN 1 AS test")
return result.single() is not None
except Exception:
return False
def store_knowledge(self, document: Dict, knowledge: Dict):
"""存储提取的知识到图数据库"""
if not self.is_connected():
logger.error("数据库未连接,无法存储知识")
return False
try:
with self.driver.session() as session:
# 开始事务
transaction = session.begin_transaction()
try:
# 1. 存储文档节点
self._store_document(transaction, document)
# 2. 存储实体
entity_map = {}
for entity in knowledge.get("entities", []):
if self._store_entity(transaction, entity, document['id']):
entity_map[entity["name"]] = entity
# 3. 存储关系
for relation in knowledge.get("relations", []):
self._store_relation(transaction, relation, entity_map)
# 提交事务
transaction.commit()
logger.debug(f"成功存储知识: {len(entity_map)} 实体, {len(knowledge.get('relations', []))} 关系")
return True
except Exception as e:
transaction.rollback()
logger.error(f"存储知识事务失败: {str(e)}")
return False
except Exception as e:
logger.error(f"存储知识失败: {str(e)}")
return False
def _store_document(self, tx, document: Dict):
"""存储文档节点"""
query = """
MERGE (d:Document {id: $id})
SET d.filename = $filename,
d.file_path = $file_path,
d.file_type = $file_type,
d.file_size = $file_size,
d.processed_at = datetime(),
d.content_preview = $content_preview
"""
content_preview = document['content'][:200] + "..." if len(document['content']) > 200 else document['content']
tx.run(query,
id=document['id'],
filename=document['filename'],
file_path=document.get('file_path', ''),
file_type=document.get('file_type', 'unknown'),
file_size=document.get('file_size', 0),
content_preview=content_preview)
def _store_entity(self, tx, entity: Dict, doc_id: str) -> bool:
"""存储实体节点"""
if not entity.get("name") or not entity.get("type"):
return False
try:
query = """
MERGE (e:Entity {name: $name})
SET e.type = $type,
e.description = $description,
e.last_updated = datetime(),
e.source_count = COALESCE(e.source_count, 0) + 1
WITH e
MATCH (d:Document {id: $doc_id})
MERGE (e)-[r:APPEARS_IN]->(d)
SET r.first_seen = COALESCE(r.first_seen, datetime())
"""
tx.run(query,
name=entity["name"],
type=entity["type"],
description=entity.get("description", ""),
doc_id=doc_id)
return True
except Exception as e:
logger.error(f"存储实体失败 {entity.get('name')}: {str(e)}")
return False
def _store_relation(self, tx, relation: Dict, entity_map: Dict):
"""存储关系"""
if not all([relation.get("subject"), relation.get("relation"), relation.get("object")]):
return
# 检查关系两端的实体是否存在
if relation["subject"] not in entity_map or relation["object"] not in entity_map:
return
try:
query = """
MATCH (a:Entity {name: $subject}), (b:Entity {name: $object})
MERGE (a)-[r:RELATION {type: $relation_type}]->(b)
SET r.confidence = $confidence,
r.created_at = datetime(),
r.occurrence_count = COALESCE(r.occurrence_count, 0) + 1
"""
tx.run(query,
subject=relation["subject"],
object=relation["object"],
relation_type=relation["relation"],
confidence=relation.get("confidence", 0.9))
except Exception as e:
logger.error(f"存储关系失败 {relation.get('subject')}->{relation.get('object')}: {str(e)}")
def query_entities_by_type(self, entity_type: str, limit: int = 100) -> List[Dict]:
"""按类型查询实体"""
if not self.is_connected():
return []
try:
query = """
MATCH (e:Entity {type: $type})
RETURN e.name as name,
e.description as description,
e.source_count as count
ORDER BY e.source_count DESC
LIMIT $limit
"""
with self.driver.session() as session:
result = session.run(query, type=entity_type, limit=limit)
return [dict(record) for record in result]
except Exception as e:
logger.error(f"查询实体失败: {str(e)}")
return []
def find_related_entities(self, entity_name: str, max_depth: int = 2) -> List[Dict]:
"""查找相关实体"""
if not self.is_connected():
return []
try:
query = """
MATCH path = (start:Entity {name: $name})-[*1..%d]-(related:Entity)
WHERE start <> related
UNWIND relationships(path) as rel
RETURN DISTINCT
start.name as start_entity,
type(rel) as relation_type,
endNode(rel).name as end_entity,
rel.occurrence_count as frequency
ORDER BY frequency DESC
""" % max_depth
with self.driver.session() as session:
result = session.run(query, name=entity_name)
return [dict(record) for record in result]
except Exception as e:
logger.error(f"查找相关实体失败: {str(e)}")
return []
def get_graph_statistics(self) -> Dict[str, Any]:
"""获取图谱统计信息"""
if not self.is_connected():
return {}
try:
queries = {
"total_entities": "MATCH (e:Entity) RETURN count(e) as count",
"total_documents": "MATCH (d:Document) RETURN count(d) as count",
"total_relations": "MATCH ()-[r:RELATION]->() RETURN count(r) as count",
"entity_types": "MATCH (e:Entity) RETURN e.type as type, count(e) as count ORDER BY count DESC",
"top_entities": "MATCH (e:Entity) RETURN e.name as name, e.type as type, e.source_count as count ORDER BY e.source_count DESC LIMIT 10"
}
stats = {}
with self.driver.session() as session:
for key, query in queries.items():
result = session.run(query)
if key in ["entity_types", "top_entities"]:
stats[key] = [dict(record) for record in result]
else:
stats[key] = result.single()[0] if result.single() else 0
return stats
except Exception as e:
logger.error(f"获取统计信息失败: {str(e)}")
return {}
def clear_database(self, confirm: bool = False) -> bool:
"""清空数据库(谨慎使用)"""
if not confirm:
logger.warning("清空数据库需要确认")
return False
if not self.is_connected():
return False
try:
with self.driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")
logger.warning("数据库已清空")
return True
except Exception as e:
logger.error(f"清空数据库失败: {str(e)}")
return False
def close(self):
"""关闭数据库连接"""
if self.driver:
self.driver.close()
logger.info("数据库连接已关闭")
if __name__ == "__main__":
# 测试存储模块
storage = KnowledgeGraphStorage()
if storage.is_connected():
stats = storage.get_graph_statistics()
print("数据库统计:", stats)
else:
print("数据库连接失败")
2.5 向量搜索 - vector_search.py
import logging
from typing import List, Dict, Any
from config import config
logger = logging.getLogger(__name__)
try:
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
VECTOR_SEARCH_AVAILABLE = True
except ImportError as e:
logger.warning(f"向量搜索组件导入失败: {str(e)}")
VECTOR_SEARCH_AVAILABLE = False
class VectorSearch:
def __init__(self, persist_directory: str = None):
if not VECTOR_SEARCH_AVAILABLE:
logger.error("向量搜索不可用,请安装相关依赖")
self.available = False
return
self.persist_directory = persist_directory or str(config.CHROMA_DIR)
self.available = False
try:
# 使用免费的中文Embedding模型
self.embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-zh-v1.5",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# 初始化向量数据库
self.vectorstore = Chroma(
persist_directory=self.persist_directory,
embedding_function=self.embeddings
)
self.available = True
logger.info("向量搜索初始化完成")
except Exception as e:
logger.error(f"向量搜索初始化失败: {str(e)}")
self.available = False
def is_available(self) -> bool:
"""检查向量搜索是否可用"""
return self.available
def add_documents(self, documents: List[Dict]):
"""添加文档到向量数据库"""
if not self.available:
logger.error("向量搜索不可用")
return False
if not documents:
logger.warning("没有文档可添加")
return False
try:
texts = [doc['content'] for doc in documents]
metadatas = [{
'doc_id': doc['id'],
'filename': doc['filename'],
'file_type': doc.get('file_type', 'unknown'),
'file_size': doc.get('file_size', 0)
} for doc in documents]
self.vectorstore.add_texts(texts, metadatas=metadatas)
logger.info(f"成功添加 {len(documents)} 个文档到向量数据库")
return True
except Exception as e:
logger.error(f"添加文档到向量数据库失败: {str(e)}")
return False
def semantic_search(self, query: str, k: int = 5) -> List[Dict]:
"""语义搜索相关文档"""
if not self.available:
logger.error("向量搜索不可用")
return []
if not query or not query.strip():
logger.warning("搜索查询为空")
return []
try:
results = self.vectorstore.similarity_search(query, k=k)
formatted_results = []
for i, doc in enumerate(results):
formatted_results.append({
'rank': i + 1,
'content': doc.page_content,
'metadata': doc.metadata,
'score': getattr(doc, 'score', 0.0) # 有些版本不返回分数
})
logger.info(f"语义搜索完成: '{query}' -> {len(formatted_results)} 结果")
return formatted_results
except Exception as e:
logger.error(f"语义搜索失败: {str(e)}")
return []
def get_document_count(self) -> int:
"""获取向量数据库中的文档数量"""
if not self.available:
return 0
try:
# ChromaDB没有直接的方法获取数量,我们可以通过搜索来估算
results = self.vectorstore.similarity_search("test", k=1000)
return len(results)
except:
return 0
if __name__ == "__main__":
# 测试向量搜索
vs = VectorSearch()
if vs.is_available():
print("向量搜索测试通过")
print(f"文档数量: {vs.get_document_count()}")
else:
print("向量搜索不可用")
2.6 批处理管理器 - batch_processor.py
import time
import logging
from typing import List, Dict, Any, Callable
from config import config
logger = logging.getLogger(__name__)
class BatchProcessor:
def __init__(self):
self.processed_count = 0
self.error_count = 0
self.start_time = None
def process_in_batches(self,
items: List[Any],
processor_func: Callable,
batch_size: int = None,
description: str = "处理") -> List[Any]:
"""分批处理项目"""
if not items:
logger.warning("没有项目需要处理")
return []
batch_size = batch_size or config.BATCH_SIZE
total = len(items)
results = []
self.start_time = time.time()
self.processed_count = 0
self.error_count = 0
logger.info(f"开始{description},共 {total} 个项目,批次大小: {batch_size}")
for i in range(0, total, batch_size):
batch = items[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total + batch_size - 1) // batch_size
logger.info(f"{description}批次 {batch_num}/{total_batches} ({len(batch)} 个项目)")
batch_results = self._process_batch(batch, processor_func, batch_num)
results.extend(batch_results)
# 批次间延迟,避免资源过载
if i + batch_size < total:
delay = 2
logger.info(f"批次完成,等待 {delay} 秒...")
time.sleep(delay)
elapsed = time.time() - self.start_time
logger.info(f"{description}完成! 成功: {self.processed_count}, 失败: {self.error_count}, 耗时: {elapsed:.1f}秒")
return results
def _process_batch(self, batch: List[Any], processor_func: Callable, batch_num: int) -> List[Any]:
"""处理单个批次"""
batch_results = []
for item in batch:
try:
result = processor_func(item)
batch_results.append(result)
self.processed_count += 1
# 项目间小延迟
time.sleep(0.3)
except Exception as e:
logger.error(f"处理项目失败: {str(e)}")
self.error_count += 1
batch_results.append(None)
continue
return batch_results
def get_progress(self) -> Dict[str, Any]:
"""获取处理进度"""
if not self.start_time:
return {}
elapsed = time.time() - self.start_time
return {
'processed': self.processed_count,
'errors': self.error_count,
'elapsed_seconds': elapsed
}
if __name__ == "__main__":
# 测试批处理器
def test_processor(item):
print(f"处理: {item}")
time.sleep(0.1)
return f"处理结果: {item}"
processor = BatchProcessor()
test_items = [f"项目{i}" for i in range(1, 11)]
results = processor.process_in_batches(test_items, test_processor, batch_size=3)
print("处理结果:", results)
2.7 主构建器 - main_builder.py
import os
import time
import logging
from typing import List, Dict, Any
from config import config
from document_processor import WindowsDocumentProcessor
from llm_extractor import WindowsLLMExtractor
from kg_storage import KnowledgeGraphStorage
from vector_search import VectorSearch
from batch_processor import BatchProcessor
logger = logging.getLogger(__name__)
class WindowsKnowledgeGraphBuilder:
def __init__(self):
self.config = config
# 初始化组件
self.doc_processor = WindowsDocumentProcessor()
self.llm_extractor = WindowsLLMExtractor()
self.batch_processor = BatchProcessor()
# 延迟初始化的组件
self.kg_storage = None
self.vector_search = None
# 统计信息
self.stats = {
'documents_processed': 0,
'entities_extracted': 0,
'relations_extracted': 0,
'errors': 0,
'start_time': None,
'end_time': None
}
logger.info("知识图谱构建器初始化完成")
def initialize_components(self) -> bool:
"""初始化所有组件"""
logger.info("初始化组件...")
# 检查Ollama
if not self.llm_extractor.check_ollama_health():
logger.error("Ollama服务不可用,请检查是否启动")
return False
# 初始化知识图谱存储
try:
self.kg_storage = KnowledgeGraphStorage()
if not self.kg_storage.is_connected():
logger.error("Neo4j数据库连接失败")
return False
except Exception as e:
logger.error(f"初始化知识图谱存储失败: {str(e)}")
return False
# 初始化向量搜索
try:
self.vector_search = VectorSearch()
if not self.vector_search.is_available():
logger.warning("向量搜索不可用,将继续但不支持语义搜索")
except Exception as e:
logger.warning(f"初始化向量搜索失败: {str(e)}")
logger.info("所有组件初始化完成")
return True
def build_from_folder(self, folder_path: str) -> bool:
"""从文件夹构建知识图谱"""
logger.info("=== 开始构建知识图谱 ===")
# 记录开始时间
self.stats['start_time'] = time.time()
# 初始化组件
if not self.initialize_components():
logger.error("组件初始化失败,无法继续")
return False
# 检查文件夹
if not os.path.exists(folder_path):
logger.error(f"文件夹不存在: {folder_path}")
return False
# 1. 加载文档
documents = self.doc_processor.load_documents(folder_path)
if not documents:
logger.error("没有找到可处理的文档")
return False
logger.info(f"成功加载 {len(documents)} 个文档")
# 2. 处理文档
def process_document(doc: Dict) -> Dict:
"""处理单个文档"""
try:
logger.info(f"处理文档: {doc['filename']} ({len(doc['content'])} 字符)")
# 使用LLM提取知识
knowledge = self.llm_extractor.extract_knowledge(doc['content'])
if knowledge and (knowledge['entities'] or knowledge['relations']):
# 存储到知识图谱
if self.kg_storage and self.kg_storage.is_connected():
success = self.kg_storage.store_knowledge(doc, knowledge)
if success:
# 更新统计
self.stats['entities_extracted'] += len(knowledge.get('entities', []))
self.stats['relations_extracted'] += len(knowledge.get('relations', []))
self.stats['documents_processed'] += 1
logger.info(f" ✅ 提取 {len(knowledge['entities'])} 实体, {len(knowledge['relations'])} 关系")
else:
logger.error(f" ❌ 存储知识失败")
self.stats['errors'] += 1
else:
logger.warning(" ⚠ 知识图谱存储不可用,跳过存储")
else:
logger.warning(f" ⚠ 未提取到知识")
self.stats['errors'] += 1
return knowledge
except Exception as e:
logger.error(f"处理文档 {doc.get('filename', 'unknown')} 失败: {str(e)}")
self.stats['errors'] += 1
return None
# 分批处理文档
knowledge_results = self.batch_processor.process_in_batches(
documents,
process_document,
batch_size=self.config.BATCH_SIZE,
description="文档处理"
)
# 3. 构建向量索引
if self.vector_search and self.vector_search.is_available():
logger.info("构建向量检索索引...")
try:
self.vector_search.add_documents(documents)
logger.info("向量索引构建完成")
except Exception as e:
logger.error(f"构建向量索引失败: {str(e)}")
# 记录结束时间
self.stats['end_time'] = time.time()
# 打印统计信息
self._print_statistics()
# 保存构建报告
self._save_build_report()
return True
def _print_statistics(self):
"""打印详细的统计信息"""
total_time = self.stats['end_time'] - self.stats['start_time']
print("\n" + "="*60)
print("🎉 知识图谱构建完成!")
print("="*60)
print(f"📊 构建统计")
print("-" * 60)
print(f"📄 处理文档: {self.stats['documents_processed']}")
print(f"🏷️ 提取实体: {self.stats['entities_extracted']}")
print(f"🔗 提取关系: {self.stats['relations_extracted']}")
print(f"❌ 错误数量: {self.stats['errors']}")
print(f"⏱️ 总耗时: {total_time:.1f} 秒")
if self.stats['documents_processed'] > 0:
avg_time_per_doc = total_time / self.stats['documents_processed']
print(f"📈 平均每文档: {avg_time_per_doc:.1f} 秒")
# 显示知识图谱统计
if self.kg_storage and self.kg_storage.is_connected():
print("\n🗃️ 知识图谱统计")
print("-" * 60)
kg_stats = self.kg_storage.get_graph_statistics()
if kg_stats:
print(f"📦 总实体数: {kg_stats.get('total_entities', 0)}")
print(f"📁 总文档数: {kg_stats.get('total_documents', 0)}")
print(f"🔗 总关系数: {kg_stats.get('total_relations', 0)}")
# 显示实体类型分布
entity_types = kg_stats.get('entity_types', [])
if entity_types:
print(f"🏷️ 实体类型分布:")
for et in entity_types[:5]: # 显示前5种类型
print(f" {et['type']}: {et['count']}")
print("="*60)
def _save_build_report(self):
"""保存构建报告到文件"""
report_path = self.config.LOG_DIR / "build_report.txt"
try:
with open(report_path, 'w', encoding='utf-8') as f:
f.write("知识图谱构建报告\n")
f.write("=" * 50 + "\n\n")
f.write(f"构建时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"处理文档: {self.stats['documents_processed']}\n")
f.write(f"提取实体: {self.stats['entities_extracted']}\n")
f.write(f"提取关系: {self.stats['relations_extracted']}\n")
f.write(f"错误数量: {self.stats['errors']}\n")
total_time = self.stats['end_time'] - self.stats['start_time']
f.write(f"总耗时: {total_time:.1f} 秒\n\n")
# 知识图谱统计
if self.kg_storage and self.kg_storage.is_connected():
kg_stats = self.kg_storage.get_graph_statistics()
if kg_stats:
f.write("知识图谱统计:\n")
f.write(f"- 总实体数: {kg_stats.get('total_entities', 0)}\n")
f.write(f"- 总文档数: {kg_stats.get('total_documents', 0)}\n")
f.write(f"- 总关系数: {kg_stats.get('total_relations', 0)}\n")
logger.info(f"构建报告已保存: {report_path}")
except Exception as e:
logger.error(f"保存构建报告失败: {str(e)}")
def query_knowledge_graph(self, query_type: str = "stats", **kwargs):
"""查询知识图谱"""
if not self.kg_storage or not self.kg_storage.is_connected():
logger.error("知识图谱存储不可用")
return None
try:
if query_type == "stats":
return self.kg_storage.get_graph_statistics()
elif query_type == "entities_by_type":
entity_type = kwargs.get('entity_type', '技术')
return self.kg_storage.query_entities_by_type(entity_type)
elif query_type == "related_entities":
entity_name = kwargs.get('entity_name')
depth = kwargs.get('depth', 2)
if entity_name:
return self.kg_storage.find_related_entities(entity_name, depth)
else:
logger.error("需要提供实体名称")
return None
else:
logger.error(f"不支持的查询类型: {query_type}")
return None
except Exception as e:
logger.error(f"查询知识图谱失败: {str(e)}")
return None
def semantic_search(self, query: str, k: int = 5):
"""语义搜索"""
if not self.vector_search or not self.vector_search.is_available():
logger.error("向量搜索不可用")
return []
return self.vector_search.semantic_search(query, k)
def cleanup(self):
"""清理资源"""
if self.kg_storage:
self.kg_storage.close()
logger.info("资源清理完成")
if __name__ == "__main__":
# 测试主构建器
builder = WindowsKnowledgeGraphBuilder()
# 检查环境
if builder.initialize_components():
print("环境检查通过")
# 测试查询
stats = builder.query_knowledge_graph("stats")
print("当前图谱统计:", stats)
else:
print("环境检查失败")
2.8 环境检查脚本 - check_environment.py
#!/usr/bin/env python3
"""
Windows环境检查脚本
检查所有依赖和服务状态
"""
import sys
import importlib
import requests
from config import config
def check_python_version():
"""检查Python版本"""
print("🔍 检查Python版本...")
version = sys.version_info
print(f" Python版本: {sys.version}")
if version.major < 3 or (version.major == 3 and version.minor < 8):
print(" ❌ 需要Python 3.8或更高版本")
return False
else:
print(" ✅ Python版本符合要求")
return True
def check_dependencies():
"""检查Python依赖"""
print("\n🔍 检查Python依赖...")
dependencies = [
"langchain", "langchain-community", "chromadb",
"pymupdf", "sentence-transformers", "flask",
"streamlit", "neo4j", "requests", "python-dotenv", "docx"
]
missing_deps = []
for dep in dependencies:
try:
importlib.import_module(dep.replace('-', '_'))
print(f" ✅ {dep}")
except ImportError as e:
print(f" ❌ {dep} - 未安装")
missing_deps.append(dep)
if missing_deps:
print(f"\n⚠ 缺少依赖: {', '.join(missing_deps)}")
print("请运行: pip install " + " ".join(missing_deps))
return False
else:
print(" ✅ 所有依赖已安装")
return True
def check_ollama():
"""检查Ollama服务"""
print("\n🔍 检查Ollama服务...")
try:
response = requests.get(f"{config.OLLAMA_BASE_URL}/api/tags", timeout=10)
if response.status_code == 200:
models = response.json().get('models', [])
model_names = [model['name'] for model in models]
print(f" ✅ Ollama服务正常")
print(f" 可用模型: {', '.join(model_names)}")
# 检查推荐模型
recommended_models = ["qwen2.5:7b", "llama3.1:8b"]
available_recommended = [model for model in recommended_models if any(model in name for name in model_names)]
if available_recommended:
print(f" ✅ 推荐模型可用: {', '.join(available_recommended)}")
else:
print(f" ⚠ 推荐模型不可用,请运行: ollama pull qwen2.5:7b")
return True
else:
print(f" ❌ Ollama服务异常: {response.status_code}")
return False
except Exception as e:
print(f" ❌ 无法连接到Ollama服务: {str(e)}")
print(" 请确保Ollama已安装并运行: https://ollama.ai/download")
return False
def check_neo4j():
"""检查Neo4j数据库"""
print("\n🔍 检查Neo4j数据库...")
try:
from neo4j import GraphDatabase
driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
with driver.session() as session:
result = session.run("RETURN 1 AS test")
test_value = result.single()[0]
if test_value == 1:
print(" ✅ Neo4j数据库连接正常")
# 检查版本
with driver.session() as session:
version_result = session.run("CALL dbms.components() YIELD versions RETURN versions[0] as version")
version = version_result.single()[0]
print(f" Neo4j版本: {version}")
driver.close()
return True
else:
print(" ❌ Neo4j数据库测试失败")
driver.close()
return False
except Exception as e:
print(f" ❌ Neo4j数据库连接失败: {str(e)}")
print(" 请确保:")
print(" 1. Neo4j数据库已启动")
print(" 2. 连接信息正确 (在config.py中配置)")
print(" 3. 数据库密码正确")
return False
def check_directories():
"""检查目录结构"""
print("\n🔍 检查目录结构...")
directories = [
config.DOCUMENTS_DIR,
config.CHROMA_DIR,
config.LOG_DIR,
config.DATA_DIR
]
all_ok = True
for directory in directories:
if directory.exists():
print(f" ✅ {directory.name} 目录存在")
else:
print(f" ❌ {directory.name} 目录不存在")
all_ok = False
# 检查文档目录是否有文件
doc_files = list(config.DOCUMENTS_DIR.glob("*"))
supported_files = [f for f in doc_files if f.suffix.lower() in ['.pdf', '.txt', '.docx', '.md']]
if supported_files:
print(f" ✅ 文档目录中有 {len(supported_files)} 个支持的文件")
else:
print(f" ⚠ 文档目录中没有支持的文件")
print(f" 支持格式: PDF, TXT, DOCX, MD")
return all_ok
def main():
"""主检查函数"""
print("🚀 Windows知识图谱环境检查")
print("=" * 50)
checks = [
check_python_version(),
check_dependencies(),
check_ollama(),
check_neo4j(),
check_directories()
]
print("\n" + "=" * 50)
if all(checks):
print("🎉 所有检查通过! 可以开始构建知识图谱")
print("\n下一步:")
print("1. 将文档放入 documents/ 文件夹")
print("2. 运行: python start_windows.py")
else:
print("❌ 环境检查失败,请解决上述问题后再试")
failed_checks = [i for i, passed in enumerate(checks) if not passed]
if 2 in failed_checks: # Ollama检查失败
print("\nOllama问题解决:")
print("1. 下载安装: https://ollama.ai/download")
print("2. 启动服务: ollama serve")
print("3. 下载模型: ollama pull qwen2.5:7b")
if 3 in failed_checks: # Neo4j检查失败
print("\nNeo4j问题解决:")
print("1. 下载安装: https://neo4j.com/download/")
print("2. 启动Neo4j Desktop")
print("3. 修改config.py中的密码配置")
if __name__ == "__main__":
main()
2.9 启动脚本 - start_windows.py
#!/usr/bin/env python3
"""
Windows知识图谱构建启动脚本
"""
import os
import sys
from pathlib import Path
def main():
print("🚀 Windows知识图谱构建器")
print("=" * 50)
# 添加当前目录到Python路径
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
# 检查环境
try:
from check_environment import main as check_environment
print("正在检查环境...")
check_environment()
print()
except Exception as e:
print(f"环境检查失败: {e}")
return
# 获取文档路径
from config import config
documents_path = input(f"请输入文档文件夹路径 (直接回车使用默认路径: {config.DOCUMENTS_DIR}): ").strip()
if not documents_path:
documents_path = str(config.DOCUMENTS_DIR)
# 创建文档目录(如果不存在)
doc_path = Path(documents_path)
doc_path.mkdir(exist_ok=True)
# 检查是否有文档
supported_extensions = ['.pdf', '.txt', '.docx', '.md']
doc_files = [f for f in doc_path.iterdir() if f.is_file() and f.suffix.lower() in supported_extensions]
if not doc_files:
print(f"⚠ {documents_path} 文件夹中没有支持的文档文件")
print(f"支持的文件格式: {', '.join(supported_extensions)}")
create_sample = input("是否创建示例文档? (y/n): ").lower().strip()
if create_sample == 'y':
_create_sample_documents(doc_path)
doc_files = [f for f in doc_path.iterdir() if f.is_file() and f.suffix.lower() in supported_extensions]
else:
print("请将文档文件放入文件夹后重新运行程序")
return
print(f"找到 {len(doc_files)} 个文档文件:")
for doc_file in doc_files:
print(f" - {doc_file.name}")
# 确认开始构建
print("\n即将开始构建知识图谱...")
confirm = input("是否继续? (y/n): ").lower().strip()
if confirm != 'y':
print("构建取消")
return
# 开始构建
try:
from main_builder import WindowsKnowledgeGraphBuilder
builder = WindowsKnowledgeGraphBuilder()
success = builder.build_from_folder(documents_path)
if success:
print("\n🎉 知识图谱构建完成!")
# 提供查询选项
while True:
print("\n选择操作:")
print("1. 查看图谱统计")
print("2. 查询实体")
print("3. 查找相关实体")
print("4. 语义搜索")
print("5. 退出")
choice = input("请输入选择 (1-5): ").strip()
if choice == '1':
stats = builder.query_knowledge_graph("stats")
if stats:
print("\n知识图谱统计:")
print(f"实体总数: {stats.get('total_entities', 0)}")
print(f"文档总数: {stats.get('total_documents', 0)}")
print(f"关系总数: {stats.get('total_relations', 0)}")
elif choice == '2':
entity_type = input("请输入实体类型 (如: 人物、组织、技术): ").strip()
if entity_type:
entities = builder.query_knowledge_graph("entities_by_type", entity_type=entity_type)
if entities:
print(f"\n{entity_type}类实体:")
for i, entity in enumerate(entities[:10], 1):
print(f"{i}. {entity['name']} - {entity.get('description', '')}")
else:
print("未找到相关实体")
elif choice == '3':
entity_name = input("请输入实体名称: ").strip()
if entity_name:
related = builder.query_knowledge_graph("related_entities", entity_name=entity_name)
if related:
print(f"\n与 '{entity_name}' 相关的实体:")
for rel in related[:10]:
print(f"{rel['start_entity']} --[{rel['relation_type']}]--> {rel['end_entity']}")
else:
print("未找到相关实体")
elif choice == '4':
if hasattr(builder, 'vector_search') and builder.vector_search.is_available():
query = input("请输入搜索查询: ").strip()
if query:
results = builder.semantic_search(query)
if results:
print(f"\n语义搜索结果:")
for result in results:
print(f"文档: {result['metadata']['filename']}")
print(f"内容: {result['content'][:100]}...")
print()
else:
print("未找到相关文档")
else:
print("向量搜索不可用")
elif choice == '5':
break
else:
print("无效选择")
builder.cleanup()
else:
print("\n❌ 构建失败,请检查错误信息")
except KeyboardInterrupt:
print("\n\n构建被用户中断")
except Exception as e:
print(f"\n❌ 构建过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
def _create_sample_documents(doc_path: Path):
"""创建示例文档"""
sample_content = """
人工智能知识图谱示例文档
人工智能(Artificial Intelligence)是计算机科学的一个分支,由约翰·麦卡锡于1956年提出。
机器学习是人工智能的重要分支,主要研究者包括 Geoffrey Hinton、Yann LeCun 和 Yoshua Bengio。
深度学习是机器学习的一种方法,使用神经网络进行特征学习。
主要技术包括:
- 自然语言处理(NLP):让计算机理解人类语言
- 计算机视觉:让计算机识别图像和视频
- 强化学习:通过试错学习最优策略
应用领域:
- 医疗健康:辅助诊断、药物研发
- 金融服务:风险控制、智能投顾
- 教育培训:个性化学习、智能辅导
知名组织:
- OpenAI:开发了GPT系列模型
- 深度求索(DeepSeek):专注于AI推理优化
- 百度:在中文NLP领域有深入研究
重要人物:
- 艾伦·图灵:提出了图灵测试
- 约翰·麦卡锡:人工智能之父
- 李飞飞:计算机视觉专家
"""
sample_file = doc_path / "AI技术介绍.txt"
try:
with open(sample_file, 'w', encoding='utf-8') as f:
f.write(sample_content)
print(f"✅ 创建示例文档: {sample_file}")
except Exception as e:
print(f"创建示例文档失败: {e}")
if __name__ == "__main__":
main()
2.10 requirement.txt
langchain>=0.1.0
langchain-community>=0.0.10
chromadb>=0.4.0
pymupdf>=1.23.0
sentence-transformers>=2.2.2
flask>=2.3.0
streamlit>=1.28.0
neo4j>=5.12.0
requests>=2.31.0
python-dotenv>=1.0.0
python-docx>=1.1.0
3. 使用前检查:
# 1. 运行环境检查
python check_environment.py
# 2. 安装缺失依赖
pip install -r requirements.txt
# 3. 启动服务
ollama serve
# 4. 运行构建
python start_windows.py

搭建本地知识图谱&spm=1001.2101.3001.5002&articleId=154354356&d=1&t=3&u=014fa232f861462b82c33a71b3cbba8c)
2万+

被折叠的 条评论
为什么被折叠?



