148 lines
4.2 KiB
Python
148 lines
4.2 KiB
Python
"""
|
||
学术论文向量化和存储系统
|
||
|
||
模块化、面向对象的论文检索系统,支持将论文文档向量化并持久化到 Chroma 向量数据库。
|
||
|
||
Usage:
|
||
# 基本用法
|
||
python papers_embedding.py --input xxx.md
|
||
|
||
# 自定义输出目录
|
||
python papers_embedding.py --input xxx.md --output ./my_db
|
||
|
||
# 程序化调用
|
||
from papers_embedding import PaperIngestionPipeline
|
||
pipeline = PaperIngestionPipeline()
|
||
pipeline.ingest(Path("icde2025.md"))
|
||
"""
|
||
|
||
import os
|
||
import argparse
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from parsers import BaseParser, MarkdownPaperParser, PaperFileReader
|
||
from embeddings import BaseEmbeddings, HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||
from vector_stores import BaseVectorStore, ChromaVectorStore
|
||
|
||
|
||
# ============================================================================
|
||
# Pipeline: 数据摄取管道
|
||
# ============================================================================
|
||
|
||
class PaperIngestionPipeline:
|
||
"""
|
||
论文摄取管道
|
||
|
||
整合解析、嵌入、存储三个模块,提供完整的论文向量化流程
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
parser: Optional[BaseParser] = None,
|
||
embeddings_provider: Optional[BaseEmbeddings] = None,
|
||
vector_store: Optional[BaseVectorStore] = None,
|
||
):
|
||
"""
|
||
初始化摄取管道
|
||
|
||
Args:
|
||
parser: 论文解析器(默认 MarkdownPaperParser)
|
||
embeddings_provider: 嵌入模型提供者(需显式指定)
|
||
vector_store: 向量存储(需显式指定)
|
||
"""
|
||
self.parser = parser or MarkdownPaperParser()
|
||
self.embeddings_provider = embeddings_provider
|
||
self.vector_store = vector_store
|
||
|
||
def ingest(self, input_file: Path) -> None:
|
||
"""
|
||
执行论文摄取流程
|
||
|
||
Args:
|
||
input_file: 输入论文文件路径
|
||
|
||
Raises:
|
||
ValueError: 如果未配置 embeddings_provider 或 vector_store
|
||
"""
|
||
if self.embeddings_provider is None:
|
||
raise ValueError("embeddings_provider must be configured")
|
||
if self.vector_store is None:
|
||
raise ValueError("vector_store must be configured")
|
||
|
||
# Step 1: 解析论文
|
||
reader = PaperFileReader(self.parser)
|
||
texts, metadatas = reader.read(input_file)
|
||
|
||
# Step 2 & 3: 向量化并存储
|
||
self.vector_store.persist(texts, metadatas)
|
||
|
||
|
||
# ============================================================================
|
||
# CLI Interface
|
||
# ============================================================================
|
||
|
||
def create_default_pipeline(output_dir: str) -> PaperIngestionPipeline:
|
||
"""
|
||
创建默认配置的摄取管道
|
||
|
||
Args:
|
||
output_dir: 输出目录路径
|
||
|
||
Returns:
|
||
配置好的 PaperIngestionPipeline 实例
|
||
"""
|
||
load_dotenv()
|
||
|
||
# 可选:使用 HuggingFace 本地模型
|
||
# embeddings = HuggingFaceEmbeddingsProvider(
|
||
# model_name=r"E:\hf_models\all-mpnet-base-v2"
|
||
# )
|
||
|
||
# 使用 OpenAI API
|
||
embeddings = OpenAIEmbeddingsProvider(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url="https://api.chatanywhere.tech/v1",
|
||
model="text-embedding-3-large"
|
||
)
|
||
|
||
vector_store = ChromaVectorStore(
|
||
persist_directory=Path(output_dir),
|
||
embeddings_provider=embeddings
|
||
)
|
||
|
||
return PaperIngestionPipeline(
|
||
embeddings_provider=embeddings,
|
||
vector_store=vector_store
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
"""命令行入口"""
|
||
parser = argparse.ArgumentParser(
|
||
description="将论文 Markdown 文件向量化并存储到 Chroma 数据库"
|
||
)
|
||
parser.add_argument(
|
||
"--input", "-i",
|
||
type=str,
|
||
required=True,
|
||
help="输入的 Markdown 论文文件路径"
|
||
)
|
||
parser.add_argument(
|
||
"--output", "-o",
|
||
type=str,
|
||
default="papers_chroma_db",
|
||
help="Chroma 数据库输出目录(默认: papers_chroma_db)"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 创建并执行管道
|
||
pipeline = create_default_pipeline(args.output)
|
||
pipeline.ingest(Path(args.input))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|