first commit
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,9 @@
|
||||
# custom .gitignore
|
||||
.claude/
|
||||
.vscode/
|
||||
paper/
|
||||
papers_chroma_db/
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
70
embeddings.py
Normal file
70
embeddings.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
嵌入模型模块
|
||||
|
||||
提供多种文本嵌入模型的统一接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
class BaseEmbeddings(ABC):
|
||||
"""嵌入模型基类"""
|
||||
|
||||
@abstractmethod
|
||||
def get_embeddings(self):
|
||||
"""获取 LangChain 兼容的 embeddings 实例"""
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingsProvider(BaseEmbeddings):
|
||||
"""HuggingFace 本地嵌入模型提供者"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""
|
||||
初始化 HuggingFace 嵌入模型
|
||||
|
||||
Args:
|
||||
model_name: 模型名称或路径
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self._embeddings: Optional[HuggingFaceEmbeddings] = None
|
||||
|
||||
def get_embeddings(self) -> HuggingFaceEmbeddings:
|
||||
"""获取 HuggingFace embeddings 实例(懒加载)"""
|
||||
if self._embeddings is None:
|
||||
print(f"[INFO] Initializing HuggingFace embeddings: {self.model_name}")
|
||||
self._embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
return self._embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddingsProvider(BaseEmbeddings):
|
||||
"""OpenAI API 嵌入模型提供者"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-large"):
|
||||
"""
|
||||
初始化 OpenAI 嵌入模型
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API 密钥
|
||||
base_url: API 基础 URL
|
||||
model: 嵌入模型名称
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self._embeddings: Optional[OpenAIEmbeddings] = None
|
||||
|
||||
def get_embeddings(self) -> OpenAIEmbeddings:
|
||||
"""获取 OpenAI embeddings 实例(懒加载)"""
|
||||
if self._embeddings is None:
|
||||
print(f"[INFO] Initializing OpenAI embeddings: {self.model}")
|
||||
self._embeddings = OpenAIEmbeddings(
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
openai_api_base=self.base_url
|
||||
)
|
||||
return self._embeddings
|
||||
147
ingest_pipeline.py
Normal file
147
ingest_pipeline.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
学术论文向量化和存储系统
|
||||
|
||||
模块化、面向对象的论文检索系统,支持将论文文档向量化并持久化到 Chroma 向量数据库。
|
||||
|
||||
Usage:
|
||||
# 基本用法
|
||||
python papers_embedding.py --input xxx.md
|
||||
|
||||
# 自定义输出目录
|
||||
python papers_embedding.py --input xxx.md --output ./my_db
|
||||
|
||||
# 程序化调用
|
||||
from papers_embedding import PaperIngestionPipeline
|
||||
pipeline = PaperIngestionPipeline()
|
||||
pipeline.ingest(Path("icde2025.md"))
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from parsers import BaseParser, MarkdownPaperParser, PaperFileReader
|
||||
from embeddings import BaseEmbeddings, HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||||
from vector_stores import BaseVectorStore, ChromaVectorStore
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline: 数据摄取管道
|
||||
# ============================================================================
|
||||
|
||||
class PaperIngestionPipeline:
|
||||
"""
|
||||
论文摄取管道
|
||||
|
||||
整合解析、嵌入、存储三个模块,提供完整的论文向量化流程
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: Optional[BaseParser] = None,
|
||||
embeddings_provider: Optional[BaseEmbeddings] = None,
|
||||
vector_store: Optional[BaseVectorStore] = None,
|
||||
):
|
||||
"""
|
||||
初始化摄取管道
|
||||
|
||||
Args:
|
||||
parser: 论文解析器(默认 MarkdownPaperParser)
|
||||
embeddings_provider: 嵌入模型提供者(需显式指定)
|
||||
vector_store: 向量存储(需显式指定)
|
||||
"""
|
||||
self.parser = parser or MarkdownPaperParser()
|
||||
self.embeddings_provider = embeddings_provider
|
||||
self.vector_store = vector_store
|
||||
|
||||
def ingest(self, input_file: Path) -> None:
|
||||
"""
|
||||
执行论文摄取流程
|
||||
|
||||
Args:
|
||||
input_file: 输入论文文件路径
|
||||
|
||||
Raises:
|
||||
ValueError: 如果未配置 embeddings_provider 或 vector_store
|
||||
"""
|
||||
if self.embeddings_provider is None:
|
||||
raise ValueError("embeddings_provider must be configured")
|
||||
if self.vector_store is None:
|
||||
raise ValueError("vector_store must be configured")
|
||||
|
||||
# Step 1: 解析论文
|
||||
reader = PaperFileReader(self.parser)
|
||||
texts, metadatas = reader.read(input_file)
|
||||
|
||||
# Step 2 & 3: 向量化并存储
|
||||
self.vector_store.persist(texts, metadatas)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Interface
|
||||
# ============================================================================
|
||||
|
||||
def create_default_pipeline(output_dir: str) -> PaperIngestionPipeline:
|
||||
"""
|
||||
创建默认配置的摄取管道
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
|
||||
Returns:
|
||||
配置好的 PaperIngestionPipeline 实例
|
||||
"""
|
||||
load_dotenv()
|
||||
|
||||
# 可选:使用 HuggingFace 本地模型
|
||||
# embeddings = HuggingFaceEmbeddingsProvider(
|
||||
# model_name=r"E:\hf_models\all-mpnet-base-v2"
|
||||
# )
|
||||
|
||||
# 使用 OpenAI API
|
||||
embeddings = OpenAIEmbeddingsProvider(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url="https://api.chatanywhere.tech/v1",
|
||||
model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
vector_store = ChromaVectorStore(
|
||||
persist_directory=Path(output_dir),
|
||||
embeddings_provider=embeddings
|
||||
)
|
||||
|
||||
return PaperIngestionPipeline(
|
||||
embeddings_provider=embeddings,
|
||||
vector_store=vector_store
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""命令行入口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="将论文 Markdown 文件向量化并存储到 Chroma 数据库"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
type=str,
|
||||
required=True,
|
||||
help="输入的 Markdown 论文文件路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default="papers_chroma_db",
|
||||
help="Chroma 数据库输出目录(默认: papers_chroma_db)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建并执行管道
|
||||
pipeline = create_default_pipeline(args.output)
|
||||
pipeline.ingest(Path(args.input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
parsers.py
Normal file
137
parsers.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
数据解析模块
|
||||
|
||||
提供论文文档的解析功能,支持多种格式(当前实现 Markdown 格式)
|
||||
"""
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""解析器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
|
||||
"""
|
||||
解析内容文本
|
||||
|
||||
Args:
|
||||
content: 待解析的文本内容
|
||||
|
||||
Returns:
|
||||
(texts, metadatas): 文本列表和对应的元数据列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MarkdownPaperParser(BaseParser):
|
||||
"""
|
||||
Markdown 格式论文解析器
|
||||
|
||||
解析格式:
|
||||
- 论文以 '---' 分隔
|
||||
- 每篇论文以 '## Title' 开头
|
||||
- 内容包含标题、摘要和其他部分
|
||||
"""
|
||||
|
||||
def __init__(self, separator: str = r'\n---\s*\n', title_pattern: str = r'^##\s+(.+)$'):
|
||||
"""
|
||||
初始化解析器
|
||||
|
||||
Args:
|
||||
separator: 论文分隔符正则表达式
|
||||
title_pattern: 标题匹配正则表达式
|
||||
"""
|
||||
self.separator = separator
|
||||
self.title_pattern = title_pattern
|
||||
|
||||
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
|
||||
"""
|
||||
解析 markdown 格式的论文内容
|
||||
|
||||
Args:
|
||||
content: markdown 文件内容
|
||||
|
||||
Returns:
|
||||
(texts, metadatas): 论文文本列表和元数据列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果未找到有效论文
|
||||
"""
|
||||
raw_chunks = re.split(self.separator, content)
|
||||
|
||||
texts: List[str] = []
|
||||
metadatas: List[dict] = []
|
||||
|
||||
for chunk in raw_chunks:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# 提取标题
|
||||
title_match = re.search(self.title_pattern, chunk, re.MULTILINE)
|
||||
if not title_match:
|
||||
self._handle_missing_title(chunk)
|
||||
continue
|
||||
|
||||
title = title_match.group(1).strip()
|
||||
paper_content = chunk
|
||||
|
||||
texts.append(paper_content)
|
||||
metadatas.append({
|
||||
"title": title,
|
||||
"content_length": len(paper_content),
|
||||
})
|
||||
|
||||
if not texts:
|
||||
raise ValueError("No valid papers were found in the content.")
|
||||
|
||||
return texts, metadatas
|
||||
|
||||
def _handle_missing_title(self, chunk: str) -> None:
|
||||
"""处理缺少标题的论文块"""
|
||||
preview = chunk[:50].replace('\n', ' ')
|
||||
print(f"[WARN] Skipping paper without ## title: {preview}...")
|
||||
|
||||
|
||||
class PaperFileReader:
|
||||
"""论文文件读取器"""
|
||||
|
||||
def __init__(self, parser: BaseParser):
|
||||
"""
|
||||
初始化文件读取器
|
||||
|
||||
Args:
|
||||
parser: 内容解析器实例
|
||||
"""
|
||||
self.parser = parser
|
||||
|
||||
def read(self, file_path: Path) -> Tuple[List[str], List[dict]]:
|
||||
"""
|
||||
从文件读取并解析论文
|
||||
|
||||
Args:
|
||||
file_path: 论文文件路径
|
||||
|
||||
Returns:
|
||||
(texts, metadatas): 解析后的文本和元数据
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 如果文件不存在
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Paper file not found: {file_path}")
|
||||
|
||||
print(f"[INFO] Reading papers from: {file_path}")
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
texts, metadatas = self.parser.parse(content)
|
||||
|
||||
# 添加源文件信息到元数据
|
||||
for meta in metadatas:
|
||||
meta["source_file"] = file_path.name
|
||||
|
||||
return texts, metadatas
|
||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[project]
|
||||
name = "embedding"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chromadb>=1.4.0",
|
||||
"huggingface>=0.0.1",
|
||||
"langchain>=1.2.3",
|
||||
"langchain-community>=0.4.1",
|
||||
"langchain-huggingface>=1.2.0",
|
||||
"langchain-openai>=1.1.7",
|
||||
"python-dotenv>=1.2.1",
|
||||
"requests>=2.32.5",
|
||||
"sentence-transformers>=5.2.0",
|
||||
]
|
||||
151
retriever.py
Normal file
151
retriever.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
学术论文检索工具
|
||||
|
||||
Usage:
|
||||
from retriever import retrieve_papers
|
||||
results = retrieve_papers("natural language to SQL")
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
from embeddings import HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||||
from vector_stores import BaseEmbeddings
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
CHROMA_DIR = Path(r"E:\studio2\embedding\papers_chroma_db")
|
||||
HUGGINGFACE_MODEL = r"E:\hf_models\all-mpnet-base-v2"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Retrieval Service
|
||||
# ============================================================================
|
||||
|
||||
class PaperRetriever:
|
||||
"""
|
||||
论文检索服务
|
||||
|
||||
使用已有的 embeddings 和 vector_stores 模块进行论文检索
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Path,
|
||||
embeddings_provider: BaseEmbeddings,
|
||||
):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
Args:
|
||||
persist_directory: Chroma 数据库目录
|
||||
embeddings_provider: 嵌入模型提供者
|
||||
"""
|
||||
self.persist_directory = Path(persist_directory)
|
||||
self.embeddings_provider = embeddings_provider
|
||||
|
||||
def _validate_db_exists(self) -> None:
|
||||
"""验证向量数据库是否存在"""
|
||||
if not self.persist_directory.exists():
|
||||
raise FileNotFoundError(
|
||||
f"论文知识库目录不存在: {self.persist_directory},"
|
||||
f"请先运行 papers_embedding.py 生成数据库"
|
||||
)
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> List:
|
||||
"""
|
||||
从论文向量数据库中检索相关论文
|
||||
|
||||
Args:
|
||||
query: 检索查询文本
|
||||
k: 返回结果数量,默认10
|
||||
|
||||
Returns:
|
||||
检索到的文档列表
|
||||
"""
|
||||
self._validate_db_exists()
|
||||
|
||||
embeddings = self.embeddings_provider.get_embeddings()
|
||||
vectorstore = Chroma(
|
||||
persist_directory=str(self.persist_directory),
|
||||
embedding_function=embeddings
|
||||
)
|
||||
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
|
||||
docs = retriever.invoke(query)
|
||||
|
||||
print(f"检索到 {len(docs)} 篇相关论文")
|
||||
return docs
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Convenience Functions
|
||||
# ============================================================================
|
||||
|
||||
def create_default_retriever() -> PaperRetriever:
|
||||
"""
|
||||
创建默认配置的检索器(使用 HuggingFace 本地模型)
|
||||
|
||||
Returns:
|
||||
配置好的 PaperRetriever 实例
|
||||
"""
|
||||
# embeddings = HuggingFaceEmbeddingsProvider(model_name=HUGGINGFACE_MODEL)
|
||||
embeddings = OpenAIEmbeddingsProvider(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url="https://api.chatanywhere.tech/v1",
|
||||
model="text-embedding-3-large"
|
||||
)
|
||||
return PaperRetriever(
|
||||
persist_directory=CHROMA_DIR,
|
||||
embeddings_provider=embeddings
|
||||
)
|
||||
|
||||
|
||||
def retrieve_papers(query: str, k: int = 10) -> List:
|
||||
"""
|
||||
便捷函数:检索相关论文
|
||||
|
||||
Args:
|
||||
query: 检索查询文本
|
||||
k: 返回结果数量,默认10
|
||||
|
||||
Returns:
|
||||
检索到的文档列表
|
||||
"""
|
||||
retriever = create_default_retriever()
|
||||
return retriever.retrieve(query, k)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Interface
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
"""测试检索功能"""
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
query = "natural language to SQL disambiguation"
|
||||
else:
|
||||
query = " ".join(sys.argv[1:])
|
||||
|
||||
load_dotenv()
|
||||
print(f"查询: {query}\n")
|
||||
docs = retrieve_papers(query)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
||||
# print(f" 来源: {doc.metadata.get('source_file', 'N/A')}")
|
||||
# print(f" 内容预览: {doc.page_content[:200]}...")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
64
vector_stores.py
Normal file
64
vector_stores.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
向量存储模块
|
||||
|
||||
提供向量数据库的统一接口,当前支持 Chroma
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
from embeddings import BaseEmbeddings
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
"""向量存储基类"""
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
|
||||
"""
|
||||
持久化文本到向量数据库
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
metadatas: 元数据列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ChromaVectorStore(BaseVectorStore):
|
||||
"""Chroma 向量数据库实现"""
|
||||
|
||||
def __init__(self, persist_directory: Path, embeddings_provider: BaseEmbeddings):
|
||||
"""
|
||||
初始化 Chroma 向量存储
|
||||
|
||||
Args:
|
||||
persist_directory: 持久化目录
|
||||
embeddings_provider: 嵌入模型提供者
|
||||
"""
|
||||
self.persist_directory = Path(persist_directory)
|
||||
self.embeddings_provider = embeddings_provider
|
||||
|
||||
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
|
||||
"""
|
||||
将文本向量化并持久化到 Chroma
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
metadatas: 元数据列表
|
||||
"""
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
embeddings = self.embeddings_provider.get_embeddings()
|
||||
|
||||
print(f"[INFO] Writing {len(texts)} papers to Chroma: {self.persist_directory}")
|
||||
vectorstore = Chroma.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embeddings,
|
||||
persist_directory=str(self.persist_directory),
|
||||
)
|
||||
print("[OK] Chroma persistence complete.")
|
||||
Reference in New Issue
Block a user