本文概要:介绍DMN的基本原理,使用PyTorch进行实现一个简单QA

论文:Ask Me Anything: Dynamic Memory Networks for Natural Language Processing

模型简介

概要说明

许多NLP问题都可以看做一个Question-Answer问题。Dynamic Memory Network 由4部分组成。

输入模块

对输入的句子facts(先embedding)使用GRU进行编码,得到encoded_facts,给到后面的情景记忆模块

问题模块

对输入的问题question使用GRU进行编码,得到encoded_question, 给到后面的情景记忆模块回答模块

情景记忆模块

Episodic Memory Modulememoryattention组成。

  • attention:会选择更重要的facts
  • memory:根据questionfacts旧memory来生成新momery 。初始:memory=encoded_question

会在facts上迭代多次去计算memory。 每一次迭代会提取出新的信息。

输出最终的momery, 给到回答模块

回答模块

memory + question, 在GRUCell上迭代原本的回答长度次, 得到最终的预测结果。

输入模块

输入

  • 一个句子,有\(T_I\)个单词
  • \(T_I\)个句子,则把这些句子合并成一个大句子。在每个句子的末尾添加一个句子结束标记</s>。如上图蓝色的部分

GRU计算隐状态

句子过RNN时,对于每一时刻\(t\)的单词\(w_t\) ,有\(h_t\) : \[ h_t = \rm{RNN}(w_t, h_{t-1}) \] 输出

使用RNN的h = hidden states 作为输入句子的向量表达,也就是encoded_facts

  • 一个句子,输出所有时刻的\(h_t\)
  • 多个句子,输出每个句子结束标记</s>时刻的\(h_t\)

问题模块

输入

输入一个句子question,有\(T_Q\)个单词。

GRU计算隐状态 \[ q_t = \rm{RNN}(w_t^Q, q_{t-1}) \] 输出Q编码

最后时刻的隐状态\(q_{T_Q}\)作为句子的编码。


情景记忆模块

总体思路

记忆模块收到两个编码表达:encoded_factsencoded_question , 也就是\(h\)\(q\)

模块会生成一个记忆memory,初始时memory = encoded_question

记忆模块在encoded_facts上反复迭代多轮,每一轮去提取新的信息episode, 更新memory

  • 遍历所有facts, 对于每一个的fact, 不停地更新当前轮的信息e
  • 计算新的信息:\(e_{new}=\rm{RNN}(fact, e)\) ,使用当前fact和当前信息
  • 计算新信息的保留比例注意门\(g\)
  • 更新信息\(e = g * e_{new} + (1-g) * e\)
  • 计算保留比例g:结合当前factmemoryquestion 去生成多个特征,再过一个两层前向网络G得到一个比例数值

  • 更新memory\(m^i = \rm{GRU}(e, m^{i-1})\)

特征函数与前向网络

保留比例门g充当着attention的作用 。

特征函数\(z(c, m, q)\), 其中c就是当前的fact ,(论文里面是9个特征): \[ z(c, m, q) = [c \circ q, c \circ m, \vert c-q\vert, \vert c-m\vert] \] 前向网络\(g=G(c, m ,q)\)\[ t = \rm{tanh}(W^1z(c, m, q) + b^1) \\ g = G(c, m, q) = \sigma(W^2 t + b^2) \] e更新

在每个fact遍历中,e会结合fact和旧e去生成新的信息\(e_{new}\),再结合旧\(e\)和新\(e_{new}\) 去生成最终的\(e^i\)\[ e_{new}=\rm{RNN}(fact, e) \]

\[ e = g * e_{new} + (1-g) * e \]

记忆更新

每一轮迭代后,结合旧记忆和当前轮的信息e去更新记忆: \[ m^i = \rm{GRU}(e, m^{i-1}) \] 迭代停止条件

  • 设置最大迭代次数\(T_M\)
  • 在输入里面追加停止迭代信号,如果注意门选择它,则停止。

回答模块

回答模块结合memory和question,来生成对问题的答案。也是通过GRU来生成答案的。

aanswer_gru的hidden state,初始\(a_0= m^{T_M}\) \[ y_t = \rm{softmax}(W^a a_t) \\ a_t = \rm{GRU} ([y_{t-1}, q], a_{t-1}) \] 使用交叉熵去计算loss,进行优化。

实现细节

我的github源代码 ,实现参考自DSKSD的代码

数据处理

原始数据

使用过的数据是facebook的bAbi Tasks Data 1-20里面的 en-10k下的qa5_three-arg-relations_train.txt 和test数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
1 Bill travelled to the office.
2 Bill picked up the football there.
3 Bill went to the bedroom.
4 Bill gave the football to Fred.
5 What did Bill give to Fred? football 4
6 Fred handed the football to Bill.
7 Jeff went back to the office.
8 Who received the football? Bill 6
9 Bill travelled to the office.
10 Bill got the milk there.
11 Who received the football? Bill 6
12 Fred travelled to the garden.
13 Fred went to the hallway.
14 Bill journeyed to the bedroom.
15 Jeff moved to the hallway.
16 Jeff journeyed to the bathroom.
17 Bill journeyed to the office.
18 Fred travelled to the bathroom.
19 Mary journeyed to the kitchen.
20 Jeff took the apple there.
21 Jeff gave the apple to Fred.
22 Who did Jeff give the apple to? Fred 21
23 Bill went back to the bathroom.
24 Bill left the milk.
25 Who received the apple? Fred 21
1 Mary travelled to the garden.
2 Mary journeyed to the kitchen.
3 Bill went back to the office.
4 Bill journeyed to the hallway.
5 Jeff went back to the bedroom.
6 Fred moved to the hallway.
7 Bill moved to the bathroom.
8 Jeff went back to the garden.
9 Jeff went back to the kitchen.
10 Fred went back to the garden.
11 Mary got the football there.
12 Mary handed the football to Jeff.
13 What did Mary give to Jeff? football 12

比如1-25是一个大的情景

  • 没有问号的都是陈述句,是情景数据fact。只有.号, 都是简单句
  • 带问号的:是问句,带有答案和答案所在句子。使用tab分割

加载原始数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def load_raw_data(file_path, seq_end='</s>'):
''' 从文件中读取文本数据,并整合成[facts, question, answer]一条一条的可用数据,原始word形式
Args:
file_path -- 数据文件
seq_end -- 句子结束标记
Returns:
data -- list,元素是[facts, question, answer]
'''
source_data = open(file_path).readlines()
print (file_path, ":", len(source_data), "lines")
# 去掉换行符号
source_data = [line[:-1] for line in source_data]
data = []
for line in source_data:
index = line.split(' ')[0]
if index == '1':
# 一个新的QA开始
facts = []
#qa = []
if '?' in line:
# 当前QA的一个问句
# 问题 答案 答案所在句子的编号 \t分隔
tmp = line.split('\t')
question = tmp[0].strip().replace('?', '').split(' ')[1:] + ['?']
answer = tmp[1].split() + [seq_end]
facts_for_q = deepcopy(facts)
data.append([facts_for_q, question, answer])
else:
# 普通的事件描述,简单句,只有.和空格
sentence = line.replace('.', '').split(' ')[1:] + [seq_end]
facts.append(sentence)
return data

把数据转成id格式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def triple_word2id(triple_word_data, th):
'''把文字转成id
Args:
triple_word_data -- [(facts, q, a)] word形式
th -- textheler
Returns:
triple_id_data -- [(facts, q, a)]index形式
'''
# 把各个word转成数字id
for t in triple_word_data:
# 处理facts句子
for i, fact in enumerate(t[0]):
t[0][i] = th.sentence2indices(fact)
# 问题与答案
t[1] = th.sentence2indices(t[1])
t[2] = th.sentence2indices(t[2])
return triple_word_data

根据batch_size取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def get_data_loader(data, batch_size=1, shuffle=False):
''' 以batch的格式返回数据
Args:
data -- list格式的data
batch_size --
shuffle -- 每一个epoch开始的时候,对数据进行shuffle
Returns:
数据遍历的iterator
'''
if shuffle:
random.shuffle(data)
start = 0
end = batch_size
while (start < len(data)):
batch = data[start:end]
start, end = end, end + batch_size
yield batch
if end >= len(data) and start < len(data):
batch = data[start:]
yield batch

对每一个batch进行padding

这部分有点复杂。要求问题、答案、fact的长度一致,每个问题的fact的数量也要一样。

其实和模型也有关,模型写的有点坑,就是每条数据的所有fact应该连接在一起成为一个大的fact送进GRU里,在每个fact后面加上结束标记。但是我这却分开了,分成了多个标记好的fact,也怪当时没有仔细看好论文,这个也是参考别人的实现。循环也导致训练贼慢,但是现在忙着找实习,就先不改了。后面好好写DMNPLUS吧。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def pad_batch_data(raw_batch_data, th):
''' 对数据进行padding,问题、答案、fact长度分别一致,同时每条数据的fact的数量一致。输入到网络的时候要用
Args:
raw_batch_data -- [[facts, q, a]],都是以list wordid表示
th -- TextHelper
Returns:
all_facts -- [b, nfact, flen],pad后的facts,Variable
all_facts_mask -- [b, nfact, flen],facts的mask,Variable
questions -- [b, qlen],pad后的questions,Variable
questions_mask -- [b, qlen],questions的mask,Variable
answers -- [b, alen],pad后的answers,Variable
'''
all_facts, questions, answers = [list(i) for i in zip(*raw_batch_data)]
batch_size = len(raw_batch_data)

# 1. 计算各种长度。一个QA的facts数量,fact、Q、A句子的最大长度
n_fact = max([len(facts) for facts in all_facts])
flen = max([len(f) for f in flatten(all_facts)])
qlen = max([len(q) for q in questions])
alen = max([len(a) for a in answers])
padid = th.word2index(th.pad)

# 2. 对数据进行padding
all_facts_mask = []
for i in range(batch_size):
# 2.1 pad fact
facts = all_facts[i]
for j in range(len(facts)):
t = flen - len(facts[j])
if t > 0:
all_facts[i][j] = facts[j] + [padid] * t
# fact数量pad
while (len(facts) < n_fact):
all_facts[i].append([padid] * flen)

# 计算facts内容是否是填充给的,填充为1,不填充为0
mask = [tuple(map(lambda v: v == padid, fact)) for fact in all_facts[i]]
all_facts_mask.append(mask)

# 2.2 pad question
q = questions[i]
if len(q) < qlen:
questions[i] = q + [padid] * (qlen - len(q))
# 2.3 pad answer
a = answers[i]
if len(a) < alen:
answers[i] = a + [padid] * (alen - len(a))

# 3. 把list数据转成Variable
all_facts = get_variable(torch.LongTensor(all_facts))
all_facts_mask = get_variable(torch.ByteTensor(all_facts_mask))
answers = get_variable(torch.LongTensor(answers))
questions = torch.LongTensor(questions)
questions_mask = [(tuple(map(lambda v: v == padid, q))) for q in questions]
questions_mask = torch.ByteTensor(questions_mask)
questions, questions_mask = get_variable(questions), get_variable(questions_mask)
return all_facts, all_facts_mask, questions, questions_mask, answers

模型

模型定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class DMN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, padding_idx, seqbegin_id, dropout_p=0.1):
'''
Args:
vocab_size -- 词汇表大小
embed_size -- 词嵌入维数
hidden_size -- GRU的输出维数
padding_idx -- pad标记的wordid
seqbegin_id -- 句子起始的wordid
dropout_p -- dropout比率
'''
super(DMN, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.seqbegin_id = seqbegin_id

self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
self.input_gru = nn.GRU(embed_size, hidden_size, batch_first=True)
self.question_gru = nn.GRU(embed_size, hidden_size, batch_first=True)
self.gate = nn.Sequential(
nn.Linear(hidden_size * 4, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)
self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)
self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)
self.answer_fc = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout_p)

self.init_weight()

def init_hidden(self, batch_size):
'''GRU的初始hidden。单层单向'''
hidden = torch.zeros(1, batch_size, self.hidden_size)
hidden = get_variable(hidden)
return hidden

def init_weight(self):
nn.init.xavier_uniform(self.embed.state_dict()['weight'])
components = [self.input_gru, self.question_gru, self.gate, self.attention_grucell,
self.memory_grucell, self.answer_grucell]
for component in components:
for name, param in component.state_dict().items():
if 'weight' in name:
nn.init.xavier_normal(param)
nn.init.xavier_uniform(self.answer_fc.state_dict()['weight'])
self.answer_fc.bias.data.fill_(0)

前向计算参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def forward(self, allfacts, allfacts_mask, questions, questions_mask, alen, n_episode=3):
'''
Args:
allfacts -- [b, n_fact, flen],输入的多个句子
allfacts_mask -- [b, n_fact, flen],mask=1表示是pad的,否则不是
questions -- [b, qlen],问题
questions_mask -- [b, qlen],mask=1:pad
alen -- Answer len
seqbegin_id -- 句子开始标记的wordid
n_episodes --
Returns:
preds -- [b * alen, vocab_size],预测的句子。b*alen合在一起方便后面算交叉熵
'''
# 0. 计算常用的信息,batch_size,一条数据nfact条句子,每个fact长度为flen,每个问题长度为qlen
bsize = allfacts.size(0)
nfact = allfacts.size(1)
flen = allfacts.size(2)
qlen = questions.size(1)

输入模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 1. 输入模块,用RNN编码输入的句子
# TODO 两层循环,待优化
encoded_facts = []
# 对每一条数据,计算facts编码
for facts, facts_mask in zip(allfacts, allfacts_mask):
facts_embeds = self.embed(facts)
facts.embeds = self.dropout(facts_embeds)
hidden = self.init_hidden(nfact)
# 1.1 把输入(多条句子)给到GRU
# b=nf, [nf, flen, h], [1, nf, h]
outputs, hidden = self.input_gru(facts_embeds, hidden)
# 1.2 每条句子真正结束时(real_len)对应的输出,作为该句子的hidden。GRU:ouput=hidden
real_hiddens = []

for i, o in enumerate(outputs):
real_len = facts_mask[i].data.tolist().count(0)
real_hiddens.append(o[real_len - 1])
# 1.3 把所有单个fact连接起来,unsqueeze(0)是为了后面的所有batch的cat
hiddens = torch.cat(real_hiddens).view(nfact, -1).unsqueeze(0)
encoded_facts.append(hiddens)
# [b, nfact, h]
encoded_facts = torch.cat(encoded_facts)

问句模块

1
2
3
4
5
6
7
8
9
10
11
# 2. 问题模块,对问题使用RNN编码
questions_embeds = self.embed(questions)
questions_embeds = self.dropout(questions_embeds)
hidden = self.init_hidden(bsize)
# [b, qlen, h], [1, b, h]
outputs, hidden = self.question_gru(questions_embeds, hidden)
real_questions = []
for i, o in enumerate(outputs):
real_len = questions_mask[i].data.tolist().count(0)
real_questions.append(o[real_len - 1])
encoded_questions = torch.cat(real_questions).view(bsize, -1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 3. Memory模块
memory = encoded_questions
for i in range(n_episode):
# e
e = self.init_hidden(bsize).squeeze(0)
# [nfact, b, h]
encoded_facts_t = encoded_facts.transpose(0, 1)
# 根据memory, episode,计算每一时刻的e。最终的e和memory来计算新的memory
for t in range(nfact):
# [b, h]
bfact = encoded_facts_t[t]
# TODO 计算4个特征,论文是9个
f1 = bfact * encoded_questions
f2 = bfact * memory
f3 = torch.abs(bfact - encoded_questions)
f4 = torch.abs(bfact - memory)
z = torch.cat([f1, f2, f3, f4], dim=1)
# [b, 1] 对每个fact的注意力
gt = self.gate(z)
e = gt * self.attention_grucell(bfact, e) + (1 - gt) * e
# 每一轮的e和旧memory计算新的memory
memory = self.memory_grucell(e, memory)

回答模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 4. Answer模块
# [b, h]
answer_hidden = memory
begin_tokens = get_variable(torch.LongTensor([self.seqbegin_id]*bsize))
# [b, h]
last_word = self.embed(begin_tokens)
preds = []
for i in range(alen):
inputs = torch.cat([last_word, encoded_questions], dim=1)
answer_hidden = self.answer_grucell(inputs, answer_hidden)
# to vocab_size
probs = self.answer_fc(answer_hidden)
# [b, v]
probs = F.log_softmax(probs.float())
_, indics = torch.max(probs, 1)
last_word = self.embed(indics)
# for cross entropy
preds.append(probs.view(bsize, 1, -1))
#print (preds[0].data.shape)
preds = torch.cat(preds, dim=1)
return preds.view(bsize * alen, -1)

配置信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class DefaultConfig(object):
'''配置文件'''
# 数据信息
train_file = "./datasets/tasks_1-20_v1-2/en-10k/qa5_three-arg-relations_train.txt"
test_file = "./datasets/tasks_1-20_v1-2/en-10k/qa5_three-arg-relations_test.txt"

# 一些特殊符号
seq_end = '</s>'
seq_begin = '<s>'
pad = '<pad>'
unk = '<unk>'

# DataLoader信息
batch_size = 128
shuffle = False
# TODO
num_workers = 1

# model
embed_size = 64
hidden_size = 64
# 对inputs推理的轮数
n_episode = 3
dropout_p = 0.1

# train
max_epoch = 500
learning_rate = 0.001
min_loss = 0.01
print_every_epoch = 5

# cuda信息
use_cuda = True
device_id = 0

# model_path
model_path = "./models/DMN.pkl"

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def train(opt, th, train_data):
''' 训练
Args:
opt -- 配置信息
th -- TextHelper实例
train_data -- 训练数据,[[facts, question, answer]]
'''
# 加载原始数据
seqbegin_id = th.word2index(th.seq_begin)

model = DMN(th.vocab_size, opt.embed_size, opt.hidden_size, seqbegin_id, th.word2index(th.pad))
if opt.use_cuda:
model = model.cuda(opt.device_id)

optimizer = optim.Adam(model.parameters(), lr = opt.learning_rate)
loss_func = nn.CrossEntropyLoss(ignore_index=th.word2index(th.pad))

for e in range(opt.max_epoch):
losses = []
for batch_data in get_data_loader(train_data, opt.batch_size, opt.shuffle):
# batch内的数据进行pad,转成Variable
allfacts, allfacts_mask, questions, questions_mask, answers = \
pad_batch_data(batch_data, th)

# 前向
preds = model(allfacts, allfacts_mask, questions, questions_mask,
answers.size(1), opt.n_episode)
# loss
optimizer.zero_grad()
loss = loss_func(preds, answers.view(-1))
losses.append(loss.data.tolist()[0])
# 反向
loss.backward()
optimizer.step()

avg_loss = np.mean(losses)

if avg_loss <= opt.min_loss or e % opt.print_every_epoch == 0 or e == opt.max_epoch - 1:
info = "e={}, loss={}".format(e, avg_loss)
losses = []
print (info)
if e == opt.max_epoch - 1 and avg_loss > opt.min_loss:
print ("epoch finish, loss > min_loss")
torch.save(model, opt.model_path)
break
elif avg_loss <= opt.min_loss:
print ("Early stop")
torch.save(model, opt.model_path)
break

预测和效果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def cal_test_accuracy(model, test_data, th, n_episode=DefaultConfig.n_episode):
'''测试,测试数据'''
batch_size = 1
model.eval()
correct = 0
for item in get_data_loader(test_data, batch_size, False):
facts, facts_mask, question, question_mask, answer = pad_batch_data(item, th)
preds = model(facts, facts_mask, question, question_mask, answer.size(1), n_episode)
#print (answer.data.shape, preds.data.shape)
preds = preds.max(1)[1].data.tolist()
answer = answer.view(-1).data.tolist()
if preds == answer:
correct += 1
print ("acccuracy = ", correct / len(test_data))


def test_one_data(model, item, th, n_episode=DefaultConfig.n_episode):
''' 测试一条数据
Args:
model -- DMN模型
item -- [facts, question, answer]
th -- TextHelper
Returns:
None
'''
# batch_size = 1
model.eval()
item = [item]
facts, facts_mask, question, question_mask, answer = pad_batch_data(item, th)
preds = model(facts, facts_mask, question, question_mask, answer.size(1), n_episode)

item = item[0]
preds = preds.max(1)[1].data.tolist()
fact = item[0][0]
facts = [th.indices2sentence(fact) for fact in item[0]]
facts = [" ".join(fact) for fact in facts]
q = " ".join(th.indices2sentence(item[1]))
a = " ".join(th.indices2sentence(item[2]))
preds = " ".join(th.indices2sentence(preds))

print ("Facts:")
print ("\n".join(facts))
print ("Question:", q)
print ("Answer:", a)
print ("Predict:", preds)
print ()

在本数据集上效果较好,但是数据量小、句子简单,还没有在别的数据集上面进行测试。等忙完了测试一下。