Source code for nlpstack.integrations.torch.model
import typing
from typing import Any, Generic, Optional, Protocol, Tuple, Type, TypeVar, cast
import torch
Inference = TypeVar("Inference")
ModelInputs = TypeVar("ModelInputs", bound=Tuple[Any, ...])
PredictionParams = TypeVar("PredictionParams")
[docs]@typing.runtime_checkable
class LazySetup(Protocol):
[docs] def setup(self, *args: Any, **kwargs: Any) -> None: ...
[docs]@typing.runtime_checkable
class TorchModelOutput(Protocol[Inference]):
inference: Inference
loss: Optional[torch.FloatTensor]
[docs]class TorchModel(torch.nn.Module, Generic[Inference, ModelInputs, PredictionParams]):
Inputs: Type[ModelInputs]
def __call__(
self,
inputs: ModelInputs,
params: Optional[PredictionParams] = None,
) -> TorchModelOutput[Inference]:
return cast(TorchModelOutput[Inference], super().__call__(inputs, params))
[docs] def forward(
self,
inputs: ModelInputs,
params: Optional[PredictionParams] = None,
) -> TorchModelOutput[Inference]:
raise NotImplementedError
[docs] @torch.no_grad()
def infer(
self,
inputs: ModelInputs,
params: Optional[PredictionParams] = None,
) -> Inference:
return self.forward(inputs, params).inference
[docs] def get_device(self) -> torch.device:
return next(self.parameters()).device
[docs] def setup(self, *args: Any, **kwargs: Any) -> None:
for module in self.modules():
if module is self:
continue
if isinstance(module, LazySetup):
module.setup(*args, **kwargs)