Files
paper-embedding/embeddings.py
2026-01-11 16:09:16 +08:00

71 lines
2.1 KiB
Python

"""
嵌入模型模块
提供多种文本嵌入模型的统一接口
"""
from abc import ABC, abstractmethod
from typing import Optional
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
class BaseEmbeddings(ABC):
"""嵌入模型基类"""
@abstractmethod
def get_embeddings(self):
"""获取 LangChain 兼容的 embeddings 实例"""
pass
class HuggingFaceEmbeddingsProvider(BaseEmbeddings):
"""HuggingFace 本地嵌入模型提供者"""
def __init__(self, model_name: str):
"""
初始化 HuggingFace 嵌入模型
Args:
model_name: 模型名称或路径
"""
self.model_name = model_name
self._embeddings: Optional[HuggingFaceEmbeddings] = None
def get_embeddings(self) -> HuggingFaceEmbeddings:
"""获取 HuggingFace embeddings 实例(懒加载)"""
if self._embeddings is None:
print(f"[INFO] Initializing HuggingFace embeddings: {self.model_name}")
self._embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
return self._embeddings
class OpenAIEmbeddingsProvider(BaseEmbeddings):
"""OpenAI API 嵌入模型提供者"""
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-large"):
"""
初始化 OpenAI 嵌入模型
Args:
api_key: OpenAI API 密钥
base_url: API 基础 URL
model: 嵌入模型名称
"""
self.api_key = api_key
self.base_url = base_url
self.model = model
self._embeddings: Optional[OpenAIEmbeddings] = None
def get_embeddings(self) -> OpenAIEmbeddings:
"""获取 OpenAI embeddings 实例(懒加载)"""
if self._embeddings is None:
print(f"[INFO] Initializing OpenAI embeddings: {self.model}")
self._embeddings = OpenAIEmbeddings(
model=self.model,
api_key=self.api_key,
openai_api_base=self.base_url
)
return self._embeddings