Files
paper-embedding/retriever.py
2026-01-11 16:09:16 +08:00

152 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
学术论文检索工具
Usage:
from retriever import retrieve_papers
results = retrieve_papers("natural language to SQL")
"""
from pathlib import Path
from typing import List
import os
from dotenv import load_dotenv
from langchain_community.vectorstores import Chroma
from embeddings import HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
from vector_stores import BaseEmbeddings
# ============================================================================
# Configuration
# ============================================================================
CHROMA_DIR = Path(r"E:\studio2\embedding\papers_chroma_db")
HUGGINGFACE_MODEL = r"E:\hf_models\all-mpnet-base-v2"
# ============================================================================
# Retrieval Service
# ============================================================================
class PaperRetriever:
"""
论文检索服务
使用已有的 embeddings 和 vector_stores 模块进行论文检索
"""
def __init__(
self,
persist_directory: Path,
embeddings_provider: BaseEmbeddings,
):
"""
初始化检索器
Args:
persist_directory: Chroma 数据库目录
embeddings_provider: 嵌入模型提供者
"""
self.persist_directory = Path(persist_directory)
self.embeddings_provider = embeddings_provider
def _validate_db_exists(self) -> None:
"""验证向量数据库是否存在"""
if not self.persist_directory.exists():
raise FileNotFoundError(
f"论文知识库目录不存在: {self.persist_directory}"
f"请先运行 papers_embedding.py 生成数据库"
)
def retrieve(self, query: str, k: int = 10) -> List:
"""
从论文向量数据库中检索相关论文
Args:
query: 检索查询文本
k: 返回结果数量默认10
Returns:
检索到的文档列表
"""
self._validate_db_exists()
embeddings = self.embeddings_provider.get_embeddings()
vectorstore = Chroma(
persist_directory=str(self.persist_directory),
embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
docs = retriever.invoke(query)
print(f"检索到 {len(docs)} 篇相关论文")
return docs
# ============================================================================
# Convenience Functions
# ============================================================================
def create_default_retriever() -> PaperRetriever:
"""
创建默认配置的检索器(使用 HuggingFace 本地模型)
Returns:
配置好的 PaperRetriever 实例
"""
# embeddings = HuggingFaceEmbeddingsProvider(model_name=HUGGINGFACE_MODEL)
embeddings = OpenAIEmbeddingsProvider(
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.chatanywhere.tech/v1",
model="text-embedding-3-large"
)
return PaperRetriever(
persist_directory=CHROMA_DIR,
embeddings_provider=embeddings
)
def retrieve_papers(query: str, k: int = 10) -> List:
"""
便捷函数:检索相关论文
Args:
query: 检索查询文本
k: 返回结果数量默认10
Returns:
检索到的文档列表
"""
retriever = create_default_retriever()
return retriever.retrieve(query, k)
# ============================================================================
# CLI Interface
# ============================================================================
def main():
"""测试检索功能"""
import sys
if len(sys.argv) < 2:
query = "natural language to SQL disambiguation"
else:
query = " ".join(sys.argv[1:])
load_dotenv()
print(f"查询: {query}\n")
docs = retrieve_papers(query)
for i, doc in enumerate(docs, 1):
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
# print(f" 来源: {doc.metadata.get('source_file', 'N/A')}")
# print(f" 内容预览: {doc.page_content[:200]}...")
print()
if __name__ == "__main__":
main()