71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
"""
|
|
嵌入模型模块
|
|
|
|
提供多种文本嵌入模型的统一接口
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
|
class BaseEmbeddings(ABC):
|
|
"""嵌入模型基类"""
|
|
|
|
@abstractmethod
|
|
def get_embeddings(self):
|
|
"""获取 LangChain 兼容的 embeddings 实例"""
|
|
pass
|
|
|
|
|
|
class HuggingFaceEmbeddingsProvider(BaseEmbeddings):
|
|
"""HuggingFace 本地嵌入模型提供者"""
|
|
|
|
def __init__(self, model_name: str):
|
|
"""
|
|
初始化 HuggingFace 嵌入模型
|
|
|
|
Args:
|
|
model_name: 模型名称或路径
|
|
"""
|
|
self.model_name = model_name
|
|
self._embeddings: Optional[HuggingFaceEmbeddings] = None
|
|
|
|
def get_embeddings(self) -> HuggingFaceEmbeddings:
|
|
"""获取 HuggingFace embeddings 实例(懒加载)"""
|
|
if self._embeddings is None:
|
|
print(f"[INFO] Initializing HuggingFace embeddings: {self.model_name}")
|
|
self._embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
|
return self._embeddings
|
|
|
|
|
|
class OpenAIEmbeddingsProvider(BaseEmbeddings):
|
|
"""OpenAI API 嵌入模型提供者"""
|
|
|
|
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-large"):
|
|
"""
|
|
初始化 OpenAI 嵌入模型
|
|
|
|
Args:
|
|
api_key: OpenAI API 密钥
|
|
base_url: API 基础 URL
|
|
model: 嵌入模型名称
|
|
"""
|
|
self.api_key = api_key
|
|
self.base_url = base_url
|
|
self.model = model
|
|
self._embeddings: Optional[OpenAIEmbeddings] = None
|
|
|
|
def get_embeddings(self) -> OpenAIEmbeddings:
|
|
"""获取 OpenAI embeddings 实例(懒加载)"""
|
|
if self._embeddings is None:
|
|
print(f"[INFO] Initializing OpenAI embeddings: {self.model}")
|
|
self._embeddings = OpenAIEmbeddings(
|
|
model=self.model,
|
|
api_key=self.api_key,
|
|
openai_api_base=self.base_url
|
|
)
|
|
return self._embeddings
|