@ -0,0 +1,57 @@ | |||
import nltk | |||
import os | |||
import json | |||
from random import randint | |||
def splitSubSection(sent): | |||
def takeEven(lst): | |||
for i, elem in enumerate(lst): | |||
if i % 2 == 0: | |||
yield elem | |||
for tok in ['======', '=====', '====', '===']: | |||
sent = '. '.join(takeEven(sent.split(tok))) | |||
return sent | |||
def parseUniSection(concept, uniSection): | |||
text = f'[@@){concept}(@@]'.join(uniSection['text'].split(concept)) | |||
for i, link in enumerate(uniSection['links']): | |||
text = text.replace( | |||
uniSection['text'][link['pos_start']:link['pos_end']], | |||
f'[@|){link["text"]}(|@]', 1 | |||
) | |||
text = text.replace(f'=={uniSection["title"]}==', '').replace('\n', '. ') | |||
text = splitSubSection(text) | |||
text.replace('|', '') | |||
yield from nltk.sent_tokenize(text) | |||
def parseUniJSON(uniJSON): | |||
for uniConcept in uniJSON: | |||
concept = uniConcept['title'] | |||
for section in uniConcept['sections']: | |||
for sent in parseUniSection(concept, section): | |||
if len(sent.split()) < 32: | |||
continue | |||
yield sent | |||
if __name__ == '__main__': | |||
dirList = os.listdir('Dataset') | |||
out = open('train.csv', 'w', encoding='utf-8') | |||
out.write('0\n') | |||
for dirName in dirList: | |||
with open(f'Dataset/{dirName}', 'r', encoding='utf-8') as f: | |||
uniJSON = json.loads(f.read().lower()) | |||
for sent in parseUniJSON(uniJSON): | |||
out.write( | |||
'\"' + | |||
sent.replace('\n', ' ') | |||
.replace('\"', '\\\"') | |||
.replace('*', '') | |||
+ '\"\n\n' | |||
) | |||
out.close() |
@ -0,0 +1,369 @@ | |||
from random import random | |||
from tkinter import Y | |||
import torch as tch | |||
from transformers import * | |||
from random import * | |||
class SentTokenizer(tch.nn.Module): | |||
def __init__(self) -> None: | |||
import spacy | |||
super().__init__() | |||
BERT_NAME = 'bert-base-uncased' | |||
self.tok = BertTokenizer.from_pretrained(BERT_NAME) | |||
self.dummy = tch.nn.parameter.Parameter(tch.tensor(0.0)) | |||
self.en_grammar = spacy.load('en_core_web_sm') | |||
@staticmethod | |||
def unifySent(sentence): | |||
for tok in ['[@@)', '(@@]', '[@|)', '(|@]']: | |||
sentence = ''.join(sentence.split(tok)) | |||
return sentence | |||
def maskNoun(self, sentence): | |||
doc = self.en_grammar(sentence) | |||
for chunk in doc.noun_chunks: | |||
text = chunk.text | |||
if random() > 0.8: | |||
sentence = (sentence.replace(text, '[||]', 1) | |||
.replace('[||][||]', '[||]')) | |||
return sentence | |||
def getLabel(self, tokList, tokLabeledList): | |||
label = tch.zeros(len(tokList), dtype=tch.int8, | |||
device=self.dummy.device) | |||
idx, offset = 0, 0 | |||
flagMaskIsDefinition = False | |||
flagMaskIsApplication = False | |||
while idx < len(tokList): | |||
idx_ = idx + offset | |||
if idx_ + 4 > len(tokLabeledList): | |||
idx += 1 | |||
continue | |||
flag = ''.join(tokLabeledList[idx_:idx_+4]) | |||
offset += 4 | |||
if flag == '[@@)': | |||
flagMaskIsDefinition = True | |||
elif flag == '(@@]': | |||
flagMaskIsDefinition = False | |||
elif flag == '[@|)': | |||
flagMaskIsApplication = True | |||
elif flag == '(|@]': | |||
flagMaskIsApplication = False | |||
else: | |||
offset -= 4 | |||
if flagMaskIsApplication: | |||
label[idx] = 2 | |||
elif flagMaskIsDefinition: | |||
label[idx] = 1 | |||
idx += 1 | |||
return label | |||
@staticmethod | |||
def maskTokList(tokList, tokMasked): | |||
off = 0 | |||
for idx, elem in enumerate(tokList): | |||
idx_ = idx + off | |||
flag = ''.join(tokMasked[idx_: idx_+4]) | |||
if flag != '[||]': | |||
continue | |||
if tokMasked[idx_+4] == elem: | |||
off += 4 | |||
continue | |||
tokList[idx] = '[MASK]' | |||
off -= 1 | |||
return tokList | |||
@staticmethod | |||
def randMaskConcept(tokList, label): | |||
for idx, elem in enumerate(tokList): | |||
if label[idx] != 0 and random() < 0.1: | |||
tokList[idx] = '[MASK]' | |||
return tokList | |||
def forward(self, sentence): | |||
# label[i] = concept | rely | |||
uniSent = self.unifySent(sentence) | |||
tokLabeledList = self.tok.tokenize(sentence) | |||
tokList = self.tok.tokenize(uniSent) | |||
label = self.getLabel(tokList, tokLabeledList) | |||
tokList = self.randMaskConcept(tokList, label) | |||
# tokList = ['[CLS]'] + tokList + ['[SEP]'] | |||
tokIdList = self.tok.convert_tokens_to_ids(tokList) | |||
tokIdList = tch.tensor([tokIdList], device=self.dummy.device) | |||
# return tokIdList, tokList[1:-1], label | |||
return tokIdList, tokList, label | |||
class Trapezoid(tch.nn.Module): | |||
def __init__(self, in_features, out_features, layers): | |||
super().__init__() | |||
self.dummy = tch.nn.parameter.Parameter(tch.tensor(0.0)) | |||
dim_diff = out_features - in_features | |||
self.layer = tch.nn.Sequential(*[ | |||
tch.nn.Sequential( | |||
tch.nn.Linear( | |||
in_features + i * dim_diff // layers, | |||
in_features + (i + 1) * dim_diff // layers | |||
), | |||
tch.nn.LeakyReLU()) | |||
for i in range(layers) | |||
]) | |||
def forward(self, x): | |||
return self.layer(x) | |||
class LinAttention(tch.nn.Module): | |||
def __init__(self, in_features, out_features, attention_features): | |||
super().__init__() | |||
self.dimDown = \ | |||
tch.nn.Linear(in_features, attention_features) | |||
self.matKQ = \ | |||
tch.nn.Linear(attention_features, attention_features, bias=False) | |||
self.matV = \ | |||
tch.nn.Linear(in_features, out_features) | |||
self.leakyRELU = \ | |||
tch.nn.LeakyReLU() | |||
def forward(self, x, y): | |||
xp = self.dimDown(x) | |||
attention = tch.einsum( | |||
'...ij, ...kj -> ...ik', | |||
self.matKQ(xp), xp | |||
) | |||
attention = self.leakyRELU(attention) | |||
return tch.einsum( | |||
'...ik, ...kj -> ...ij', | |||
attention, self.matV(y) | |||
) | |||
class KLAttention(tch.nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, p, x): | |||
# attention[i0, i1] = \sum_j p[i0, j] (\log p[i0, j] - \log p[i1, j]) | |||
# 表示 p[i1] 丢失了多少 p[i0] 当中的信息 | |||
EPS = 1e-40 | |||
plog = -(p + EPS).log() | |||
crs_entropy = tch.einsum('...ij, ...kj -> ...ik', p, plog) | |||
uni_entropy = (tch.einsum('...kj, ...kj -> ...k', p, plog) | |||
.unsqueeze(-1)) | |||
# 把因为浮点数运算不准而产生的负数变成 0 | |||
attention = (crs_entropy - uni_entropy).relu() | |||
return tch.einsum('...ik, ...kj -> ...ij', attention, x) | |||
class KLTransformer(tch.nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.attention_layer = KLAttention() | |||
self.mlp_layer0 = Trapezoid(768, 768, 5) | |||
self.mlp_layer1 = Trapezoid(768, 768, 5) | |||
def forward(self, x): | |||
p = self.mlp_layer0(x).softmax(-1) | |||
attended_x = self.attention_layer(p, x) | |||
transformed_x = self.mlp_layer1(attended_x) | |||
return x + transformed_x | |||
class DefDiscriminator(tch.nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.transform = KLTransformer() | |||
self.mlp_layer = Trapezoid(768, 3, 5) | |||
def forward(self, x): | |||
y = self.transform(x) | |||
y = self.mlp_layer(y).softmax(-1) | |||
return y | |||
class NERModel(tch.nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.dummy = tch.nn.Parameter(tch.tensor(0.0)) | |||
BERT_NAME = 'bert-base-uncased' | |||
self.bert = BertModel.from_pretrained(BERT_NAME) | |||
self.head = DefDiscriminator() | |||
self.type_cnt = tch.nn.Parameter( | |||
tch.tensor([1.0, 1.0, 1.0], | |||
dtype=tch.double, | |||
requires_grad=False) | |||
) | |||
def criterion(self, y, label, giveRate): | |||
ys = [y[..., label == i, i] for i in range(3)] | |||
def uni_criterion(t): | |||
eps = 1e-2 / (t.shape[-2] + 1) | |||
randMask = tch.rand(size=t.shape, device=self.dummy.device) < 0.8 | |||
clip = (t < eps) * randMask | |||
return ~clip * (t < 1-eps) * t.log() | |||
loss = [uni_criterion(ys[i]) for i in range(3)] | |||
with tch.no_grad(): | |||
self.type_cnt += \ | |||
tch.tensor([(label == i).sum() for i in range(3)], | |||
device=self.dummy.device) | |||
tok_cnt = self.type_cnt.sum() | |||
tot_loss = ( | |||
- loss[0].sum() * (tok_cnt / self.type_cnt[0]).to(float) | |||
- loss[1].sum() * (tok_cnt / self.type_cnt[1]).to(float) | |||
- loss[2].sum() * (tok_cnt / self.type_cnt[2]).to(float) | |||
) | |||
if giveRate: | |||
cnt = (ys[1] > 1/3).sum().item() + (ys[2] > 1/3).sum().item() | |||
label_cnt = (label > 0).sum().item() | |||
tot_rate = (cnt + (ys[0] > 1/3).sum().item()) / label.shape[-1] | |||
if label_cnt > 0: | |||
rate = cnt / label_cnt | |||
else: | |||
rate = -1 | |||
return tot_loss, rate, tot_rate | |||
else: | |||
return tot_loss | |||
def forward(self, x, label=None, giveRate=True): | |||
with tch.no_grad(): | |||
y = self.bert(x)[0] | |||
y = self.head(y) | |||
if label is None: | |||
return y | |||
else: | |||
return self.criterion(y, label, giveRate) | |||
if __name__ == '__main__': | |||
import pandas as pd | |||
from tqdm import tqdm | |||
import matplotlib.pyplot as plt | |||
df = pd.read_csv('train.csv', sep='\n')['0'] | |||
device = tch.device('cuda:0') | |||
# 模型 | |||
tokenizer = SentTokenizer().to(device) | |||
model = NERModel().to(device) | |||
try: | |||
with open('NER.model', 'rb') as f: | |||
print('find model, load state dict') | |||
model.load_state_dict(tch.load(f)) | |||
print('load model state dict success') | |||
except: | |||
pass | |||
# 训练配置 | |||
optimizer = tch.optim.RMSprop(model.head.parameters(), lr=1e-5) | |||
BATCH_SIZE = 15 | |||
SAVE_ONCE = 5000 | |||
try: | |||
with open('NER.optimizer', 'rb') as f: | |||
print('find optimizer, load state dict') | |||
optimizer.load_state_dict(tch.load(f)) | |||
print('load optimizer state dict success') | |||
except: | |||
pass | |||
# 可视化 | |||
running_loss = 0.0 | |||
running_rate = 0.0 | |||
running_tot_rate = 0.0 | |||
history_loss = [] | |||
history_rate = [] | |||
history_tot_rate = [] | |||
skippedIter = 0 | |||
plt.ion() | |||
for epoch in range(1, 5): | |||
dataset_with_progress_bar = tqdm( | |||
enumerate(df.sample(frac=1)), total=len(df)) | |||
skippedIter = 0 | |||
for i, sentence in dataset_with_progress_bar: | |||
tokIdList, _, label = tokenizer(sentence) | |||
if tokIdList.shape[-1] > 512 or\ | |||
(int((label == 1).sum()) == 0): | |||
skippedIter += 1 | |||
else: | |||
loss, rate, tot_rate = model(tokIdList, label) | |||
running_loss = \ | |||
9e-1*running_loss + 1e-1*loss.item() if running_loss > 0.0 \ | |||
else loss.item() | |||
running_rate = \ | |||
99e-2*running_rate + 1e-2*rate if rate > 0.0 \ | |||
else running_rate | |||
running_tot_rate = \ | |||
99e-2*running_tot_rate + 1e-2*tot_rate | |||
dataset_with_progress_bar\ | |||
.set_description( | |||
'loss[%-1.5f] rate[%-2.2f%%] ' | |||
'tot_rate[%-2.2f%%] sent_len[%-3d] ' | |||
'skipped[%d] epoch[%d] ' | |||
% (running_loss, running_rate * 100, | |||
running_tot_rate * 100, tokIdList.shape[-1], | |||
skippedIter, epoch) | |||
) | |||
loss.backward() | |||
if (i - skippedIter) % BATCH_SIZE == 0: | |||
optimizer.step() | |||
optimizer.zero_grad() | |||
# 记录训练历史 | |||
history_loss.append(running_loss) | |||
history_rate.append(running_rate) | |||
history_tot_rate.append(running_tot_rate) | |||
if len(history_loss) > 20: | |||
history_loss = history_loss[-20:] | |||
if len(history_rate) > 20: | |||
history_rate = history_rate[-20:] | |||
if len(history_tot_rate) > 20: | |||
history_tot_rate = history_tot_rate[-20:] | |||
# 画图 | |||
plt.clf() | |||
plt.subplot(1, 2, 1) | |||
plt.plot(range(len(history_loss)), | |||
history_loss, c='red', | |||
label='loss (cross entropy loss)') | |||
plt.legend() | |||
plt.subplot(1, 2, 2) | |||
plt.plot(range(len(history_rate)), | |||
history_rate, c='blue', label='rate (only concepts)') | |||
plt.plot(range(len(history_tot_rate)), | |||
history_tot_rate, c='green', label='rate (all)') | |||
plt.legend() | |||
plt.draw() | |||
plt.pause(0.01) | |||
BATCH_SIZE = randint(15, 25) | |||
if i % SAVE_ONCE == 0: | |||
# 保存模型 | |||
dataset_with_progress_bar\ | |||
.set_description('saving ! ') | |||
tch.save(optimizer.state_dict(), 'NER.optimizer') | |||
tch.save(model.state_dict(), 'NER.model') | |||
dataset_with_progress_bar\ | |||
.set_description('done ! ') | |||
dataset_with_progress_bar\ | |||
.set_description('saving ! ') | |||
tch.save(optimizer.state_dict(), 'NER.optimizer') | |||
tch.save(model.state_dict(), 'NER.model') | |||
dataset_with_progress_bar\ | |||
.set_description('done ! ') |
@ -0,0 +1,43 @@ | |||
from sympy import im | |||
from model import NERModel, SentTokenizer | |||
import torch as tch | |||
device = tch.device('cuda:0') | |||
tokenizer = SentTokenizer().to(device) | |||
model = NERModel().to(device) | |||
print('loading model') | |||
model.load_state_dict(tch.load('Saved/NER.model')) | |||
print('ok') | |||
if __name__ == '__main__': | |||
print('='*20) | |||
while True: | |||
text = input('>>>') | |||
if len(text) >= 512 or len(text) <= 0: | |||
print('Sorry bro. I cannot do this. ') | |||
continue | |||
tokIdList, tokList, label = tokenizer(text) | |||
print('\n') | |||
print(tokIdList) | |||
print(tokList) | |||
prediction = model(tokIdList) | |||
_prediction = ( | |||
+ 2 * (prediction[..., 2] > 1/3) | |||
+ 1 * (prediction[..., 1] > 1/3) | |||
+ 0 * (prediction[..., 0] > 1/3) | |||
) | |||
for i, pred in enumerate(_prediction[0].tolist()): | |||
if pred == 2: | |||
print(f'%-25s' % tokList[i], 'reliance', | |||
prediction[..., i, 2].tolist()) | |||
if pred == 1: | |||
print(f'%-25s' % tokList[i], 'concept', | |||
prediction[..., i, 1].tolist()) | |||
if pred == 0: | |||
print(f'%-25s' % tokList[i], 'nothing', | |||
prediction[..., i, 0].tolist()) |
@ -0,0 +1,3 @@ | |||
a martingale can be thought of as the fortune at time n of a player who is betting on a fair game. | |||
f is usually denoted dν/du and called the Radon-Nikodym derivative. |
@ -0,0 +1,143 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import spacy\n", | |||
"\n", | |||
"en_grammar = spacy.load('en_core_web_sm')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stderr", | |||
"output_type": "stream", | |||
"text": [ | |||
"C:\\Users\\75872\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torchaudio\\backend\\utils.py:67: UserWarning: No audio backend is available.\n", | |||
" warnings.warn('No audio backend is available.')\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from transformers import *\n", | |||
"\n", | |||
"BERT_NAME = 'bert-base-uncased'\n", | |||
"tok = BertTokenizer.from_pretrained(BERT_NAME)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n", | |||
"when [[]] ceases and [[]] makes [[]] toward [[]], [[]] says to drop [[]] after [[]] starting at [[]].\n", | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"sentence = \"\"\"\n", | |||
"when the drop ceases and the curve makes an elbow toward less steep decline, cattell's scree test says to drop all further components after the one starting at the elbow.\n", | |||
"\"\"\"\n", | |||
"def maskNoun(sentence: str):\n", | |||
" doc = en_grammar(sentence)\n", | |||
" for noun in doc.noun_chunks:\n", | |||
" sentence = sentence.replace(noun.text, '[[]]', 1)\n", | |||
" return sentence\n", | |||
"print(maskNoun(sentence))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 29, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"['when', 'the', 'drop', 'cease', '##s', 'and', 'the', 'curve', 'makes', 'an', 'elbow', 'toward', 'less', 'steep', 'decline', ',', 'cat', '##tell', \"'\", 's', 'sc', '##ree', 'test', 'says', 'to', 'drop', 'all', 'further', 'components', 'after', 'the', 'one', 'starting', 'at', 'the', 'elbow', '.']\n", | |||
"['when', '[', '[', ']', ']', 'cease', '##s', 'and', '[', '[', ']', ']', 'makes', '[', '[', ']', ']', 'toward', '[', '[', ']', ']', ',', '[', '[', ']', ']', 'says', 'to', 'drop', '[', '[', ']', ']', 'after', '[', '[', ']', ']', 'starting', 'at', '[', '[', ']', ']', '.']\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"tokList = tok.tokenize(sentence)\n", | |||
"tokMask = tok.tokenize(maskNoun(sentence))\n", | |||
"\n", | |||
"print(tokList)\n", | |||
"print(tokMask)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 32, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"['when', '[MASK]', '[MASK]', 'cease', '##s', 'and', '[MASK]', '[MASK]', 'makes', '[MASK]', '[MASK]', 'toward', '[MASK]', '[MASK]', '[MASK]', ',', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', 'says', 'to', 'drop', '[MASK]', '[MASK]', '[MASK]', 'after', '[MASK]', '[MASK]', 'starting', 'at', '[MASK]', '[MASK]', '.']\n", | |||
"['when', '[', '[', ']', ']', 'cease', '##s', 'and', '[', '[', ']', ']', 'makes', '[', '[', ']', ']', 'toward', '[', '[', ']', ']', ',', '[', '[', ']', ']', 'says', 'to', 'drop', '[', '[', ']', ']', 'after', '[', '[', ']', ']', 'starting', 'at', '[', '[', ']', ']', '.']\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"def maskTokList(self, tokList, tokMask):\n", | |||
" off = 0\n", | |||
" for idx, elem in enumerate(tokList):\n", | |||
" idx_ = idx + off\n", | |||
" flag = ''.join(tokMask[idx_: idx_+4])\n", | |||
" if flag == '[[]]':\n", | |||
" if tokMask[idx_+4] == elem:\n", | |||
" off += 4\n", | |||
" continue\n", | |||
" tokList[idx] = '[MASK]'\n", | |||
" off -= 1\n", | |||
" return tokList\n", | |||
"\n", | |||
"\n", | |||
"print(tokList)\n", | |||
"print(tokMask)\n" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"interpreter": { | |||
"hash": "f29e8b3fa2d991a6f8847b235850bc2cfc73e5042ba8efb84ff0f4dcd41902ea" | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3.9.6 64-bit", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.9.6" | |||
}, | |||
"orig_nbformat": 4 | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |
@ -0,0 +1,212 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 111, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"torch.Size([1, 4])" | |||
] | |||
}, | |||
"execution_count": 111, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"import torch as tch\n", | |||
"\n", | |||
"vec_seq = tch.tensor([i for i in range(4)])\n", | |||
"\n", | |||
"vec_seq.unsqueeze_(-2).shape" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 112, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"tensor([0.0001, 0.0004, 0.0009, 0.0013])\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"class KLAttention(tch.nn.Module):\n", | |||
" def __init__(self):\n", | |||
" super().__init__()\n", | |||
"\n", | |||
" def forward(self, x):\n", | |||
" # p包含了多少q中的信息? KL[p||q] = \\sum_j q(j) (\\log q(j) - \\log p(j))\n", | |||
" # 现在 x 的每一列都表示一个概率分布, 也就是说 KL[x[i0] || x[i1]]\n", | |||
" # 表示 x[i0] 含有 多少 x[i1] 当中的信息\n", | |||
" # KL[x[i0] || x[i1]] = \\sum_j x[i0, j] (\\log x[i0, j] - \\log x[i1, j])\n", | |||
" EPS = 1e-40\n", | |||
" xlog = (x + EPS).log()\n", | |||
" crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, xlog)\n", | |||
" uni_entropy = (tch.einsum('...kj, ...kj -> ...k', x, xlog)\n", | |||
" .unsqueeze(-1))\n", | |||
" return uni_entropy - crs_entropy\n", | |||
"\n", | |||
"\n", | |||
"attention_layer = KLAttention()\n", | |||
"\n", | |||
"x = tch.tensor(\n", | |||
" [[(i + 1) * (j + 1) * 10 for i in range(128)]\n", | |||
" for j in range(4)],\n", | |||
" dtype=tch.float\n", | |||
").softmax(-1)\n", | |||
"\n", | |||
"print(attention_layer(x).relu().sum(-2))\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 113, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"crs: tensor(1.1598)\n", | |||
"entro: tensor(-0.9475)\n", | |||
"kl: tensor(0.2122)\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"import torch\n", | |||
"\n", | |||
"x = torch.tensor([1, 2, 3, 4], dtype=torch.float).softmax(-1)\n", | |||
"y = torch.tensor([2, 4, 6, 8], dtype=torch.float).softmax(-1)\n", | |||
"\n", | |||
"print('crs:', torch.einsum('...j, ...j', x, -y.log()))\n", | |||
"print('entro:', torch.einsum('...j, ...j', x, x.log()))\n", | |||
"print('kl:', torch.einsum('...j, ...j', x, x.log()-y.log()))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 114, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"crs: tensor([[0.9475, inf],\n", | |||
" [0.4402, nan]])\n", | |||
"entro: tensor([[-0.9475],\n", | |||
" [ nan]])\n", | |||
"kl: tensor([[0., inf],\n", | |||
" [nan, nan]])\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"x = torch.tensor([[1, 2, 3, 4], [2, 4, 6, 1000]], \n", | |||
" dtype=torch.float).softmax(-1)\n", | |||
"\n", | |||
"xlog = x.log()\n", | |||
"crs_entropy = tch.einsum('...ij, ...kj -> ...ik', x, -xlog)\n", | |||
"print('crs:',crs_entropy)\n", | |||
"\n", | |||
"entropy = tch.einsum('...ij, ...ij -> ...i', x, xlog).unsqueeze(-1)\n", | |||
"print('entro:', entropy)\n", | |||
"\n", | |||
"print('kl:', crs_entropy + entropy)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"[Matrix([\n", | |||
" [ 1],\n", | |||
" [ 2],\n", | |||
" [-1]]),\n", | |||
" Matrix([\n", | |||
" [-5/3],\n", | |||
" [ 5/3],\n", | |||
" [ 5/3]]),\n", | |||
" Matrix([\n", | |||
" [2],\n", | |||
" [0],\n", | |||
" [2]])]" | |||
] | |||
}, | |||
"execution_count": 1, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"import numpy as np\n", | |||
"from sympy.matrices import Matrix,GramSchmidt\n", | |||
"\n", | |||
"a = np.array([[1,2,-1], [-1,3,1], [4,-1,0]])\n", | |||
"a = [Matrix(col) for col in a]\n", | |||
"GramSchmidt(a)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"tensor([[0, 2],\n", | |||
" [0, 2]])" | |||
] | |||
}, | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"import torch\n", | |||
"\n", | |||
"torch.tensor([[0, 1, 2], [0, 1, 2]])[..., torch.tensor([True, False, True])]" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"interpreter": { | |||
"hash": "f29e8b3fa2d991a6f8847b235850bc2cfc73e5042ba8efb84ff0f4dcd41902ea" | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3.9.6 64-bit", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.9.6" | |||
}, | |||
"orig_nbformat": 4 | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |