Unpacks packed tensors into their constituents and removes padding.
This function acts as the inverse of the pack operation. It splits a packed tensor along a specified axis and removes any padding that was added during packing.
- Parameters:
tensor (Tensor) – The tensor to be unpacked.
value (int | float | bool, optional) – The identity of the padding value, by default 0.
axis (int, optional) – The axis along which the tensor was packed, by default 0.
- returns:
A tuple of the constituent tensors after unpacking and deflating.
- rtype:
tuple[Tensor, …]
Examples
Suppose you have a tensor that has been packed along the first axis (axis=0):
>>> packed_tensor = torch.tensor([
>>> [1, 2, 3, 0, 0],
>>> [4, 5, 0, 0, 0],
>>> [6, 7, 8, 9, 0]
>>> ])
Unpacking this tensor would yield:
>>> unpacked_tensors = unpack(packed_tensor, value=0, axis=0)
>>> for tensor in unpacked_tensors:
>>> print(tensor)
tensor([1, 2, 3])
tensor([4, 5])
tensor([6, 7, 8, 9])