pytorch中张量操作对应的数学表示

99 阅读1分钟

张量就是多重线性函数,以下用2阶张量为例来说明pytorch的函数变换

l:V×VRl: V^{*}\times V^{*} \rightarrow R

l=kijeiejl=\sum k_{ij}e_{i}\otimes e_{j}


t = torch.Tensor([[1, 2],[3, 4]])

l=1e1e1+2e1e2+3e2e1+4e2e2l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}

l=i,j{1,2}l(ei,ej)eiejl=\sum_{i,j\in\{1,2\}} l(e_{i}^{*}, e_{j}^{*})e_{i}\otimes e_{j}


torch.transpose(t, 0, 1)
[[1, 3], [2, 4]]

lout=1e1e1+2e2e1+3e1e2+4e2e2l_{out}=1e_{1}\otimes e_{1}+2e_{2}\otimes e_{1}+3e_{1}\otimes e_{2}+4e_{2}\otimes e_{2}

lout=1e1e1+3e1e2+2e2e1+4e2e2l_{out}=1e_{1}\otimes e_{1}+3e_{1}\otimes e_{2}+2e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}


torch.cat((t, torch.Tensor([[5, 6]])), 0)
[[1, 2], [3, 4], [5, 6]]

lin2=5e3e1+6e3e2l_{in2}=5e_{3}\otimes e_{1}+6e_{3}\otimes e_{2}

lout=1e1e1+2e1e2+3e2e1+4e2e2+5e3e1+6e3e2l_{out}=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2} + 5e_{3}\otimes e_{1}+6e_{3}\otimes e_{2}


torch.cat((t, torch.Tensor([[5], [6]])), 1)
[[1, 2, 5], [3, 4, 6]]

lin2=5e1e3+6e2e3l_{in2}=5e_{1}\otimes e_{3}+6e_{2}\otimes e_{3}

l=1e1e1+2e1e2+3e2e1+4e2e2+5e1e3+6e2e3l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2} + 5e_{1}\otimes e_{3}+6e_{2}\otimes e_{3}

l=1e1e1+2e1e2+5e1e3+3e2e1+4e2e2+6e2e3l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+5e_{1}\otimes e_{3}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}+6e_{2}\otimes e_{3}


torch.split(t, 1, 0)
[[1, 2]], [[3, 4]]

l=1e1e1+2e1e2+3e2e1+4e2e2l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}

lout1=1e1e1+2e1e2l_{out1}=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}

lout2=3e1e1+4e1e2l_{out2}=3e_{1}\otimes e_{1}+4e_{1}\otimes e_{2}


torch.split(t, 1, 1)
[[1], [3]], [[2], [4]]

l=1e1e1+2e1e2+3e2e1+4e2e2l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}

lout1=1e1e1+3e2e1l_{out1}=1e_{1}\otimes e_{1}+3e_{2}\otimes e_{1}

lout2=2e1e1+4e2e1l_{out2}=2e_{1}\otimes e_{1}+4e_{2}\otimes e_{1}


torch.gather(t, 0, torch.tensor([[1, 0], [1, 1]]))
[[3, 2], [3, 4]]

l=1e1e1+2e1e2+3e2e1+4e2e2l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}

lin2=2e1e1+1e1e2+2e2e1+2e2e2l_{in2}=2e_{1}\otimes e_{1}+1e_{1}\otimes e_{2}+2e_{2}\otimes e_{1}+2e_{2}\otimes e_{2}

lout=i,j{1,2}l(elin2(ei,ej),ej)eiejl_{out}=\sum_{i,j\in\{1,2\}} l(e_{l_{in2}(e_{i}^{*}, e_{j}^{*})}^{*}, e^{*}_{j}) e_{i}\otimes e_{j}

lout=3e1e1+2e1e2+3e2e1+4e2e2l_{out}=3e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}


torch.gather(t, 1, torch.tensor([[1, 0], [1, 1]]))
[[2, 1], [4, 4]]

l=1e1e1+2e1e2+3e2e1+4e2e2l=1e_{1}\otimes e_{1}+2e_{1}\otimes e_{2}+3e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}

lin2=2e1e1+1e1e2+2e2e1+2e2e2l_{in2}=2e_{1}\otimes e_{1}+1e_{1}\otimes e_{2}+2e_{2}\otimes e_{1}+2e_{2}\otimes e_{2}

lout=i,j{1,2}l(ei,elin2(ei,ej))eiejl_{out}=\sum_{i,j\in\{1,2\}} l(e^{*}_{i}, e_{l_{in2}(e_{i}^{*}, e_{j}^{*})}^{*}) e_{i}\otimes e_{j}

lout=2e1e1+1e1e2+4e2e1+4e2e2l_{out}=2e_{1}\otimes e_{1}+1e_{1}\otimes e_{2}+4e_{2}\otimes e_{1}+4e_{2}\otimes e_{2}