pytorch中的Torch.gather函数的含义

阿里云双11来了!从本博客参与阿里云,服务器最低只要86元/年!

在动手学习深度学习中学到了一个函数gather,原文是说可以通过gather得到标签的预测概率。

y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
y = torch.LongTensor([0,2])
y_hat.gather(1,y.view(-1,1)) 
tensor([[0.1000],
        [0.5000]])

开始我看到这个输出一头雾水 不知道怎么回事

查了查 gather的时候我才知道

torch.gather(input,dim,index,out=None)

example:
t = torch.Tensor([1,2],[3,4])
torch.gather(t,1,torchLongTensor([[0,0],[1,0]]))
1,1
4,3

可以看出gather的作用是根据索引返回该项元素,首先先输入一个Tensor 然后根据dim进行判断是是行的还是列的,当dim=0 时候竖行查找,当dim=1的时候是横向查找

上题中,dim=1,那么索引就是列号。index的大小就是输出的大小,比如index是[1,0;0,0]其实就是第一行的第二个元素和第一个元素,第二行的第一个元素也就是返回的是2,1 3,3

所以例子中是[0,0],[1,0] 返回的就是[1,1],[4,3]

在例题中的他是通过view函数来返回index的,开始不知道view的意思,查过后知道了,他实际上和resize的意思差不多。

a = torch.Tensor([[1,2,3],[4,5,6]])
b = torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6))
print(b.view(1,6))

得到的都是
tensor([[1,2,3,4,5,6]])

再看一个例子

a = torch.Tensor([[1,2,3],[4,5,6]])
print(a.view(3,2))

将会得到

tensor([[1,2],
[3,4],
[5,6]
])

相当于就是从1,2,3,4,5,6 顺序的拿数组来填充需要的形状。

参数中的-1就代表这个位置由其他位置的数字来进行推断,只要不在歧义的情况下,view参数就可以推断出来,也就是人可以推断出形状的情况下,view也是可以推断出来的,比如a tensor的数据个数是6个,如果view(1,-1)我们就可以推断出来-1代表6。而如果view(-1,-1,2)的话,人也不知道的话,机器也不会知道的,所以就会报错

https://www.jianshu.com/p/eddaa933365f

Python量化投资网携手4326手游为资深游戏玩家推荐:《跑跑卡丁车下载

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
0 条回复 A 作者 M 管理员
    所有的伟大,都源于一个勇敢的开始!
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论