Source code for nlpstack.integrations.torch.picklable

import tempfile
from typing import Any, ClassVar, Dict, Sequence

import torch


[docs]class TorchPicklable: # type: ignore[misc] cuda_dependent_attributes: ClassVar[Sequence[str]] = [] def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() cuda_attrs: Dict[str, Any] = {} for attr in self.cuda_dependent_attributes: if attr in state: cuda_attrs[attr] = state.pop(attr) with tempfile.SpooledTemporaryFile() as f: torch.save(cuda_attrs, f) f.seek(0) state["__cuda_dependent_attributes__"] = f.read() return state def __setstate__(self, state: Dict[str, Any]) -> None: with tempfile.SpooledTemporaryFile() as f: f.write(state.pop("__cuda_dependent_attributes__")) f.seek(0) cuda_attrs = torch.load(f, map_location=torch.device("cpu")) state.update(cuda_attrs) self.__dict__.update(state)