修改retriever的接收参数
This commit is contained in:
27
retriever.py
27
retriever.py
@@ -128,17 +128,28 @@ def retrieve_papers(query: str, k: int = 10) -> List:
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""测试检索功能"""
|
"""命令行检索入口"""
|
||||||
import sys
|
import argparse
|
||||||
|
|
||||||
if len(sys.argv) < 2:
|
parser = argparse.ArgumentParser(
|
||||||
query = "natural language to SQL disambiguation"
|
description="从论文向量数据库中检索相关论文"
|
||||||
else:
|
)
|
||||||
query = " ".join(sys.argv[1:])
|
parser.add_argument(
|
||||||
|
"query",
|
||||||
|
type=str,
|
||||||
|
help="检索查询文本"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k", "-k",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="返回结果数量(默认: 20)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
print(f"查询: {query}\n")
|
print(f"查询: {args.query}\n")
|
||||||
docs = retrieve_papers(query)
|
docs = retrieve_papers(args.query, args.top_k)
|
||||||
|
|
||||||
for i, doc in enumerate(docs, 1):
|
for i, doc in enumerate(docs, 1):
|
||||||
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
||||||
|
|||||||
Reference in New Issue
Block a user