修改retriever的接收参数
This commit is contained in:
27
retriever.py
27
retriever.py
@@ -128,17 +128,28 @@ def retrieve_papers(query: str, k: int = 10) -> List:
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
"""测试检索功能"""
|
||||
import sys
|
||||
"""命令行检索入口"""
|
||||
import argparse
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
query = "natural language to SQL disambiguation"
|
||||
else:
|
||||
query = " ".join(sys.argv[1:])
|
||||
parser = argparse.ArgumentParser(
|
||||
description="从论文向量数据库中检索相关论文"
|
||||
)
|
||||
parser.add_argument(
|
||||
"query",
|
||||
type=str,
|
||||
help="检索查询文本"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", "-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="返回结果数量(默认: 20)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
print(f"查询: {query}\n")
|
||||
docs = retrieve_papers(query)
|
||||
print(f"查询: {args.query}\n")
|
||||
docs = retrieve_papers(args.query, args.top_k)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
||||
|
||||
Reference in New Issue
Block a user