From 4ce0a9e2cb78aababead3139e80c16d6e479d675 Mon Sep 17 00:00:00 2001 From: along <1015042407@qq.com> Date: Sun, 11 Jan 2026 17:30:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9retriever=E7=9A=84=E6=8E=A5?= =?UTF-8?q?=E6=94=B6=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- retriever.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/retriever.py b/retriever.py index 0974361..5cd6c81 100644 --- a/retriever.py +++ b/retriever.py @@ -128,17 +128,28 @@ def retrieve_papers(query: str, k: int = 10) -> List: # ============================================================================ def main(): - """测试检索功能""" - import sys + """命令行检索入口""" + import argparse - if len(sys.argv) < 2: - query = "natural language to SQL disambiguation" - else: - query = " ".join(sys.argv[1:]) + parser = argparse.ArgumentParser( + description="从论文向量数据库中检索相关论文" + ) + parser.add_argument( + "query", + type=str, + help="检索查询文本" + ) + parser.add_argument( + "--top-k", "-k", + type=int, + default=20, + help="返回结果数量(默认: 20)" + ) + args = parser.parse_args() load_dotenv() - print(f"查询: {query}\n") - docs = retrieve_papers(query) + print(f"查询: {args.query}\n") + docs = retrieve_papers(args.query, args.top_k) for i, doc in enumerate(docs, 1): print(f"[{i}] {doc.metadata.get('title', 'N/A')}")