修改retriever的接收参数

This commit is contained in:
2026-01-11 17:30:00 +08:00
parent 165b5788d2
commit 4ce0a9e2cb

View File

@@ -128,17 +128,28 @@ def retrieve_papers(query: str, k: int = 10) -> List:
# ============================================================================
def main():
"""测试检索功能"""
import sys
"""命令行检索入口"""
import argparse
if len(sys.argv) < 2:
query = "natural language to SQL disambiguation"
else:
query = " ".join(sys.argv[1:])
parser = argparse.ArgumentParser(
description="从论文向量数据库中检索相关论文"
)
parser.add_argument(
"query",
type=str,
help="检索查询文本"
)
parser.add_argument(
"--top-k", "-k",
type=int,
default=20,
help="返回结果数量(默认: 20"
)
args = parser.parse_args()
load_dotenv()
print(f"查询: {query}\n")
docs = retrieve_papers(query)
print(f"查询: {args.query}\n")
docs = retrieve_papers(args.query, args.top_k)
for i, doc in enumerate(docs, 1):
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")