Source code for modern_bert_score.inference

import gc
from typing import Any, List

import torch
from sentence_transformers import SentenceTransformer
from torch.nn import functional as F
from transformers import AutoTokenizer

try:
    from vllm import LLM

    VLLM_AVAILABLE = True
except ImportError:
    LLM = object  # To prevent NameError if vllm is not installed
    VLLM_AVAILABLE = False


# TODO: Cache reference embeddings
[docs] class Inference: """Abstract base class for inference backends.""" model: Any = None
[docs] def inference( self, candidates: List[str], references: List[str], **kwargs: Any ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: """Computes embeddings for candidates and references. Args: candidates: A list of candidate strings. references: A list of reference strings. **kwargs: Additional arguments passed to the underlying model. Returns: A tuple containing two lists of tensors (candidate_embeddings, reference_embeddings). """ raise NotImplementedError("Method must be implemented in Subclass.")
[docs] class STInference(Inference): """Inference backend using SentenceTransformers.""" def __init__( self, model_id: str, device: str = "cpu", batch_size: int = 64, **kwargs: Any, ): """Initializes the SentenceTransformers inference backend. Args: model_id: The model identifier or path. device: The device to load the model on (e.g., 'cpu', 'cuda'). batch_size: Batch size for inference. **kwargs: Additional arguments for SentenceTransformer. """ self.model = SentenceTransformer( model_name_or_path=model_id, device=device, **kwargs ) self.tokenizer = AutoTokenizer.from_pretrained( model_id ) # TODO: Maybe switch to PreTrainedTokenizerFast for clarity? self.batch_size = batch_size self.eps: float = 1e-12
[docs] def inference( self, candidates: List[str], references: List[str], **kwargs: Any ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: """Computes embeddings using SentenceTransformers. Args: candidates: A list of candidate strings. references: A list of reference strings. **kwargs: Additional arguments passed to `model.encode`. Returns: A tuple containing two lists of tensors (candidate_embeddings, reference_embeddings). """ if self.model is None: raise RuntimeError("Model not loaded.") embds_refs = self.model.encode( references, output_value="token_embeddings", convert_to_tensor=True, **kwargs, ) embds_refs = [ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_refs ] embds_cnds = self.model.encode( candidates, output_value="token_embeddings", convert_to_tensor=True, **kwargs, ) embds_cnds = [ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_cnds ] return embds_cnds, embds_refs
[docs] class VLLMInference(Inference): """Inference backend using vLLM.""" def __init__(self, **kwargs: Any): """Initializes the vLLM inference backend. Args: **kwargs: Arguments passed to the vLLM `LLM` constructor. """ if not VLLM_AVAILABLE: raise ImportError( "vLLM is not installed. To use the vLLM backend, please " "install it with `pip install vllm` or " "`pip install 'modern-bert-score[vllm]'`." ) # Backward compatibility for old callsites that pass task="embed". kwargs = self._prepare_args(kwargs) try: self.model = LLM(**kwargs) except Exception as exc: message = str(exc) if ( "Model architectures" in message and "ModernBertForMaskedLM" in message ): raise RuntimeError( "vLLM does not accept the masked-LM ModernBERT checkpoint " "directly. Export an encoder-only checkpoint first with " "prepare_model.py, which rewrites the saved config to " "advertise ModernBertModel, then load that local path in " "VLLMInference. If you do not need vLLM specifically, use " "STInference for the original HF checkpoint." ) from exc raise self.eps: float = 1e-12 @staticmethod def _prepare_args(kwargs: Any) -> Any: """Prepares arguments for vLLM, setting defaults for embedding tasks.""" task = kwargs.pop("task", None) if task == "embed": kwargs.setdefault("runner", "pooling") kwargs.setdefault("convert", "embed") return kwargs
[docs] def inference( self, candidates: List[str], references: List[str], **kwargs: Any ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: """Computes embeddings using vLLM. Args: candidates: A list of candidate strings. references: A list of reference strings. **kwargs: Additional arguments passed to `model.encode`. Returns: A tuple containing two lists of tensors (candidate_embeddings, reference_embeddings). """ if self.model is None: raise RuntimeError("Model not loaded.") outputs_cands = self.model.encode( candidates, pooling_task="token_embed", **kwargs ) outputs_refs = self.model.encode( references, pooling_task="token_embed", **kwargs ) collector: List[torch.Tensor] = [] for output in outputs_cands: embeds = output.outputs.data collector.append(embeds) for output in outputs_refs: embeds = output.outputs.data collector.append(embeds) collector = [ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in collector ] # TODO: Check superflous? return collector[0 : len(candidates)], collector[len(candidates) :]
[docs] def cleanup(self) -> None: """Cleans up the vLLM model and frees GPU memory.""" if hasattr(self, "model") and self.model: del self.model gc.collect() torch.cuda.empty_cache()