Seq2SQL :使用强化学习通过自然语言生成SQL
Seq2SQL属于natural language interface (NLI)的领域,方便普通用户接入并查询数据库中的内容,即用户不需要了解SQL语句,只需要通过自然语言,就可查询所需内容。
Seq2SQL借鉴的是Seq2Seq的思想,与Seq2Seq应用于机器翻译与Chatbot类似,Seq2SQL将输入的语句encode后再decode成结构化的SQL语言输出,强化学习是在Seq2SQL中的最后一个模块中应用。同时,这篇论文还推出一个数据集WikiSQL,数据集内有人工标注好的问句及其对应SQL语句。
试验结果显示,Seq2SQL的准确率也不是特别的高,只有60.3%
Seq2SQL结构:
Seq2SQL由三部分组成:
data:image/s3,"s3://crabby-images/dd726/dd726e08c59f4fc6b73fe8f2ab40650aeccd9ec2" alt=""
第一部分: Aggregation classifier 这一部分其实是一个分类器,将用户输入的语句分类成是select count/max/min 等统计相关的约束条件
在此处采用的Augmented Pointer Network,Augmented Pointer Network总体而言也是ecoder-to-decoder的结构,
encoder采用的是两层的bi-LSTM, decoder 采用的是两层的unidirectional LSTM,
encoder输出h,ht对应的是第t个词的输出状态
decoder的每一步是,输入y s-1,输出状态gs,接着,decoder为每个位置t生成一个attention的score
data:image/s3,"s3://crabby-images/d0f79/d0f7965cba0b10601bf96b59a50a0349afec7442" alt=""
data:image/s3,"s3://crabby-images/38e99/38e99bbaa0b7b29ee2d802817919851f1d61921e" alt=""
在Seq2SQL中,首先为input生成一个表征向量
(agg:aggregation clasifier, inp:input,enc:encoder) 首先为Augmented Pointer Network类似,计算出一个attention的分数,
,data:image/s3,"s3://crabby-images/a8ab1/a8ab1bbb572d9526e1cac06280abd80608556bdd" alt=""
data:image/s3,"s3://crabby-images/f61f7/f61f737c5b726cfe95e20556b5e2d95129855f7c" alt=""
data:image/s3,"s3://crabby-images/0c3db/0c3db2736978e1dc4005a126849377cfdf9c5a65" alt=""
data:image/s3,"s3://crabby-images/a8ab1/a8ab1bbb572d9526e1cac06280abd80608556bdd" alt=""
量化后,通过softmax函数 data:image/s3,"s3://crabby-images/1f18b/1f18bb01e9106118e018875fd5e5e943d8d05b15" alt=""
data:image/s3,"s3://crabby-images/1f18b/1f18bb01e9106118e018875fd5e5e943d8d05b15" alt=""
input的表征向量 data:image/s3,"s3://crabby-images/13771/137716c591034ca420b5cdb2d1781fb428688ed1" alt=""
data:image/s3,"s3://crabby-images/13771/137716c591034ca420b5cdb2d1781fb428688ed1" alt=""
通过一个多层的网络和softmax完成分类任务
data:image/s3,"s3://crabby-images/0935e/0935e03da68c4538017f6151b3539786e03060a7" alt=""
data:image/s3,"s3://crabby-images/89641/89641b4f617fba2bdebfe2e9ae91a9c492bf5e59" alt=""
第二部分: select column 这一部分是看用户输入的问句命中了哪个column
首先将每个column name 通过LSTM encode
data:image/s3,"s3://crabby-images/3f859/3f8590dfa7ea4d7803a9c15154e93f79ee48faea" alt=""
WikiSQL: data:image/s3,"s3://crabby-images/7b889/7b8896f9eb40bbe8095778793b53d58626bcf324" alt=""
data:image/s3,"s3://crabby-images/087ab/087ab9a352301d82ba88ac2f8f6882c13c5f3a8f" alt=""
将用户输入encode成与第一部分
类似的data:image/s3,"s3://crabby-images/9459b/9459be6d30ece10a75e24e8807a7457d4ed7698a" alt=""
data:image/s3,"s3://crabby-images/16ab0/16ab039100fbd07d5ca6e2f33185d4e145596030" alt=""
data:image/s3,"s3://crabby-images/9459b/9459be6d30ece10a75e24e8807a7457d4ed7698a" alt=""
最终通过一个多层的神经元和softmax确定是命中哪一行
data:image/s3,"s3://crabby-images/d64b4/d64b4e63edd958257d6bd2c4076fb6ca682be3e8" alt=""
data:image/s3,"s3://crabby-images/3f859/3f8590dfa7ea4d7803a9c15154e93f79ee48faea" alt=""
第三部分:where clause 确定约束条件,因为最终生成的SQL可能与标注中的不太一样,但是依旧有一样的结果,所以不能像前两部分一样使用交叉熵作为loss训练,因此使用强化训练中reward函数 (g: ground-truth), loss使用梯度
data:image/s3,"s3://crabby-images/871bb/871bb5fb970be15da2d38a36710314423aec2eb5" alt=""
WikiSQL包含一系列与SQL相关的问题集以及SQL table
data:image/s3,"s3://crabby-images/7b889/7b8896f9eb40bbe8095778793b53d58626bcf324" alt=""
data:image/s3,"s3://crabby-images/da9e4/da9e4ea02ebf22f8cbbde0eae174dce26f6e1568" alt=""
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步