163 lines
4.5 KiB
Python
163 lines
4.5 KiB
Python
"""
|
||
学术论文检索工具
|
||
|
||
Usage:
|
||
from retriever import retrieve_papers
|
||
results = retrieve_papers("natural language to SQL")
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import List
|
||
import os
|
||
from dotenv import load_dotenv
|
||
|
||
from langchain_community.vectorstores import Chroma
|
||
|
||
from embeddings import HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||
from vector_stores import BaseEmbeddings
|
||
|
||
|
||
# ============================================================================
|
||
# Configuration
|
||
# ============================================================================
|
||
|
||
CHROMA_DIR = Path(r"E:\studio2\embedding\papers_chroma_db")
|
||
HUGGINGFACE_MODEL = r"E:\hf_models\all-mpnet-base-v2"
|
||
|
||
|
||
# ============================================================================
|
||
# Retrieval Service
|
||
# ============================================================================
|
||
|
||
class PaperRetriever:
|
||
"""
|
||
论文检索服务
|
||
|
||
使用已有的 embeddings 和 vector_stores 模块进行论文检索
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
persist_directory: Path,
|
||
embeddings_provider: BaseEmbeddings,
|
||
):
|
||
"""
|
||
初始化检索器
|
||
|
||
Args:
|
||
persist_directory: Chroma 数据库目录
|
||
embeddings_provider: 嵌入模型提供者
|
||
"""
|
||
self.persist_directory = Path(persist_directory)
|
||
self.embeddings_provider = embeddings_provider
|
||
|
||
def _validate_db_exists(self) -> None:
|
||
"""验证向量数据库是否存在"""
|
||
if not self.persist_directory.exists():
|
||
raise FileNotFoundError(
|
||
f"论文知识库目录不存在: {self.persist_directory},"
|
||
f"请先运行 papers_embedding.py 生成数据库"
|
||
)
|
||
|
||
def retrieve(self, query: str, k: int = 10) -> List:
|
||
"""
|
||
从论文向量数据库中检索相关论文
|
||
|
||
Args:
|
||
query: 检索查询文本
|
||
k: 返回结果数量,默认10
|
||
|
||
Returns:
|
||
检索到的文档列表
|
||
"""
|
||
self._validate_db_exists()
|
||
|
||
embeddings = self.embeddings_provider.get_embeddings()
|
||
vectorstore = Chroma(
|
||
persist_directory=str(self.persist_directory),
|
||
embedding_function=embeddings
|
||
)
|
||
|
||
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
|
||
docs = retriever.invoke(query)
|
||
|
||
print(f"检索到 {len(docs)} 篇相关论文")
|
||
return docs
|
||
|
||
|
||
# ============================================================================
|
||
# Convenience Functions
|
||
# ============================================================================
|
||
|
||
def create_default_retriever() -> PaperRetriever:
|
||
"""
|
||
创建默认配置的检索器(使用 HuggingFace 本地模型)
|
||
|
||
Returns:
|
||
配置好的 PaperRetriever 实例
|
||
"""
|
||
# embeddings = HuggingFaceEmbeddingsProvider(model_name=HUGGINGFACE_MODEL)
|
||
embeddings = OpenAIEmbeddingsProvider(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url="https://api.chatanywhere.tech/v1",
|
||
model="text-embedding-3-large"
|
||
)
|
||
return PaperRetriever(
|
||
persist_directory=CHROMA_DIR,
|
||
embeddings_provider=embeddings
|
||
)
|
||
|
||
|
||
def retrieve_papers(query: str, k: int = 10) -> List:
|
||
"""
|
||
便捷函数:检索相关论文
|
||
|
||
Args:
|
||
query: 检索查询文本
|
||
k: 返回结果数量,默认10
|
||
|
||
Returns:
|
||
检索到的文档列表
|
||
"""
|
||
retriever = create_default_retriever()
|
||
return retriever.retrieve(query, k)
|
||
|
||
|
||
# ============================================================================
|
||
# CLI Interface
|
||
# ============================================================================
|
||
|
||
def main():
|
||
"""命令行检索入口"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description="从论文向量数据库中检索相关论文"
|
||
)
|
||
parser.add_argument(
|
||
"query",
|
||
type=str,
|
||
help="检索查询文本"
|
||
)
|
||
parser.add_argument(
|
||
"--top-k", "-k",
|
||
type=int,
|
||
default=20,
|
||
help="返回结果数量(默认: 20)"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
load_dotenv()
|
||
print(f"查询: {args.query}\n")
|
||
docs = retrieve_papers(args.query, args.top_k)
|
||
|
||
for i, doc in enumerate(docs, 1):
|
||
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
||
print(f" 来源: {doc.metadata.get('source_file', 'N/A')}")
|
||
# print(f" 内容预览: {doc.page_content[:200]}...")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|