如何在 PyTorch 中对张量的元素进行排序?
要在PyTorch中对张量的元素进行排序,我们可以使用方法。此方法返回两个张量。第一个张量是具有元素排序值的张量,第二个张量是原始张量中元素索引的张量。我们可以按行和列计算二维张量。torch.sort()
脚步
导入所需的库。在以下所有Python示例中,所需的Python库是torch。确保您已经安装了它。
创建一个PyTorch张量并打印它。
要对上面创建的张量的元素进行排序,请计算。将此值分配给一个新变量"v"。这里,输入是输入张量,dim是元素排序的维度。对元素按行排序,dim设置为1,按列对元素排序,dim设置为0。torch.sort(input,dim)
具有排序值的张量可以作为v[0]访问,排序元素的索引张量作为v[1]访问。
打印带有排序值的张量和带有排序值索引的张量。
示例1
以下Python程序展示了如何对一维张量的元素进行排序。
# Python program to sort elements of a tensor # import necessary library import torch # Create a tensor T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443]) print("Original Tensor:\n", T) # sort the tensor T # it sorts the tensor in ascending order v = torch.sort(T) # print(v) # print tensor of sorted value print("Tensor with sorted value:\n", v[0]) # print indices of sorted value print("Indices of sorted value:\n", v[1])输出结果
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Tensor with sorted value: tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000]) Indices of sorted value: tensor([2, 3, 0, 1, 5, 4])
示例2
以下Python程序显示了如何对2D张量的元素进行排序。
# Python program to sort elements of a 2-D tensor # import the library import torch # Create a 2-D tensor T = torch.Tensor([[2,3,-32], [43,4,-53], [4,37,-4], [3,-75,34]]) print("Original Tensor:\n", T) # sort tensor T # it sorts the tensor in ascending order v = torch.sort(T) # print(v) # print tensor of sorted value print("Tensor with sorted value:\n", v[0]) # print indices of sorted value print("Indices of sorted value:\n", v[1]) print("Sort tensor Column-wise") v = torch.sort(T, 0) # print(v) # print tensor of sorted value print("Tensor with sorted value:\n", v[0]) # print indices of sorted value print("Indices of sorted value:\n", v[1]) print("Sort tensor Row-wise") v = torch.sort(T, 1) # print(v) # print tensor of sorted value print("Tensor with sorted value:\n", v[0]) # print indices of sorted value print("Indices of sorted value:\n", v[1])输出结果
Original Tensor: tensor([[ 2., 3., -32.], [ 43., 4., -53.], [ 4., 37., -4.], [ 3., -75., 34.]]) Tensor with sorted value: tensor([[-32., 2., 3.], [-53., 4., 43.], [ -4., 4., 37.], [-75., 3., 34.]]) Indices of sorted value: tensor([[2, 0, 1], [2, 1, 0], [2, 0, 1], [1, 0, 2]]) Sort tensor Column-wise Tensor with sorted value: tensor([[ 2., -75., -53.], [ 3., 3., -32.], [ 4., 4., -4.], [ 43., 37., 34.]]) Indices of sorted value: tensor([[0, 3, 1], [3, 0, 0], [2, 1, 2], [1, 2, 3]]) Sort tensor Row-wise Tensor with sorted value: tensor([[-32., 2., 3.], [-53., 4., 43.], [ -4., 4., 37.], [-75., 3., 34.]]) Indices of sorted value: tensor([[2, 0, 1], [2, 1, 0], [2, 0, 1], [1, 0, 2]])