{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 111,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 4])"
|
|
]
|
|
},
|
|
"execution_count": 111,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch as tch\n",
|
|
"\n",
|
|
"vec_seq = tch.tensor([i for i in range(4)])\n",
|
|
"\n",
|
|
"vec_seq.unsqueeze_(-2).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 112,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([0.0001, 0.0004, 0.0009, 0.0013])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"class KLAttention(tch.nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" # p包含了多少q中的信息? KL[p||q] = \\sum_j q(j) (\\log q(j) - \\log p(j))\n",
|
|
" # 现在 x 的每一列都表示一个概率分布, 也就是说 KL[x[i0] || x[i1]]\n",
|
|
" # 表示 x[i0] 含有 多少 x[i1] 当中的信息\n",
|
|
" # KL[x[i0] || x[i1]] = \\sum_j x[i0, j] (\\log x[i0, j] - \\log x[i1, j])\n",
|
|
" EPS = 1e-40\n",
|
|
" xlog = (x + EPS).log()\n",
|
|
" crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, xlog)\n",
|
|
" uni_entropy = (tch.einsum('...kj, ...kj -> ...k', x, xlog)\n",
|
|
" .unsqueeze(-1))\n",
|
|
" return uni_entropy - crs_entropy\n",
|
|
"\n",
|
|
"\n",
|
|
"attention_layer = KLAttention()\n",
|
|
"\n",
|
|
"x = tch.tensor(\n",
|
|
" [[(i + 1) * (j + 1) * 10 for i in range(128)]\n",
|
|
" for j in range(4)],\n",
|
|
" dtype=tch.float\n",
|
|
").softmax(-1)\n",
|
|
"\n",
|
|
"print(attention_layer(x).relu().sum(-2))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 113,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"crs: tensor(1.1598)\n",
|
|
"entro: tensor(-0.9475)\n",
|
|
"kl: tensor(0.2122)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"x = torch.tensor([1, 2, 3, 4], dtype=torch.float).softmax(-1)\n",
|
|
"y = torch.tensor([2, 4, 6, 8], dtype=torch.float).softmax(-1)\n",
|
|
"\n",
|
|
"print('crs:', torch.einsum('...j, ...j', x, -y.log()))\n",
|
|
"print('entro:', torch.einsum('...j, ...j', x, x.log()))\n",
|
|
"print('kl:', torch.einsum('...j, ...j', x, x.log()-y.log()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 114,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"crs: tensor([[0.9475, inf],\n",
|
|
" [0.4402, nan]])\n",
|
|
"entro: tensor([[-0.9475],\n",
|
|
" [ nan]])\n",
|
|
"kl: tensor([[0., inf],\n",
|
|
" [nan, nan]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.tensor([[1, 2, 3, 4], [2, 4, 6, 1000]], \n",
|
|
" dtype=torch.float).softmax(-1)\n",
|
|
"\n",
|
|
"xlog = x.log()\n",
|
|
"crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, -xlog)\n",
|
|
"print('crs:',crs_entropy)\n",
|
|
"\n",
|
|
"entropy = tch.einsum('...ij, ...ij -> ...i', x, xlog).unsqueeze(-1)\n",
|
|
"print('entro:', entropy)\n",
|
|
"\n",
|
|
"print('kl:', crs_entropy + entropy)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[Matrix([\n",
|
|
" [ 1],\n",
|
|
" [ 2],\n",
|
|
" [-1]]),\n",
|
|
" Matrix([\n",
|
|
" [-5/3],\n",
|
|
" [ 5/3],\n",
|
|
" [ 5/3]]),\n",
|
|
" Matrix([\n",
|
|
" [2],\n",
|
|
" [0],\n",
|
|
" [2]])]"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from sympy.matrices import Matrix,GramSchmidt\n",
|
|
"\n",
|
|
"a = np.array([[1,2,-1], [-1,3,1], [4,-1,0]])\n",
|
|
"a = [Matrix(col) for col in a]\n",
|
|
"GramSchmidt(a)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0, 2],\n",
|
|
" [0, 2]])"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"torch.tensor([[0, 1, 2], [0, 1, 2]])[..., torch.tensor([True, False, True])]"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "f29e8b3fa2d991a6f8847b235850bc2cfc73e5042ba8efb84ff0f4dcd41902ea"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.6 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.6"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|