first commit
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,9 @@
|
|||||||
|
# custom .gitignore
|
||||||
|
.claude/
|
||||||
|
.vscode/
|
||||||
|
paper/
|
||||||
|
papers_chroma_db/
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
70
embeddings.py
Normal file
70
embeddings.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
嵌入模型模块
|
||||||
|
|
||||||
|
提供多种文本嵌入模型的统一接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddings(ABC):
|
||||||
|
"""嵌入模型基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_embeddings(self):
|
||||||
|
"""获取 LangChain 兼容的 embeddings 实例"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEmbeddingsProvider(BaseEmbeddings):
|
||||||
|
"""HuggingFace 本地嵌入模型提供者"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
"""
|
||||||
|
初始化 HuggingFace 嵌入模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称或路径
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self._embeddings: Optional[HuggingFaceEmbeddings] = None
|
||||||
|
|
||||||
|
def get_embeddings(self) -> HuggingFaceEmbeddings:
|
||||||
|
"""获取 HuggingFace embeddings 实例(懒加载)"""
|
||||||
|
if self._embeddings is None:
|
||||||
|
print(f"[INFO] Initializing HuggingFace embeddings: {self.model_name}")
|
||||||
|
self._embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingsProvider(BaseEmbeddings):
|
||||||
|
"""OpenAI API 嵌入模型提供者"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-large"):
|
||||||
|
"""
|
||||||
|
初始化 OpenAI 嵌入模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API 密钥
|
||||||
|
base_url: API 基础 URL
|
||||||
|
model: 嵌入模型名称
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model = model
|
||||||
|
self._embeddings: Optional[OpenAIEmbeddings] = None
|
||||||
|
|
||||||
|
def get_embeddings(self) -> OpenAIEmbeddings:
|
||||||
|
"""获取 OpenAI embeddings 实例(懒加载)"""
|
||||||
|
if self._embeddings is None:
|
||||||
|
print(f"[INFO] Initializing OpenAI embeddings: {self.model}")
|
||||||
|
self._embeddings = OpenAIEmbeddings(
|
||||||
|
model=self.model,
|
||||||
|
api_key=self.api_key,
|
||||||
|
openai_api_base=self.base_url
|
||||||
|
)
|
||||||
|
return self._embeddings
|
||||||
147
ingest_pipeline.py
Normal file
147
ingest_pipeline.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
学术论文向量化和存储系统
|
||||||
|
|
||||||
|
模块化、面向对象的论文检索系统,支持将论文文档向量化并持久化到 Chroma 向量数据库。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# 基本用法
|
||||||
|
python papers_embedding.py --input xxx.md
|
||||||
|
|
||||||
|
# 自定义输出目录
|
||||||
|
python papers_embedding.py --input xxx.md --output ./my_db
|
||||||
|
|
||||||
|
# 程序化调用
|
||||||
|
from papers_embedding import PaperIngestionPipeline
|
||||||
|
pipeline = PaperIngestionPipeline()
|
||||||
|
pipeline.ingest(Path("icde2025.md"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from parsers import BaseParser, MarkdownPaperParser, PaperFileReader
|
||||||
|
from embeddings import BaseEmbeddings, HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||||||
|
from vector_stores import BaseVectorStore, ChromaVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Pipeline: 数据摄取管道
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class PaperIngestionPipeline:
|
||||||
|
"""
|
||||||
|
论文摄取管道
|
||||||
|
|
||||||
|
整合解析、嵌入、存储三个模块,提供完整的论文向量化流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parser: Optional[BaseParser] = None,
|
||||||
|
embeddings_provider: Optional[BaseEmbeddings] = None,
|
||||||
|
vector_store: Optional[BaseVectorStore] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化摄取管道
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: 论文解析器(默认 MarkdownPaperParser)
|
||||||
|
embeddings_provider: 嵌入模型提供者(需显式指定)
|
||||||
|
vector_store: 向量存储(需显式指定)
|
||||||
|
"""
|
||||||
|
self.parser = parser or MarkdownPaperParser()
|
||||||
|
self.embeddings_provider = embeddings_provider
|
||||||
|
self.vector_store = vector_store
|
||||||
|
|
||||||
|
def ingest(self, input_file: Path) -> None:
|
||||||
|
"""
|
||||||
|
执行论文摄取流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: 输入论文文件路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果未配置 embeddings_provider 或 vector_store
|
||||||
|
"""
|
||||||
|
if self.embeddings_provider is None:
|
||||||
|
raise ValueError("embeddings_provider must be configured")
|
||||||
|
if self.vector_store is None:
|
||||||
|
raise ValueError("vector_store must be configured")
|
||||||
|
|
||||||
|
# Step 1: 解析论文
|
||||||
|
reader = PaperFileReader(self.parser)
|
||||||
|
texts, metadatas = reader.read(input_file)
|
||||||
|
|
||||||
|
# Step 2 & 3: 向量化并存储
|
||||||
|
self.vector_store.persist(texts, metadatas)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# CLI Interface
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def create_default_pipeline(output_dir: str) -> PaperIngestionPipeline:
|
||||||
|
"""
|
||||||
|
创建默认配置的摄取管道
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 PaperIngestionPipeline 实例
|
||||||
|
"""
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 可选:使用 HuggingFace 本地模型
|
||||||
|
# embeddings = HuggingFaceEmbeddingsProvider(
|
||||||
|
# model_name=r"E:\hf_models\all-mpnet-base-v2"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 使用 OpenAI API
|
||||||
|
embeddings = OpenAIEmbeddingsProvider(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
base_url="https://api.chatanywhere.tech/v1",
|
||||||
|
model="text-embedding-3-large"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = ChromaVectorStore(
|
||||||
|
persist_directory=Path(output_dir),
|
||||||
|
embeddings_provider=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaperIngestionPipeline(
|
||||||
|
embeddings_provider=embeddings,
|
||||||
|
vector_store=vector_store
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""命令行入口"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="将论文 Markdown 文件向量化并存储到 Chroma 数据库"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input", "-i",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="输入的 Markdown 论文文件路径"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o",
|
||||||
|
type=str,
|
||||||
|
default="papers_chroma_db",
|
||||||
|
help="Chroma 数据库输出目录(默认: papers_chroma_db)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 创建并执行管道
|
||||||
|
pipeline = create_default_pipeline(args.output)
|
||||||
|
pipeline.ingest(Path(args.input))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
137
parsers.py
Normal file
137
parsers.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
数据解析模块
|
||||||
|
|
||||||
|
提供论文文档的解析功能,支持多种格式(当前实现 Markdown 格式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class BaseParser(ABC):
|
||||||
|
"""解析器基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
|
||||||
|
"""
|
||||||
|
解析内容文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 待解析的文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(texts, metadatas): 文本列表和对应的元数据列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownPaperParser(BaseParser):
|
||||||
|
"""
|
||||||
|
Markdown 格式论文解析器
|
||||||
|
|
||||||
|
解析格式:
|
||||||
|
- 论文以 '---' 分隔
|
||||||
|
- 每篇论文以 '## Title' 开头
|
||||||
|
- 内容包含标题、摘要和其他部分
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, separator: str = r'\n---\s*\n', title_pattern: str = r'^##\s+(.+)$'):
|
||||||
|
"""
|
||||||
|
初始化解析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
separator: 论文分隔符正则表达式
|
||||||
|
title_pattern: 标题匹配正则表达式
|
||||||
|
"""
|
||||||
|
self.separator = separator
|
||||||
|
self.title_pattern = title_pattern
|
||||||
|
|
||||||
|
def parse(self, content: str) -> Tuple[List[str], List[dict]]:
|
||||||
|
"""
|
||||||
|
解析 markdown 格式的论文内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: markdown 文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(texts, metadatas): 论文文本列表和元数据列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果未找到有效论文
|
||||||
|
"""
|
||||||
|
raw_chunks = re.split(self.separator, content)
|
||||||
|
|
||||||
|
texts: List[str] = []
|
||||||
|
metadatas: List[dict] = []
|
||||||
|
|
||||||
|
for chunk in raw_chunks:
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 提取标题
|
||||||
|
title_match = re.search(self.title_pattern, chunk, re.MULTILINE)
|
||||||
|
if not title_match:
|
||||||
|
self._handle_missing_title(chunk)
|
||||||
|
continue
|
||||||
|
|
||||||
|
title = title_match.group(1).strip()
|
||||||
|
paper_content = chunk
|
||||||
|
|
||||||
|
texts.append(paper_content)
|
||||||
|
metadatas.append({
|
||||||
|
"title": title,
|
||||||
|
"content_length": len(paper_content),
|
||||||
|
})
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("No valid papers were found in the content.")
|
||||||
|
|
||||||
|
return texts, metadatas
|
||||||
|
|
||||||
|
def _handle_missing_title(self, chunk: str) -> None:
|
||||||
|
"""处理缺少标题的论文块"""
|
||||||
|
preview = chunk[:50].replace('\n', ' ')
|
||||||
|
print(f"[WARN] Skipping paper without ## title: {preview}...")
|
||||||
|
|
||||||
|
|
||||||
|
class PaperFileReader:
|
||||||
|
"""论文文件读取器"""
|
||||||
|
|
||||||
|
def __init__(self, parser: BaseParser):
|
||||||
|
"""
|
||||||
|
初始化文件读取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: 内容解析器实例
|
||||||
|
"""
|
||||||
|
self.parser = parser
|
||||||
|
|
||||||
|
def read(self, file_path: Path) -> Tuple[List[str], List[dict]]:
|
||||||
|
"""
|
||||||
|
从文件读取并解析论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 论文文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(texts, metadatas): 解析后的文本和元数据
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: 如果文件不存在
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"Paper file not found: {file_path}")
|
||||||
|
|
||||||
|
print(f"[INFO] Reading papers from: {file_path}")
|
||||||
|
content = file_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
texts, metadatas = self.parser.parse(content)
|
||||||
|
|
||||||
|
# 添加源文件信息到元数据
|
||||||
|
for meta in metadatas:
|
||||||
|
meta["source_file"] = file_path.name
|
||||||
|
|
||||||
|
return texts, metadatas
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "embedding"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"chromadb>=1.4.0",
|
||||||
|
"huggingface>=0.0.1",
|
||||||
|
"langchain>=1.2.3",
|
||||||
|
"langchain-community>=0.4.1",
|
||||||
|
"langchain-huggingface>=1.2.0",
|
||||||
|
"langchain-openai>=1.1.7",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"sentence-transformers>=5.2.0",
|
||||||
|
]
|
||||||
151
retriever.py
Normal file
151
retriever.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""
|
||||||
|
学术论文检索工具
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from retriever import retrieve_papers
|
||||||
|
results = retrieve_papers("natural language to SQL")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
|
||||||
|
from embeddings import HuggingFaceEmbeddingsProvider, OpenAIEmbeddingsProvider
|
||||||
|
from vector_stores import BaseEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
CHROMA_DIR = Path(r"E:\studio2\embedding\papers_chroma_db")
|
||||||
|
HUGGINGFACE_MODEL = r"E:\hf_models\all-mpnet-base-v2"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Retrieval Service
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class PaperRetriever:
|
||||||
|
"""
|
||||||
|
论文检索服务
|
||||||
|
|
||||||
|
使用已有的 embeddings 和 vector_stores 模块进行论文检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
persist_directory: Path,
|
||||||
|
embeddings_provider: BaseEmbeddings,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化检索器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persist_directory: Chroma 数据库目录
|
||||||
|
embeddings_provider: 嵌入模型提供者
|
||||||
|
"""
|
||||||
|
self.persist_directory = Path(persist_directory)
|
||||||
|
self.embeddings_provider = embeddings_provider
|
||||||
|
|
||||||
|
def _validate_db_exists(self) -> None:
|
||||||
|
"""验证向量数据库是否存在"""
|
||||||
|
if not self.persist_directory.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"论文知识库目录不存在: {self.persist_directory},"
|
||||||
|
f"请先运行 papers_embedding.py 生成数据库"
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve(self, query: str, k: int = 10) -> List:
|
||||||
|
"""
|
||||||
|
从论文向量数据库中检索相关论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 检索查询文本
|
||||||
|
k: 返回结果数量,默认10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索到的文档列表
|
||||||
|
"""
|
||||||
|
self._validate_db_exists()
|
||||||
|
|
||||||
|
embeddings = self.embeddings_provider.get_embeddings()
|
||||||
|
vectorstore = Chroma(
|
||||||
|
persist_directory=str(self.persist_directory),
|
||||||
|
embedding_function=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
|
||||||
|
docs = retriever.invoke(query)
|
||||||
|
|
||||||
|
print(f"检索到 {len(docs)} 篇相关论文")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Convenience Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def create_default_retriever() -> PaperRetriever:
|
||||||
|
"""
|
||||||
|
创建默认配置的检索器(使用 HuggingFace 本地模型)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 PaperRetriever 实例
|
||||||
|
"""
|
||||||
|
# embeddings = HuggingFaceEmbeddingsProvider(model_name=HUGGINGFACE_MODEL)
|
||||||
|
embeddings = OpenAIEmbeddingsProvider(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
base_url="https://api.chatanywhere.tech/v1",
|
||||||
|
model="text-embedding-3-large"
|
||||||
|
)
|
||||||
|
return PaperRetriever(
|
||||||
|
persist_directory=CHROMA_DIR,
|
||||||
|
embeddings_provider=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_papers(query: str, k: int = 10) -> List:
|
||||||
|
"""
|
||||||
|
便捷函数:检索相关论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 检索查询文本
|
||||||
|
k: 返回结果数量,默认10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索到的文档列表
|
||||||
|
"""
|
||||||
|
retriever = create_default_retriever()
|
||||||
|
return retriever.retrieve(query, k)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# CLI Interface
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""测试检索功能"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
query = "natural language to SQL disambiguation"
|
||||||
|
else:
|
||||||
|
query = " ".join(sys.argv[1:])
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
print(f"查询: {query}\n")
|
||||||
|
docs = retrieve_papers(query)
|
||||||
|
|
||||||
|
for i, doc in enumerate(docs, 1):
|
||||||
|
print(f"[{i}] {doc.metadata.get('title', 'N/A')}")
|
||||||
|
# print(f" 来源: {doc.metadata.get('source_file', 'N/A')}")
|
||||||
|
# print(f" 内容预览: {doc.page_content[:200]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
64
vector_stores.py
Normal file
64
vector_stores.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
向量存储模块
|
||||||
|
|
||||||
|
提供向量数据库的统一接口,当前支持 Chroma
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
|
||||||
|
from embeddings import BaseEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorStore(ABC):
|
||||||
|
"""向量存储基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
|
||||||
|
"""
|
||||||
|
持久化文本到向量数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
metadatas: 元数据列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaVectorStore(BaseVectorStore):
|
||||||
|
"""Chroma 向量数据库实现"""
|
||||||
|
|
||||||
|
def __init__(self, persist_directory: Path, embeddings_provider: BaseEmbeddings):
|
||||||
|
"""
|
||||||
|
初始化 Chroma 向量存储
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persist_directory: 持久化目录
|
||||||
|
embeddings_provider: 嵌入模型提供者
|
||||||
|
"""
|
||||||
|
self.persist_directory = Path(persist_directory)
|
||||||
|
self.embeddings_provider = embeddings_provider
|
||||||
|
|
||||||
|
def persist(self, texts: List[str], metadatas: List[dict]) -> None:
|
||||||
|
"""
|
||||||
|
将文本向量化并持久化到 Chroma
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
metadatas: 元数据列表
|
||||||
|
"""
|
||||||
|
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
embeddings = self.embeddings_provider.get_embeddings()
|
||||||
|
|
||||||
|
print(f"[INFO] Writing {len(texts)} papers to Chroma: {self.persist_directory}")
|
||||||
|
vectorstore = Chroma.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
embedding=embeddings,
|
||||||
|
persist_directory=str(self.persist_directory),
|
||||||
|
)
|
||||||
|
print("[OK] Chroma persistence complete.")
|
||||||
Reference in New Issue
Block a user