如何在 PyTorch 中执行展开操作?
Tensor.expand()属性用于执行展开操作。它沿着单例维度将张量扩展到新维度。
扩展张量只会创建原始张量的新视图;它不会复制原始张量。
如果您将特定维度设置为-1,则张量将不会沿此维度展开。
例如,如果我们有一个大小为(3,1)的张量,我们可以沿着大小为1的维度扩展这个张量。
脚步
要扩展张量,可以按照以下步骤操作-
导入火炬库。确保您已经安装了它。
import torch
将至少具有一维的张量定义为单例。
t = torch.tensor([[1],[2],[3]])
沿单例维度展开张量。沿非单一维度展开将引发运行时错误(参见示例3)。
t_exp = t.expand(3,2)
显示扩展的张量。
print("Tensor after expand:\n", t_exp)
示例1
以下Python程序展示了如何将大小为(3,1)的张量扩展为大小为(3,2)的张量。它沿维度大小1扩展张量。大小为3的另一个维度保持不变。
# import required libraries import torch # create a tensor t = torch.tensor([[1],[2],[3]]) # display the tensor print("Tensor:\n", t) print("Size of Tensor:\n", t.size()) # expand the tensor exp = t.expand(3,2) print("Tensor after expansion:\n", exp)输出结果
Tensor: tensor([[1], [2], [3]]) Size of Tensor: torch.Size([3, 1]) Tensor after expansion: tensor([[1, 1], [2, 2], [3, 3]])
示例2
以下Python程序将大小为(1,3)的张量扩展为大小为(3,3)的张量。它沿维度大小为1扩展张量。
# import required libraries import torch # create a tensor t = torch.tensor([[1,2,3]]) # display the tensor print("Tensor:\n", t) # size of tensor is [1,3] print("Size of Tensor:\n", t.size()) # expand the tensor expandedTensor = t.expand(3,-1) print("Expanded Tensor:\n", expandedTensor) print("Size of expanded tensor:\n", expandedTensor.size())输出结果
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) Expanded Tensor: tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) Size of expanded tensor: torch.Size([3, 3])
示例3
在下面的Python程序中,我们尝试沿非单一维度扩展张量,因此它引发了运行时错误。
# import required libraries import torch # create a tensor t = torch.tensor([[1,2,3]]) # display the tensor print("Tensor:\n", t) # size of tensor is [1,3] print("Size of Tensor:\n", t.size()) t.expand(3,4)输出结果
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]