Source code for nlpstack.integrations.transformers.cache
from logging import getLogger
from os import PathLike
from typing import Dict, NamedTuple, Optional, Type, Union
try:
import transformers
except ModuleNotFoundError:
transformers = None
logger = getLogger(__name__)
class _ModelSpec(NamedTuple):
model_name: str
auto_cls: Type
class _TokenizerSpec(NamedTuple):
model_name: str
_model_cache: Dict[_ModelSpec, "transformers.PretrainedModel"] = {}
_tokenizer_cache: Dict[_TokenizerSpec, "transformers.PreTrainedTokenizer"] = {}
[docs]def get_pretrained_model(
pretrained_model_name_or_path: Union[str, PathLike],
auto_cls: Optional[Type] = None,
) -> "transformers.PretrainedModel":
global _model_cache
if transformers is None:
raise ModuleNotFoundError("transformers is not installed.")
auto_cls = auto_cls or transformers.AutoModel
spec = _ModelSpec(str(pretrained_model_name_or_path), auto_cls)
if spec in _model_cache:
logger.debug(f"Found cached model: {spec}")
else:
_model_cache[spec] = auto_cls.from_pretrained(pretrained_model_name_or_path)
return _model_cache[spec]
[docs]def get_pretrained_tokenizer(
pretrained_model_name_or_path: Union[str, PathLike],
) -> "transformers.PreTrainedTokenizer":
global _tokenizer_cache
if transformers is None:
raise ModuleNotFoundError("transformers is not installed.")
spec = _TokenizerSpec(str(pretrained_model_name_or_path))
if spec in _tokenizer_cache:
logger.debug(f"Found cached tokenizer: {spec}")
else:
_tokenizer_cache[spec] = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
return _tokenizer_cache[spec]
[docs]def clear() -> None:
global _model_cache
global _tokenizer_cache
_model_cache.clear()
_tokenizer_cache.clear()