sqlnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.autograd import Variable
  7. import numpy as np
  8. from sqlnet.model.modules.word_embedding import WordEmbedding
  9. from sqlnet.model.modules.aggregator_predict import AggPredictor
  10. from sqlnet.model.modules.selection_predict import SelPredictor
  11. from sqlnet.model.modules.sqlnet_condition_predict import SQLNetCondPredictor
  12. from sqlnet.model.modules.select_number import SelNumPredictor
  13. from sqlnet.model.modules.where_relation import WhereRelationPredictor
  14. # 定义 SQLNet 模型
  15. class SQLNet(nn.Module):
  16. def __init__(self, word_emb, N_word, N_h=100, N_depth=2,
  17. gpu=False, use_ca=True, trainable_emb=False):
  18. super(SQLNet, self).__init__()
  19. self.use_ca = use_ca
  20. self.trainable_emb = trainable_emb
  21. self.gpu = gpu
  22. self.N_h = N_h
  23. self.N_depth = N_depth
  24. self.max_col_num = 45
  25. self.max_tok_num = 200
  26. self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>']
  27. self.COND_OPS = ['>', '<', '==', '!=']
  28. # Word embedding
  29. self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb)
  30. # Predict the number of selected columns
  31. self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)
  32. #Predict which columns are selected
  33. self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca)
  34. #Predict aggregation functions of corresponding selected columns
  35. self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)
  36. #Predict number of conditions, condition columns, condition operations and condition values
  37. self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu)
  38. # Predict condition relationship, like 'and', 'or'
  39. self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca)
  40. self.CE = nn.CrossEntropyLoss()
  41. self.softmax = nn.Softmax(dim=-1)
  42. self.log_softmax = nn.LogSoftmax()
  43. self.bce_logit = nn.BCEWithLogitsLoss()
  44. if gpu:
  45. self.cuda()
  46. def generate_gt_where_seq_test(self, q, gt_cond_seq):
  47. ret_seq = []
  48. for cur_q, ans in zip(q, gt_cond_seq):
  49. temp_q = u"".join(cur_q)
  50. cur_q = [u'<BEG>'] + cur_q + [u'<END>']
  51. record = []
  52. record_cond = []
  53. for cond in ans:
  54. if cond[2] not in temp_q:
  55. record.append((False, cond[2]))
  56. else:
  57. record.append((True, cond[2]))
  58. for idx, item in enumerate(record):
  59. temp_ret_seq = []
  60. if item[0]:
  61. temp_ret_seq.append(0)
  62. temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
  63. temp_ret_seq.append(len(cur_q)-1)
  64. else:
  65. temp_ret_seq.append([0,len(cur_q)-1])
  66. record_cond.append(temp_ret_seq)
  67. ret_seq.append(record_cond)
  68. return ret_seq
  69. def forward(self, q, col, col_num, gt_where = None, gt_cond=None, reinforce=False, gt_sel=None, gt_sel_num=None):
  70. B = len(q)
  71. sel_num_score = None
  72. agg_score = None
  73. sel_score = None
  74. cond_score = None
  75. #Predict aggregator
  76. if self.trainable_emb:
  77. x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
  78. col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(col)
  79. max_x_len = max(x_len)
  80. agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var,
  81. col_name_len, col_len, col_num, gt_sel=gt_sel)
  82. x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
  83. col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(col)
  84. max_x_len = max(x_len)
  85. sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
  86. col_name_len, col_len, col_num)
  87. x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
  88. col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(col)
  89. max_x_len = max(x_len)
  90. cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
  91. where_rela_score = None
  92. else:
  93. x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
  94. col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(col)
  95. sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
  96. # x_emb_var: embedding of each question
  97. # x_len: length of each question
  98. # col_inp_var: embedding of each header
  99. # col_name_len: length of each header
  100. # col_len: number of headers in each table, array type
  101. # col_num: number of headers in each table, list type
  102. if gt_sel_num:
  103. pr_sel_num = gt_sel_num
  104. else:
  105. pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
  106. sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
  107. if gt_sel:
  108. pr_sel = gt_sel
  109. else:
  110. num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
  111. sel = sel_score.data.cpu().numpy()
  112. pr_sel = [list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))]
  113. agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=pr_sel, gt_sel_num=pr_sel_num)
  114. where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
  115. cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
  116. return (sel_num_score, sel_score, agg_score, cond_score, where_rela_score)
  117. def loss(self, score, truth_num, gt_where):
  118. sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
  119. B = len(truth_num)
  120. loss = 0
  121. # Evaluate select number
  122. # sel_num_truth = map(lambda x:x[0], truth_num)
  123. sel_num_truth = [x[0] for x in truth_num]
  124. sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
  125. if self.gpu:
  126. sel_num_truth = Variable(sel_num_truth.cuda())
  127. else:
  128. sel_num_truth = Variable(sel_num_truth)
  129. loss += self.CE(sel_num_score, sel_num_truth)
  130. # Evaluate select column
  131. T = len(sel_score[0])
  132. truth_prob = np.zeros((B,T), dtype=np.float32)
  133. for b in range(B):
  134. truth_prob[b][list(truth_num[b][1])] = 1
  135. data = torch.from_numpy(truth_prob)
  136. if self.gpu:
  137. sel_col_truth_var = Variable(data.cuda())
  138. else:
  139. sel_col_truth_var = Variable(data)
  140. sigm = nn.Sigmoid()
  141. sel_col_prob = sigm(sel_score)
  142. bce_loss = -torch.mean(
  143. 3*(sel_col_truth_var * torch.log(sel_col_prob+1e-10)) +
  144. (1-sel_col_truth_var) * torch.log(1-sel_col_prob+1e-10)
  145. )
  146. loss += bce_loss
  147. # Evaluate select aggregation
  148. for b in range(len(truth_num)):
  149. data = torch.from_numpy(np.array(truth_num[b][2]))
  150. if self.gpu:
  151. sel_agg_truth_var = Variable(data.cuda())
  152. else:
  153. sel_agg_truth_var = Variable(data)
  154. sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
  155. loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
  156. cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
  157. # Evaluate the number of conditions
  158. # cond_num_truth = map(lambda x:x[3], truth_num)
  159. cond_num_truth = [x[3] for x in truth_num]
  160. data = torch.from_numpy(np.array(cond_num_truth))
  161. if self.gpu:
  162. try:
  163. cond_num_truth_var = Variable(data.cuda())
  164. except:
  165. print ("cond_num_truth_var error")
  166. print (data)
  167. exit(0)
  168. else:
  169. cond_num_truth_var = Variable(data)
  170. loss += self.CE(cond_num_score, cond_num_truth_var)
  171. # Evaluate the columns of conditions
  172. T = len(cond_col_score[0])
  173. truth_prob = np.zeros((B, T), dtype=np.float32)
  174. for b in range(B):
  175. if len(truth_num[b][4]) > 0:
  176. truth_prob[b][list(truth_num[b][4])] = 1
  177. data = torch.from_numpy(truth_prob)
  178. if self.gpu:
  179. cond_col_truth_var = Variable(data.cuda())
  180. else:
  181. cond_col_truth_var = Variable(data)
  182. sigm = nn.Sigmoid()
  183. cond_col_prob = sigm(cond_col_score)
  184. bce_loss = -torch.mean(
  185. 3*(cond_col_truth_var * torch.log(cond_col_prob+1e-10)) +
  186. (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
  187. loss += bce_loss
  188. # Evaluate the operator of conditions
  189. for b in range(len(truth_num)):
  190. if len(truth_num[b][5]) == 0:
  191. continue
  192. data = torch.from_numpy(np.array(truth_num[b][5]))
  193. if self.gpu:
  194. cond_op_truth_var = Variable(data.cuda())
  195. else:
  196. cond_op_truth_var = Variable(data)
  197. cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
  198. try:
  199. loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
  200. except:
  201. print (cond_op_pred)
  202. print (cond_op_truth_var)
  203. exit(0)
  204. #Evaluate the strings of conditions
  205. for b in range(len(gt_where)):
  206. for idx in range(len(gt_where[b])):
  207. cond_str_truth = gt_where[b][idx]
  208. if len(cond_str_truth) == 1:
  209. continue
  210. data = torch.from_numpy(np.array(cond_str_truth[1:]))
  211. if self.gpu:
  212. cond_str_truth_var = Variable(data.cuda())
  213. else:
  214. cond_str_truth_var = Variable(data)
  215. str_end = len(cond_str_truth)-1
  216. cond_str_pred = cond_str_score[b, idx, :str_end]
  217. loss += (self.CE(cond_str_pred, cond_str_truth_var) \
  218. / (len(gt_where) * len(gt_where[b])))
  219. # Evaluate condition relationship, and / or
  220. # where_rela_truth = map(lambda x:x[6], truth_num)
  221. where_rela_truth = [x[6] for x in truth_num]
  222. data = torch.from_numpy(np.array(where_rela_truth))
  223. if self.gpu:
  224. try:
  225. where_rela_truth = Variable(data.cuda())
  226. except:
  227. print ("where_rela_truth error")
  228. print (data)
  229. exit(0)
  230. else:
  231. where_rela_truth = Variable(data)
  232. loss += self.CE(where_rela_score, where_rela_truth)
  233. return loss
  234. def check_acc(self, vis_info, pred_queries, gt_queries):
  235. def gen_cond_str(conds, header):
  236. if len(conds) == 0:
  237. return 'None'
  238. cond_str = []
  239. for cond in conds:
  240. cond_str.append(header[cond[0]] + ' ' +
  241. self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower())
  242. return 'WHERE ' + ' AND '.join(cond_str)
  243. tot_err = sel_num_err = agg_err = sel_err = 0.0
  244. cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
  245. for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
  246. good = True
  247. sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry['agg'], pred_qry['cond_conn_op']
  248. sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry['agg'], gt_qry['cond_conn_op']
  249. if where_rela_gt != where_rela_pred:
  250. good = False
  251. cond_rela_err += 1
  252. if len(sel_pred) != len(sel_gt):
  253. good = False
  254. sel_num_err += 1
  255. pred_sel_dict = {k:v for k,v in zip(list(sel_pred), list(agg_pred))}
  256. gt_sel_dict = {k:v for k,v in zip(sel_gt, agg_gt)}
  257. if set(sel_pred) != set(sel_gt):
  258. good = False
  259. sel_err += 1
  260. agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
  261. agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
  262. if agg_pred != agg_gt:
  263. good = False
  264. agg_err += 1
  265. cond_pred = pred_qry['conds']
  266. cond_gt = gt_qry['conds']
  267. if len(cond_pred) != len(cond_gt):
  268. good = False
  269. cond_num_err += 1
  270. else:
  271. cond_op_pred, cond_op_gt = {}, {}
  272. cond_val_pred, cond_val_gt = {}, {}
  273. for p, g in zip(cond_pred, cond_gt):
  274. cond_op_pred[p[0]] = p[1]
  275. cond_val_pred[p[0]] = p[2]
  276. cond_op_gt[g[0]] = g[1]
  277. cond_val_gt[g[0]] = g[2]
  278. if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
  279. cond_col_err += 1
  280. good=False
  281. where_op_pred = [cond_op_pred[x] for x in sorted(cond_op_pred.keys())]
  282. where_op_gt = [cond_op_gt[x] for x in sorted(cond_op_gt.keys())]
  283. if where_op_pred != where_op_gt:
  284. cond_op_err += 1
  285. good=False
  286. where_val_pred = [cond_val_pred[x] for x in sorted(cond_val_pred.keys())]
  287. where_val_gt = [cond_val_gt[x] for x in sorted(cond_val_gt.keys())]
  288. if where_val_pred != where_val_gt:
  289. cond_val_err += 1
  290. good=False
  291. if not good:
  292. tot_err += 1
  293. return np.array((sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err , cond_rela_err)), tot_err
  294. def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
  295. """
  296. :param score:
  297. :param q: token-questions
  298. :param col: token-headers
  299. :param raw_q: original question sequence
  300. :return:
  301. """
  302. def merge_tokens(tok_list, raw_tok_str):
  303. tok_str = raw_tok_str# .lower()
  304. alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
  305. special = {'-LRB-':'(',
  306. '-RRB-':')',
  307. '-LSB-':'[',
  308. '-RSB-':']',
  309. '``':'"',
  310. '\'\'':'"',
  311. '--':u'\u2013'}
  312. ret = ''
  313. double_quote_appear = 0
  314. for raw_tok in tok_list:
  315. if not raw_tok:
  316. continue
  317. tok = special.get(raw_tok, raw_tok)
  318. if tok == '"':
  319. double_quote_appear = 1 - double_quote_appear
  320. if len(ret) == 0:
  321. pass
  322. elif len(ret) > 0 and ret + ' ' + tok in tok_str:
  323. ret = ret + ' '
  324. elif len(ret) > 0 and ret + tok in tok_str:
  325. pass
  326. elif tok == '"':
  327. if double_quote_appear:
  328. ret = ret + ' '
  329. # elif tok[0] not in alphabet:
  330. # pass
  331. elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
  332. and (ret[-1] != '"' or not double_quote_appear):
  333. ret = ret + ' '
  334. ret = ret + tok
  335. return ret.strip()
  336. sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
  337. # [64,4,6], [64,14], ..., [64,4]
  338. sel_num_score = sel_num_score.data.cpu().numpy()
  339. sel_score = sel_score.data.cpu().numpy()
  340. agg_score = agg_score.data.cpu().numpy()
  341. where_rela_score = where_rela_score.data.cpu().numpy()
  342. ret_queries = []
  343. B = len(agg_score)
  344. cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
  345. [x.data.cpu().numpy() for x in cond_score]
  346. for b in range(B):
  347. cur_query = {}
  348. cur_query['sel'] = []
  349. cur_query['agg'] = []
  350. sel_num = np.argmax(sel_num_score[b])
  351. max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
  352. # find the most-probable columns' indexes
  353. max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
  354. cur_query['sel'].extend([int(i) for i in max_col_idxes])
  355. cur_query['agg'].extend([i[0] for i in max_agg_idxes])
  356. cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
  357. cur_query['conds'] = []
  358. cond_num = np.argmax(cond_num_score[b])
  359. all_toks = ['<BEG>'] + q[b] + ['<END>']
  360. max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
  361. for idx in range(cond_num):
  362. cur_cond = []
  363. cur_cond.append(max_idxes[idx]) # where-col
  364. cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op
  365. cur_cond_str_toks = []
  366. for str_score in cond_str_score[b][idx]:
  367. str_tok = np.argmax(str_score[:len(all_toks)])
  368. str_val = all_toks[str_tok]
  369. if str_val == '<END>':
  370. break
  371. cur_cond_str_toks.append(str_val)
  372. cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
  373. cur_query['conds'].append(cur_cond)
  374. ret_queries.append(cur_query)
  375. return ret_queries