基于Bert的NL2SQL模型:一个简明的Baseline
By 苏剑林 | 2019-06-29 | 145422位读者 |在之前的文章《当Bert遇上Keras:这可能是Bert最简单的打开姿势》中,我们介绍了基于微调Bert的三个NLP例子,算是体验了一把Bert的强大和Keras的便捷。而在这篇文章中,我们再添一个例子:基于Bert的NL2SQL模型。
NL2SQL的NL也就是Natural Language,所以NL2SQL的意思就是“自然语言转SQL语句”,近年来也颇多研究,它算是人工智能领域中比较实用的一个任务。而笔者做这个模型的契机,则是今年我司举办的首届“中文NL2SQL挑战赛”:
首届中文NL2SQL挑战赛,使用金融以及通用领域的表格数据作为数据源,提供在此基础上标注的自然语言与SQL语句的匹配对,希望选手可以利用数据训练出可以准确转换自然语言到SQL的模型。
这个NL2SQL比赛算是今年比较大型的NLP赛事了,赛前投入了颇多人力物力进行宣传推广,比赛的奖金也颇丰富,唯一的问题是NL2SQL本身算是偏冷门的研究领域,所以注定不会太火爆,为此主办方也放出了一个Baseline,基于Pytorch写的,希望能降低大家的入门难度。
抱着“Baseline怎么能少得了Keras版”的心态,我抽时间自己用Keras做了做这个比赛,为了简化模型并且提升效果也加载了预训练的Bert模型,最终形成此文。
数据示例 #
每个数据样本如下:
{
"table_id": "a1b2c3d4", # 相应表格的id
"question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句
"sql":{ # 真实SQL
"sel": [7], # SQL选择的列
"agg": [0], # 选择的列相应的聚合函数, '0'代表无
"cond_conn_op": 0, # 条件之间的关系
"conds": [
[1, 2, "世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府"
[6, 0, "1"]
]
}
}
# 其中条件运算符、聚合符、连接符分别如下
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
conn_sql_dict = {0:"", 1:"and", 2:"or"}
然后每个样本都对应着一个数据表,里边包含该表的所有列名,以及相应的数据记录,而原则上生成的SQL语句在对应的数据表上是可以执行的,并且都能返回有效的结果。
可以看到,虽然说是NL2SQL,但事实上主办方已经将SQL语句做了十分清晰的格式化,这样一来这个任务就可以相当大地简化了。比如sel这个字段,其实就是一个多标签分类模型,只不过类别可能会随时变化,因为这里的类别实际上对应着数据表的列,而每个样本的数据表及其含义都不尽相同,所以我们要根据表的列名来动态地编码一个类别向量;至于agg则跟sel一一对应,并且类别是固定的,cond_conn_op则是一个单标签分类问题。
最后的conds相对复杂一些,它需要结合字标注和分类,因为要同时决定哪一列是条件,条件的运算关系,以及条件对应的值,并且要注意的是,条件值并不总是question的一个片段,它有可能是格式化后的结果,比如question里边是“16年”,那么条件值可能是格式化之后的“2016”,不过,由于主办方保证生成的sql是能够在对应数据表中执行并且出有效结果的,所以如果条件运算符是“==”的时候,那么条件值肯定会出现在数据表对应列的值之中,比如上述示例样本中,数据表的第一列肯定会有“世茂茂悦府”这个值,而通过这个信息我们也可以去校正预测结果。
模型结构 #
正式看本模型之前,读者不妨自己思考一番,想一下自己会怎么做。只有经过思考之后,才会明白其中的困难所在,也就这样才能理解本模型的一些处理技巧的要点。
本文的模型示意图如下:
作为一个SQL,最基本的是要决定哪些列会被select,而每个表的列的含义均不一样,所以我们要将question句子连同数据表的所有表头拼起来,一同输入到Bert模型中进行实时编码,其中每一个表头也视为一个句子,用[CLS]***[SEP]括住。经过Bert之后,我们就得到了一系列编码向量,然后就看我们怎么去用这些向量了。
第一个[CLS]对应的向量,我们可以认为是整个问题的句向量,我们用它来预测conds的连接符。后面的每个[CLS]对应的向量,我们认为是每个表头的编码向量,我们把它拿出来,用来预测该表头表示的列是否应该被select。注意,此处预测有一个技巧,就是前面说了除了预测sel外,还要预测对应的agg,其中agg一共有6个类别,代表不同的距离运算,既然如此,我们干脆多增加一个类别,第7个类别代表着此列不被select,这样一来,每一列都对应着一个7分类问题,如果分到前6类,那么代表着此类被select而且同时预测到了agg,如果分到了第7类,那意味着此类不被select。
现在就剩下比较复杂的conds了,就是where col_1 == value_1
这样子的,col_1、value_1以及运算符==都要找出来。conds的预测分两步,第一步预测条件值,第二步预测条件列,预测条件值其实就是一个序列标注问题,而条件值对应的运算符有4个,我们同样新增一类变成5类,第5类代表着当前字不被标注,否则就被标注,这样我们就能预测出条件值和运算符了。剩下的就是预测条件值对应的列,我们将标注出来的值的字向量跟每个表头的向量一一算相似度,然后softmax。我这里算相似度的方法是最简单的,直接将字向量和表头向量拼接起来,然后过一个全连接层后再接一个Dense(1)
,做得这么简单,一是因为本文主要目的是给出一个基本可行的demo而非一个完善的程序,需要留些改进空间给读者,二是因为做得复杂了,就很容易显存不足而OOM了。
顺便说一下,本文的模型是自己根据比赛任务“闭门造车”出来的,如果读者需要跟我讨论主流的NL2SQL模型,我可能就无能为力了,请见谅。
实验结果 #
本文的模型代码位于:
https://github.com/bojone/bert_in_keras/blob/master/nl2sql_baseline.py
注意,如果你执行此代码报错,那么你可能需要修改一下Keras的backend/tensorflow_backend.py
,将sparse_categorical_crossentropy
函数中原本是
logits = tf.reshape(output, [-1, int(output_shape[-1])])
的那一行改为
logits = tf.reshape(output, [-1, tf.shape(output)[-1]])
我已经向官方提了这个修正,并且已经通过了(请看这里),在未来的版本应该就自动包含这个特性了。
还是那句话,只要你认真地观察过比赛数据、独立思考过这个任务,上面本文介绍的模型其实很容易理解,而简单的模型能够有不错的效果,则得益于Bert强大的语义编码能力。在线下valid集上,本文的模型生成的SQL全匹配率大概为58%左右,而官方的评估指标是(全匹配率 + 执行匹配率) / 2,也就是说你有可能写出跟标注答案不同的SQL语句,但执行结果是一致的,这也算正确一半。
这样一来最终得分肯定会比58%要高,我估计是65%左右吧,我看了看当前榜单,如果65%的话还可以排在前几名(第一名的大佬现在是70%)。由于公司员工不允许打榜评测,所以我没参与过评测,也不知道线上提交会有多少,有兴趣试用的选手自行提交测试吧。
对了,要跑这个脚本最好有个1080ti或者以上的显卡,如果你没有那么多显存,你可以试着降低maxlen和batch size。还有,现在Bert可用的中文预训练权重有两个,一个是官方版的,一个是哈工大版的,两个的最终效果差不多,但是哈工大版的收敛更快。
纵观整个模型,在实现上,最困难的应当是要精细地考虑各种mask,在上述脚本中,xm, hm, cm
是三个mask变量,就是要去除训练过程中padding部分带来的效应。注意mask不是Keras独有,不管你用Tensorflow还是Pytorch,理论上你都要仔细地处理好mask。如果读者实在读不懂mask部分,欢迎留言提问讨论,但是在提问之前,请你回答以下问题:
mask之前的序列大概是怎样的?mask之后序列的哪些位置发生了变化?变成了怎么样?
回答这个问题是证明“你已经能明白程序做了什么运算了,只是不明白为什么这样运算”。而如果你连这个运算本身都看不明白,我想我们很难沟通下去了(哪部分发生了变化总能知道吧...),还是好好学学Keras或者Tensorflow再来玩吧,不能想着一蹴而就。
前后处理 #
对于模型来说,实现的困难部分在于mask。不过如果看整个脚本的话,其实代码占比最多的就是数据的读取和预处理、结果的后处理这两部分罢了,真正搭建模型也就只有二十行左右(再次惊叹Keras的简洁和Bert的强大吧)。
其中我们说过条件值不一定出现在question中,那如何用对question字标注的方法来抽取条件值呢?
我的方法是,如果条件值不出现在question,那么我就对question分词,然后找出question的所有1gram、2gram、3gram,然后根据条件值找出一个最相近的ngram作为标注片段。而在预测的时候,如果找到了一个ngram作为条件值、并且运算符为==时,我们会判断这个ngram有没有在数据库出现过,如果有直接保留,如果没有则在数据库中找一个最相近的值。
相应的过程在代码中都有体现,欢迎细读。
文章小结 #
欢迎大家来玩~
祝大家取得好成绩!
转载到请包括本文地址:https://kexue.fm/archives/6771
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 29, 2019). 《 基于Bert的NL2SQL模型:一个简明的Baseline 》[Blog post]. Retrieved from https://kexue.fm/archives/6771
@online{kexuefm-6771,
title={ 基于Bert的NL2SQL模型:一个简明的Baseline},
author={苏剑林},
year={2019},
month={Jun},
url={\url{https://kexue.fm/archives/6771}},
}
July 19th, 2019
苏神,我看了你的代码,想做一些改动,比如cls输出得全局向量和batch_gather获得的None h_len 768向量做一个简单的attention,这里需要repeat cls 扩大他的维度,但是您全输入的都是None,所以想问一下 ,批次喂入是保证了一个批次内长度一致 不易浪费 单数也导致了不同批次 长度不一致,所以输入input的维度全是None 这样的话 该怎么操作呢!对未知的维度,有些指定的操作就不行了,比如repeat需要获得一个repeat的数量!这该怎么解决阿 还是推倒全部重来!
K.shape(x)[1]可以动态获取seq的长度,K.tile可以实现重复。
你能遇到的问题,别人都想到过了...所以不用“阿”也不用“!”,冷静找找资料吧。
哈哈哈哈 爱死你了 苏神,我是试过什么K.int_shape()啊 hm.shape(1)啊等 都没成功,才来回复的呢!感谢感谢。以后好好学keras好好爱keras
DANDAN 你好
我想请教下你对苏神代码的改进 是如何实现的?我现在有同样的需求,但是不太清楚怎么实现...
September 10th, 2019
"剩下的就是预测条件值对应的列,我们将标注出来的值的字向量跟每个表头的向量一一算相似度,然后softmax。"
这里利用向量拼接然后 Dense(1) 求"相似度",然后 softmax 的方式,也许有点点小问题。
这事实上等价于将一个相同的值 "Dense(1)(字向量)",加到每一个 "Dense(1)(表头向量)" 中,但其实这个操作并不影响最终的 softmax 值。您在源码中确实也是这么实现的。
https://github.com/bojone/bert_in_keras/blob/master/nl2sql_baseline.py#L226-L232
```python
x = Lambda(lambda x: K.expand_dims(x, 2))(x)
x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h)
pcsel_1 = Dense(1)(x)
pcsel_2 = Dense(1)(x4h)
pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2])
pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm])
pcsel = Activation('softmax')(pcsel)
```
假设 header 有 n 个 column,
pcsel_1 + pcsel_2 的时候,pcsel_1 在 dim=1 这个维度上复制了 n 次,所以相当于 pcsel_2 中每个值加上了一个常数,不影响 softmax 之后的值,其后果就是 cond 中 col 的选择和当前的字向量无关。
通过如下代码观察 pcsel 的值:
```python
model_pcsel = Model(
[x1_in, x2_in, h_in, hm_in],
[pcsel]
)
```
可以发现,对于同一个句子,question 中所有的字选择的 column 是完全一样的。
是的,我明白你说的问题了,感谢你的细致思考和反馈,确实是我的疏忽。
我已经修正了代码,并在文章中修正了评测结果~
April 2nd, 2020
所以这个数据默认最多2个条件吗?(2个conds,一个cond_conn_op)?如果是这样那其实这个数据集还是有点太简单了。
是的
April 21st, 2020
苏神您好,想上手试试,苦于错过了比赛,无法得到数据集,请问您这边有数据集可以分享下吗?
没有
April 26th, 2020
苏神这个不出个 bert4keras版吗
这个不难呀
March 14th, 2022
return K.tf.batch_gather(seq, idxs)这个一直报错,怎么改呢?
自己import tensorflow as tf,然后改为tf
January 30th, 2024
想知道NL2sql这个数据集是怎么标注的,全部手工标注吗,没有工具辅助吗
有带界面的标注工具吧
February 21st, 2024
https://github.com/ZhuiyiTechnology/nl2sql_baseline.git--项目中用到的数据集没有找到,请问如何获取数据集?
我现在也没有了,可以尝试联系主办方~