修改retriever的接收参数

This commit is contained in:
2026-01-11 17:30:00 +08:00
parent 165b5788d2
commit 4ce0a9e2cb

View File

@@ -128,17 +128,28 @@ def retrieve_papers(query: str, k: int = 10) -> List:
# ============================================================================ # ============================================================================
def main(): def main():
"""测试检索功能""" """命令行检索入口"""
import sys import argparse
if len(sys.argv) < 2: parser = argparse.ArgumentParser(
query = "natural language to SQL disambiguation" description="从论文向量数据库中检索相关论文"
else: )
query = " ".join(sys.argv[1:]) parser.add_argument(
"query",
type=str,
help="检索查询文本"
)
parser.add_argument(
"--top-k", "-k",
type=int,
default=20,
help="返回结果数量(默认: 20"
)
args = parser.parse_args()
load_dotenv() load_dotenv()
print(f"查询: {query}\n") print(f"查询: {args.query}\n")
docs = retrieve_papers(query) docs = retrieve_papers(args.query, args.top_k)
for i, doc in enumerate(docs, 1): for i, doc in enumerate(docs, 1):
print(f"[{i}] {doc.metadata.get('title', 'N/A')}") print(f"[{i}] {doc.metadata.get('title', 'N/A')}")