Conversion: PyTorch-specific Tools#
This module contains PyTorch-specific conversion tools.
- tad_mctc.convert.pytorch.any_to_tensor(x, device=None, dtype=None)[source]#
Convert various types of inputs to a PyTorch tensor.
The device and dtype of the tensor can be specified. If the input is already a tensor, it’s converted to the specified device and dtype.
- Parameters:
x (Any) – The input to convert. Can be of type Tensor, float, int, bool, or str, or a list containing float, int, or bool.
device (
torch.device, optional) – The device on which to place the created tensor. If None, the default device is used.dtype (
torch.dtype, optional) – The desired data type for the tensor. If None, the default data type is used or inferred from the input.
- Returns:
A PyTorch tensor representation of the input.
- Return type:
Tensor
- Raises:
ValueError – If
xis a string that cannot be converted to a float or if the list contains elements other than float, int, or bool.TypeError – If
xis of a type that cannot be converted to a tensor.
Examples
>>> any_to_tensor(3.14) tensor(3.1400)
>>> any_to_tensor(42, dtype=torch.float32) tensor(42.)
>>> any_to_tensor(True) tensor(True)
>>> any_to_tensor('2.718') tensor(2.7180)
>>> any_to_tensor('not_a_number') ValueError: Cannot convert string 'not_a_number' to float
>>> any_to_tensor(["1", "2"]) TypeError: Tensor-incompatible type '<class 'list'>' of variable ["1", "2"].
- tad_mctc.convert.pytorch.normalize_device(s)[source]#
Convert any device input to
torch.device. Critically, this also sets the index for CUDA devices totorch.cuda.current_device().- Parameters:
s (
torch.device| str | None) – Name of the device as string.- Returns:
Device as torch class.
- Return type:
torch.device- Raises:
KeyError – Unknown device name is given.
- tad_mctc.convert.pytorch.str_to_device(s)[source]#
Convert device name to
torch.device. Critically, this also sets the index for CUDA devices totorch.cuda.current_device().- Parameters:
s (str) – Name of the device as string.
- Returns:
Device as torch class.
- Return type:
torch.device
- Raises:
KeyError – Unknown device name is given.