Conversion: PyTorch-specific Tools#

This module contains PyTorch-specific conversion tools.

tad_mctc.convert.pytorch.any_to_tensor(x, device=None, dtype=None)[source]#

Convert various types of inputs to a PyTorch tensor.

The device and dtype of the tensor can be specified. If the input is already a tensor, it’s converted to the specified device and dtype.

Parameters:
  • x (Any) – The input to convert. Can be of type Tensor, float, int, bool, or str, or a list containing float, int, or bool.

  • device (torch.device, optional) – The device on which to place the created tensor. If None, the default device is used.

  • dtype (torch.dtype, optional) – The desired data type for the tensor. If None, the default data type is used or inferred from the input.

Returns:

A PyTorch tensor representation of the input.

Return type:

Tensor

Raises:
  • ValueError – If x is a string that cannot be converted to a float or if the list contains elements other than float, int, or bool.

  • TypeError – If x is of a type that cannot be converted to a tensor.

Examples

>>> any_to_tensor(3.14)
tensor(3.1400)
>>> any_to_tensor(42, dtype=torch.float32)
tensor(42.)
>>> any_to_tensor(True)
tensor(True)
>>> any_to_tensor('2.718')
tensor(2.7180)
>>> any_to_tensor('not_a_number')
ValueError: Cannot convert string 'not_a_number' to float
>>> any_to_tensor(["1", "2"])
TypeError: Tensor-incompatible type '<class 'list'>' of variable ["1", "2"].
tad_mctc.convert.pytorch.normalize_device(s)[source]#

Convert any device input to torch.device. Critically, this also sets the index for CUDA devices to torch.cuda.current_device().

Parameters:

s (torch.device | str | None) – Name of the device as string.

Returns:

Device as torch class.

Return type:

torch.device

Raises:

KeyError – Unknown device name is given.

tad_mctc.convert.pytorch.str_to_device(s)[source]#

Convert device name to torch.device. Critically, this also sets the index for CUDA devices to torch.cuda.current_device().

Parameters:

s (str) – Name of the device as string.

Returns:

Device as torch class.

Return type:

torch.device

Raises:

KeyError – Unknown device name is given.