first commit

This commit is contained in:
2026-01-11 16:09:16 +08:00
parent 38192d5b9d
commit df1279633e
9 changed files with 4119 additions and 0 deletions

6
.gitignore vendored
View File

@@ -1,3 +1,9 @@
# custom .gitignore
.claude/
.vscode/
paper/
papers_chroma_db/
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

70
embeddings.py Normal file
View File

@@ -0,0 +1,70 @@
"""
嵌入模型模块
提供多种文本嵌入模型的统一接口
"""
from abc import ABC, abstractmethod
from typing import Optional
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
class BaseEmbeddings(ABC):
"""嵌入模型基类"""
@abstractmethod
def get_embeddings(self):
"""获取 LangChain 兼容的 embeddings 实例"""
pass
class HuggingFaceEmbeddingsProvider(BaseEmbeddings):
"""HuggingFace 本地嵌入模型提供者"""
def __init__(self, model_name: str):
"""
初始化 HuggingFace 嵌入模型
Args:
model_name: 模型名称或路径
"""
self.model_name = model_name
self._embeddings: Optional[HuggingFaceEmbeddings] = None
def get_embeddings(self) -> HuggingFaceEmbeddings:
"""获取 HuggingFace embeddings 实例(懒加载)"""
if self._embeddings is None:
print(f"[INFO] Initializing HuggingFace embeddings: {self.model_name}")
self._embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
return self._embeddings
class OpenAIEmbeddingsProvider(BaseEmbeddings):
"""OpenAI API 嵌入模型提供者"""
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-large"):
"""
初始化 OpenAI 嵌入模型
Args:
api_key: OpenAI API 密钥
base_url: API 基础 URL
model: 嵌入模型名称
"""
self.api_key = api_key
self.base_url = base_url
self.model = model
self._embeddings: Optional[OpenAIEmbeddings] = None
def get_embeddings(self) -> OpenAIEmbeddings:
"""获取 OpenAI embeddings 实例(懒加载)"""
if self._embeddings is None:
print(f"[INFO] Initializing OpenAI embeddings: {self.model}")
self._embeddings = OpenAIEmbeddings(
model=self.model,
api_key=self.api_key,
openai_api_base=self.base_url
)
return self._embeddings

147
ingest_pipeline.py Normal file
View File

@@ -0,0 +1,147 @@
"""
学术论文向量化和存储系统
模块化、面向对象的论文检索系统,支持将论文文档向量化并持久化到 Chroma 向量数据库。
Usage:
# 基本用法
python papers_embedding.py --input xxx.md
# 自定义输出目录
python papers_embedding.py --input xxx.md --output ./my_db
# 程序化调用
from papers_embedding import PaperIngestionPipeline
pipeline = PaperIngestionPipeline()
pipeline.ingest(Path("icde2025.md"))
"""
import os
import argparse
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from parsers import BaseParser, MarkdownPaperParser, PaperFileReader
from embeddings import BaseEmbeddings, HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
from vector_stores import BaseVectorStore, ChromaVectorStore
# ============================================================================
# Pipeline: 数据摄取管道
# ============================================================================
class PaperIngestionPipeline:
"""
论文摄取管道
整合解析、嵌入、存储三个模块,提供完整的论文向量化流程
"""
def __init__(
self,
parser: Optional[BaseParser] = None,
embeddings_provider: Optional[BaseEmbeddings] = None,
vector_store: Optional[BaseVectorStore] = None,
):
"""
初始化摄取管道
Args:
parser: 论文解析器(默认 MarkdownPaperParser
embeddings_provider: 嵌入模型提供者(需显式指定)
vector_store: 向量存储(需显式指定)
"""
self.parser = parser or MarkdownPaperParser()
self.embeddings_provider = embeddings_provider
self.vector_store = vector_store
def ingest(self, input_file: Path) -> None:
"""
执行论文摄取流程
Args:
input_file: 输入论文文件路径
Raises:
ValueError: 如果未配置 embeddings_provider 或 vector_store
"""
if self.embeddings_provider is None:
raise ValueError("embeddings_provider must be configured")
if self.vector_store is None:
raise ValueError("vector_store must be configured")
# Step 1: 解析论文
reader = PaperFileReader(self.parser)
texts, metadatas = reader.read(input_file)
# Step 2 & 3: 向量化并存储
self.vector_store.persist(texts, metadatas)
# ============================================================================
# CLI Interface
# ============================================================================
def create_default_pipeline(output_dir: str) -> PaperIngestionPipeline:
"""
创建默认配置的摄取管道
Args:
output_dir: 输出目录路径
Returns:
配置好的 PaperIngestionPipeline 实例
"""
load_dotenv()
# 可选:使用 HuggingFace 本地模型
# embeddings = HuggingFaceEmbeddingsProvider(
# model_name=r"E:\hf_models\all-mpnet-base-v2"
# )
# 使用 OpenAI API
embeddings = OpenAIEmbeddingsProvider(
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.chatanywhere.tech/v1",
model="text-embedding-3-large"
)
vector_store = ChromaVectorStore(
persist_directory=Path(output_dir),
embeddings_provider=embeddings
)
return PaperIngestionPipeline(
embeddings_provider=embeddings,
vector_store=vector_store
)
def main() -> None:
"""命令行入口"""
parser = argparse.ArgumentParser(
description="将论文 Markdown 文件向量化并存储到 Chroma 数据库"
)
parser.add_argument(
"--input", "-i",
type=str,
required=True,
help="输入的 Markdown 论文文件路径"
)
parser.add_argument(
"--output", "-o",
type=str,
default="papers_chroma_db",
help="Chroma 数据库输出目录(默认: papers_chroma_db"
)
args = parser.parse_args()
# 创建并执行管道
pipeline = create_default_pipeline(args.output)
pipeline.ingest(Path(args.input))
if __name__ == "__main__":
main()

137
parsers.py Normal file
View File

@@ -0,0 +1,137 @@
"""
数据解析模块
提供论文文档的解析功能,支持多种格式(当前实现 Markdown 格式)
"""
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Tuple
class BaseParser(ABC):
"""解析器基类"""
@abstractmethod
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
"""
解析内容文本
Args:
content: 待解析的文本内容
Returns:
(texts, metadatas): 文本列表和对应的元数据列表
"""
pass
class MarkdownPaperParser(BaseParser):
"""
Markdown 格式论文解析器
解析格式:
- 论文以 '---' 分隔
- 每篇论文以 '## Title' 开头
- 内容包含标题、摘要和其他部分
"""
def __init__(self, separator: str = r'\n---\s*\n', title_pattern: str = r'^##\s+(.+)$'):
"""
初始化解析器
Args:
separator: 论文分隔符正则表达式
title_pattern: 标题匹配正则表达式
"""
self.separator = separator
self.title_pattern = title_pattern
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
"""
解析 markdown 格式的论文内容
Args:
content: markdown 文件内容
Returns:
(texts, metadatas): 论文文本列表和元数据列表
Raises:
ValueError: 如果未找到有效论文
"""
raw_chunks = re.split(self.separator, content)
texts: List[str] = []
metadatas: List[dict] = []
for chunk in raw_chunks:
chunk = chunk.strip()
if not chunk:
continue
# 提取标题
title_match = re.search(self.title_pattern, chunk, re.MULTILINE)
if not title_match:
self._handle_missing_title(chunk)
continue
title = title_match.group(1).strip()
paper_content = chunk
texts.append(paper_content)
metadatas.append({
"title": title,
"content_length": len(paper_content),
})
if not texts:
raise ValueError("No valid papers were found in the content.")
return texts, metadatas
def _handle_missing_title(self, chunk: str) -> None:
"""处理缺少标题的论文块"""
preview = chunk[:50].replace('\n', ' ')
print(f"[WARN] Skipping paper without ## title: {preview}...")
class PaperFileReader:
"""论文文件读取器"""
def __init__(self, parser: BaseParser):
"""
初始化文件读取器
Args:
parser: 内容解析器实例
"""
self.parser = parser
def read(self, file_path: Path) -> Tuple[List[str], List[dict]]:
"""
从文件读取并解析论文
Args:
file_path: 论文文件路径
Returns:
(texts, metadatas): 解析后的文本和元数据
Raises:
FileNotFoundError: 如果文件不存在
"""
if not file_path.exists():
raise FileNotFoundError(f"Paper file not found: {file_path}")
print(f"[INFO] Reading papers from: {file_path}")
content = file_path.read_text(encoding="utf-8")
texts, metadatas = self.parser.parse(content)
# 添加源文件信息到元数据
for meta in metadatas:
meta["source_file"] = file_path.name
return texts, metadatas

17
pyproject.toml Normal file
View File

@@ -0,0 +1,17 @@
[project]
name = "embedding"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"chromadb>=1.4.0",
"huggingface>=0.0.1",
"langchain>=1.2.3",
"langchain-community>=0.4.1",
"langchain-huggingface>=1.2.0",
"langchain-openai>=1.1.7",
"python-dotenv>=1.2.1",
"requests>=2.32.5",
"sentence-transformers>=5.2.0",
]

151
retriever.py Normal file
View File

@@ -0,0 +1,151 @@
"""
学术论文检索工具
Usage:
from retriever import retrieve_papers
results = retrieve_papers("natural language to SQL")
"""
from pathlib import Path
from typing import List
import os
from dotenv import load_dotenv
from langchain_community.vectorstores import Chroma
from embeddings import HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
from vector_stores import BaseEmbeddings
# ============================================================================
# Configuration
# ============================================================================
CHROMA_DIR = Path(r"E:\studio2\embedding\papers_chroma_db")
HUGGINGFACE_MODEL = r"E:\hf_models\all-mpnet-base-v2"
# ============================================================================
# Retrieval Service
# ============================================================================
class PaperRetriever:
"""
论文检索服务
使用已有的 embeddings 和 vector_stores 模块进行论文检索
"""
def __init__(
self,
persist_directory: Path,
embeddings_provider: BaseEmbeddings,
):
"""
初始化检索器
Args:
persist_directory: Chroma 数据库目录
embeddings_provider: 嵌入模型提供者
"""
self.persist_directory = Path(persist_directory)
self.embeddings_provider = embeddings_provider
def _validate_db_exists(self) -> None:
"""验证向量数据库是否存在"""
if not self.persist_directory.exists():
raise FileNotFoundError(
f"论文知识库目录不存在: {self.persist_directory}"
f"请先运行 papers_embedding.py 生成数据库"
)
def retrieve(self, query: str, k: int = 10) -> List:
"""
从论文向量数据库中检索相关论文
Args:
query: 检索查询文本
k: 返回结果数量默认10
Returns:
检索到的文档列表
"""
self._validate_db_exists()
embeddings = self.embeddings_provider.get_embeddings()
vectorstore = Chroma(
persist_directory=str(self.persist_directory),
embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
docs = retriever.invoke(query)
print(f"检索到 {len(docs)} 篇相关论文")
return docs
# ============================================================================
# Convenience Functions
# ============================================================================
def create_default_retriever() -> PaperRetriever:
"""
创建默认配置的检索器(使用 HuggingFace 本地模型)
Returns:
配置好的 PaperRetriever 实例
"""
# embeddings = HuggingFaceEmbeddingsProvider(model_name=HUGGINGFACE_MODEL)
embeddings = OpenAIEmbeddingsProvider(
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.chatanywhere.tech/v1",
model="text-embedding-3-large"
)
return PaperRetriever(
persist_directory=CHROMA_DIR,
embeddings_provider=embeddings
)
def retrieve_papers(query: str, k: int = 10) -> List:
"""
便捷函数:检索相关论文
Args:
query: 检索查询文本
k: 返回结果数量默认10
Returns:
检索到的文档列表
"""
retriever = create_default_retriever()
return retriever.retrieve(query, k)
# ============================================================================
# CLI Interface
# ============================================================================
def main():
"""测试检索功能"""
import sys
if len(sys.argv) < 2:
query = "natural language to SQL disambiguation"
else:
query = " ".join(sys.argv[1:])
load_dotenv()
print(f"查询: {query}\n")
docs = retrieve_papers(query)
for i, doc in enumerate(docs, 1):
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
# print(f" 来源: {doc.metadata.get('source_file', 'N/A')}")
# print(f" 内容预览: {doc.page_content[:200]}...")
print()
if __name__ == "__main__":
main()

3526
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

64
vector_stores.py Normal file
View File

@@ -0,0 +1,64 @@
"""
向量存储模块
提供向量数据库的统一接口,当前支持 Chroma
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from langchain_community.vectorstores import Chroma
from embeddings import BaseEmbeddings
class BaseVectorStore(ABC):
"""向量存储基类"""
@abstractmethod
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
"""
持久化文本到向量数据库
Args:
texts: 文本列表
metadatas: 元数据列表
"""
pass
class ChromaVectorStore(BaseVectorStore):
"""Chroma 向量数据库实现"""
def __init__(self, persist_directory: Path, embeddings_provider: BaseEmbeddings):
"""
初始化 Chroma 向量存储
Args:
persist_directory: 持久化目录
embeddings_provider: 嵌入模型提供者
"""
self.persist_directory = Path(persist_directory)
self.embeddings_provider = embeddings_provider
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
"""
将文本向量化并持久化到 Chroma
Args:
texts: 文本列表
metadatas: 元数据列表
"""
self.persist_directory.mkdir(parents=True, exist_ok=True)
embeddings = self.embeddings_provider.get_embeddings()
print(f"[INFO] Writing {len(texts)} papers to Chroma: {self.persist_directory}")
vectorstore = Chroma.from_texts(
texts=texts,
metadatas=metadatas,
embedding=embeddings,
persist_directory=str(self.persist_directory),
)
print("[OK] Chroma persistence complete.")