where pure knowledge is acquired by just reading
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

212 lines
5.0 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 4])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch as tch\n",
"\n",
"vec_seq = tch.tensor([i for i in range(4)])\n",
"\n",
"vec_seq.unsqueeze_(-2).shape"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.0001, 0.0004, 0.0009, 0.0013])\n"
]
}
],
"source": [
"class KLAttention(tch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(self, x):\n",
" # p包含了多少q中的信息? KL[p||q] = \\sum_j q(j) (\\log q(j) - \\log p(j))\n",
" # 现在 x 的每一列都表示一个概率分布, 也就是说 KL[x[i0] || x[i1]]\n",
" # 表示 x[i0] 含有 多少 x[i1] 当中的信息\n",
" # KL[x[i0] || x[i1]] = \\sum_j x[i0, j] (\\log x[i0, j] - \\log x[i1, j])\n",
" EPS = 1e-40\n",
" xlog = (x + EPS).log()\n",
" crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, xlog)\n",
" uni_entropy = (tch.einsum('...kj, ...kj -> ...k', x, xlog)\n",
" .unsqueeze(-1))\n",
" return uni_entropy - crs_entropy\n",
"\n",
"\n",
"attention_layer = KLAttention()\n",
"\n",
"x = tch.tensor(\n",
" [[(i + 1) * (j + 1) * 10 for i in range(128)]\n",
" for j in range(4)],\n",
" dtype=tch.float\n",
").softmax(-1)\n",
"\n",
"print(attention_layer(x).relu().sum(-2))\n"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"crs: tensor(1.1598)\n",
"entro: tensor(-0.9475)\n",
"kl: tensor(0.2122)\n"
]
}
],
"source": [
"import torch\n",
"\n",
"x = torch.tensor([1, 2, 3, 4], dtype=torch.float).softmax(-1)\n",
"y = torch.tensor([2, 4, 6, 8], dtype=torch.float).softmax(-1)\n",
"\n",
"print('crs:', torch.einsum('...j, ...j', x, -y.log()))\n",
"print('entro:', torch.einsum('...j, ...j', x, x.log()))\n",
"print('kl:', torch.einsum('...j, ...j', x, x.log()-y.log()))"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"crs: tensor([[0.9475, inf],\n",
" [0.4402, nan]])\n",
"entro: tensor([[-0.9475],\n",
" [ nan]])\n",
"kl: tensor([[0., inf],\n",
" [nan, nan]])\n"
]
}
],
"source": [
"x = torch.tensor([[1, 2, 3, 4], [2, 4, 6, 1000]], \n",
" dtype=torch.float).softmax(-1)\n",
"\n",
"xlog = x.log()\n",
"crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, -xlog)\n",
"print('crs:',crs_entropy)\n",
"\n",
"entropy = tch.einsum('...ij, ...ij -> ...i', x, xlog).unsqueeze(-1)\n",
"print('entro:', entropy)\n",
"\n",
"print('kl:', crs_entropy + entropy)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Matrix([\n",
" [ 1],\n",
" [ 2],\n",
" [-1]]),\n",
" Matrix([\n",
" [-5/3],\n",
" [ 5/3],\n",
" [ 5/3]]),\n",
" Matrix([\n",
" [2],\n",
" [0],\n",
" [2]])]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"from sympy.matrices import Matrix,GramSchmidt\n",
"\n",
"a = np.array([[1,2,-1], [-1,3,1], [4,-1,0]])\n",
"a = [Matrix(col) for col in a]\n",
"GramSchmidt(a)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 2],\n",
" [0, 2]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"\n",
"torch.tensor([[0, 1, 2], [0, 1, 2]])[..., torch.tensor([True, False, True])]"
]
}
],
"metadata": {
"interpreter": {
"hash": "f29e8b3fa2d991a6f8847b235850bc2cfc73e5042ba8efb84ff0f4dcd41902ea"
},
"kernelspec": {
"display_name": "Python 3.9.6 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}