Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

Is there a built in function to calculate efficiently all pairwaise dot products of two tensors in Pytorch? e.g.
input - tensor A (shape NxD)
tensor B (shape NxD)

output - tensor C (shape NxN) such that C_i,j = torch.dot(A_i, B_j) ?

question from:https://stackoverflow.com/questions/65935952/all-pairwise-dot-product-pytorch

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
444 views
Welcome To Ask or Share your Answers For Others

1 Answer

Isn't it simply

C = torch.mm(A, B.T)  # same as C = A @ B.T

BTW,
A very flexible tool for matrix/vector/tensor dot products is torch.einsum:

C = torch.einsum('id,jd->ij', A, B)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...