Files
paper-embedding/ingest_pipeline.py
2026-01-11 16:09:16 +08:00

148 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
学术论文向量化和存储系统
模块化、面向对象的论文检索系统,支持将论文文档向量化并持久化到 Chroma 向量数据库。
Usage:
# 基本用法
python papers_embedding.py --input xxx.md
# 自定义输出目录
python papers_embedding.py --input xxx.md --output ./my_db
# 程序化调用
from papers_embedding import PaperIngestionPipeline
pipeline = PaperIngestionPipeline()
pipeline.ingest(Path("icde2025.md"))
"""
import os
import argparse
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from parsers import BaseParser, MarkdownPaperParser, PaperFileReader
from embeddings import BaseEmbeddings, HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
from vector_stores import BaseVectorStore, ChromaVectorStore
# ============================================================================
# Pipeline: 数据摄取管道
# ============================================================================
class PaperIngestionPipeline:
"""
论文摄取管道
整合解析、嵌入、存储三个模块,提供完整的论文向量化流程
"""
def __init__(
self,
parser: Optional[BaseParser] = None,
embeddings_provider: Optional[BaseEmbeddings] = None,
vector_store: Optional[BaseVectorStore] = None,
):
"""
初始化摄取管道
Args:
parser: 论文解析器(默认 MarkdownPaperParser
embeddings_provider: 嵌入模型提供者(需显式指定)
vector_store: 向量存储(需显式指定)
"""
self.parser = parser or MarkdownPaperParser()
self.embeddings_provider = embeddings_provider
self.vector_store = vector_store
def ingest(self, input_file: Path) -> None:
"""
执行论文摄取流程
Args:
input_file: 输入论文文件路径
Raises:
ValueError: 如果未配置 embeddings_provider 或 vector_store
"""
if self.embeddings_provider is None:
raise ValueError("embeddings_provider must be configured")
if self.vector_store is None:
raise ValueError("vector_store must be configured")
# Step 1: 解析论文
reader = PaperFileReader(self.parser)
texts, metadatas = reader.read(input_file)
# Step 2 & 3: 向量化并存储
self.vector_store.persist(texts, metadatas)
# ============================================================================
# CLI Interface
# ============================================================================
def create_default_pipeline(output_dir: str) -> PaperIngestionPipeline:
"""
创建默认配置的摄取管道
Args:
output_dir: 输出目录路径
Returns:
配置好的 PaperIngestionPipeline 实例
"""
load_dotenv()
# 可选:使用 HuggingFace 本地模型
# embeddings = HuggingFaceEmbeddingsProvider(
# model_name=r"E:\hf_models\all-mpnet-base-v2"
# )
# 使用 OpenAI API
embeddings = OpenAIEmbeddingsProvider(
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.chatanywhere.tech/v1",
model="text-embedding-3-large"
)
vector_store = ChromaVectorStore(
persist_directory=Path(output_dir),
embeddings_provider=embeddings
)
return PaperIngestionPipeline(
embeddings_provider=embeddings,
vector_store=vector_store
)
def main() -> None:
"""命令行入口"""
parser = argparse.ArgumentParser(
description="将论文 Markdown 文件向量化并存储到 Chroma 数据库"
)
parser.add_argument(
"--input", "-i",
type=str,
required=True,
help="输入的 Markdown 论文文件路径"
)
parser.add_argument(
"--output", "-o",
type=str,
default="papers_chroma_db",
help="Chroma 数据库输出目录(默认: papers_chroma_db"
)
args = parser.parse_args()
# 创建并执行管道
pipeline = create_default_pipeline(args.output)
pipeline.ingest(Path(args.input))
if __name__ == "__main__":
main()