sqlnet_condition_predict.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from sqlnet.model.modules.net_utils import run_lstm, col_name_encode
  8. class SQLNetCondPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu):
  10. super(SQLNetCondPredictor, self).__init__()
  11. self.N_h = N_h
  12. self.max_tok_num = max_tok_num
  13. self.max_col_num = max_col_num
  14. self.gpu = gpu
  15. self.use_ca = use_ca
  16. self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  17. self.cond_num_att = nn.Linear(N_h, 1)
  18. self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h),
  19. nn.Tanh(), nn.Linear(N_h, 5))
  20. self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  21. self.cond_num_col_att = nn.Linear(N_h, 1)
  22. self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h)
  23. self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h)
  24. self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  25. if use_ca:
  26. print ("Using column attention on where predicting")
  27. self.cond_col_att = nn.Linear(N_h, N_h)
  28. else:
  29. print ("Not using column attention on where predicting")
  30. self.cond_col_att = nn.Linear(N_h, 1)
  31. self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  32. self.cond_col_out_K = nn.Linear(N_h, N_h)
  33. self.cond_col_out_col = nn.Linear(N_h, N_h)
  34. self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
  35. self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  36. if use_ca:
  37. self.cond_op_att = nn.Linear(N_h, N_h)
  38. else:
  39. self.cond_op_att = nn.Linear(N_h, 1)
  40. self.cond_op_out_K = nn.Linear(N_h, N_h)
  41. self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  42. self.cond_op_out_col = nn.Linear(N_h, N_h)
  43. self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(),
  44. nn.Linear(N_h, 4))
  45. self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  46. self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num, hidden_size=N_h, num_layers=N_depth, batch_first=True, dropout=0.3)
  47. self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  48. self.cond_str_out_g = nn.Linear(N_h, N_h)
  49. self.cond_str_out_h = nn.Linear(N_h, N_h)
  50. self.cond_str_out_col = nn.Linear(N_h, N_h)
  51. self.cond_str_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
  52. self.softmax = nn.Softmax(dim=-1)
  53. def gen_gt_batch(self, split_tok_seq):
  54. B = len(split_tok_seq)
  55. max_len = max([max([len(tok) for tok in tok_seq]+[0]) for
  56. tok_seq in split_tok_seq]) - 1 # The max seq len in the batch.
  57. if max_len < 1:
  58. max_len = 1
  59. ret_array = np.zeros((
  60. B, 4, max_len, self.max_tok_num), dtype=np.float32)
  61. ret_len = np.zeros((B, 4))
  62. for b, tok_seq in enumerate(split_tok_seq):
  63. idx = 0
  64. for idx, one_tok_seq in enumerate(tok_seq):
  65. out_one_tok_seq = one_tok_seq[:-1]
  66. ret_len[b, idx] = len(out_one_tok_seq)
  67. for t, tok_id in enumerate(out_one_tok_seq):
  68. ret_array[b, idx, t, tok_id] = 1
  69. if idx < 3:
  70. ret_array[b, idx+1:, 0, 1] = 1
  71. ret_len[b, idx+1:] = 1
  72. ret_inp = torch.from_numpy(ret_array)
  73. if self.gpu:
  74. ret_inp = ret_inp.cuda()
  75. ret_inp_var = Variable(ret_inp)
  76. return ret_inp_var, ret_len #[B, IDX, max_len, max_tok_num]
  77. def forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
  78. col_len, col_num, gt_where, gt_cond, reinforce):
  79. max_x_len = max(x_len)
  80. B = len(x_len)
  81. if reinforce:
  82. raise NotImplementedError('Our model doesn\'t have RL')
  83. # Predict the number of conditions
  84. # First use column embeddings to calculate the initial hidden unit
  85. # Then run the LSTM and predict condition number.
  86. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc)
  87. num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
  88. for idx, num in enumerate(col_num):
  89. if num < max(col_num):
  90. num_col_att_val[idx, num:] = -100
  91. num_col_att = self.softmax(num_col_att_val)
  92. K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
  93. cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()
  94. cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()
  95. h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
  96. hidden=(cond_num_h1, cond_num_h2))
  97. num_att_val = self.cond_num_att(h_num_enc).squeeze()
  98. for idx, num in enumerate(x_len):
  99. if num < max_x_len:
  100. num_att_val[idx, num:] = -100
  101. num_att = self.softmax(num_att_val)
  102. K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
  103. cond_num_score = self.cond_num_out(K_cond_num)
  104. #Predict the columns of conditions
  105. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc)
  106. h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)
  107. if self.use_ca:
  108. col_att_val = torch.bmm(e_cond_col,
  109. self.cond_col_att(h_col_enc).transpose(1, 2))
  110. for idx, num in enumerate(x_len):
  111. if num < max_x_len:
  112. col_att_val[idx, :, num:] = -100
  113. col_att = self.softmax(col_att_val.view(
  114. (-1, max_x_len))).view(B, -1, max_x_len)
  115. K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
  116. else:
  117. col_att_val = self.cond_col_att(h_col_enc).squeeze()
  118. for idx, num in enumerate(x_len):
  119. if num < max_x_len:
  120. col_att_val[idx, num:] = -100
  121. col_att = self.softmax(col_att_val)
  122. K_cond_col = (h_col_enc *
  123. col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)
  124. cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) +
  125. self.cond_col_out_col(e_cond_col)).squeeze()
  126. max_col_num = max(col_num)
  127. for b, num in enumerate(col_num):
  128. if num < max_col_num:
  129. cond_col_score[b, num:] = -100
  130. #Predict the operator of conditions
  131. chosen_col_gt = []
  132. if gt_cond is None:
  133. cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
  134. col_scores = cond_col_score.data.cpu().numpy()
  135. chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]])
  136. for b in range(len(cond_nums))]
  137. else:
  138. # print gt_cond
  139. chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond]
  140. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
  141. col_len, self.cond_op_name_enc)
  142. h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
  143. col_emb = []
  144. for b in range(B):
  145. cur_col_emb = torch.stack([e_cond_col[b, x]
  146. for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
  147. (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4)
  148. col_emb.append(cur_col_emb)
  149. col_emb = torch.stack(col_emb)
  150. if self.use_ca:
  151. op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
  152. col_emb.unsqueeze(3)).squeeze()
  153. for idx, num in enumerate(x_len):
  154. if num < max_x_len:
  155. op_att_val[idx, :, num:] = -100
  156. op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1)
  157. K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
  158. else:
  159. op_att_val = self.cond_op_att(h_op_enc).squeeze()
  160. for idx, num in enumerate(x_len):
  161. if num < max_x_len:
  162. op_att_val[idx, num:] = -100
  163. op_att = self.softmax(op_att_val)
  164. K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)
  165. cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
  166. self.cond_op_out_col(col_emb)).squeeze()
  167. #Predict the string of conditions
  168. h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
  169. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
  170. col_len, self.cond_str_name_enc)
  171. col_emb = []
  172. for b in range(B):
  173. cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
  174. [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
  175. col_emb.append(cur_col_emb)
  176. col_emb = torch.stack(col_emb)
  177. if gt_where is not None:
  178. gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
  179. g_str_s_flat, _ = self.cond_str_decoder(
  180. gt_tok_seq.view(B*4, -1, self.max_tok_num))
  181. g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)
  182. h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
  183. g_ext = g_str_s.unsqueeze(3)
  184. col_ext = col_emb.unsqueeze(2).unsqueeze(2)
  185. cond_str_score = self.cond_str_out(
  186. self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
  187. self.cond_str_out_col(col_ext)).squeeze()
  188. for b, num in enumerate(x_len):
  189. if num < max_x_len:
  190. cond_str_score[b, :, :, num:] = -100
  191. else:
  192. h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
  193. col_ext = col_emb.unsqueeze(2).unsqueeze(2)
  194. scores = []
  195. t = 0
  196. init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32)
  197. init_inp[:,0,0] = 1 #Set the <BEG> token
  198. if self.gpu:
  199. cur_inp = Variable(torch.from_numpy(init_inp).cuda())
  200. else:
  201. cur_inp = Variable(torch.from_numpy(init_inp))
  202. cur_h = None
  203. while t < 50:
  204. if cur_h:
  205. g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
  206. else:
  207. g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
  208. g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
  209. g_ext = g_str_s.unsqueeze(3)
  210. cur_cond_str_score = self.cond_str_out(
  211. self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext)
  212. + self.cond_str_out_col(col_ext)).squeeze()
  213. for b, num in enumerate(x_len):
  214. if num < max_x_len:
  215. cur_cond_str_score[b, :, num:] = -100
  216. scores.append(cur_cond_str_score)
  217. _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1)
  218. ans_tok = ans_tok_var.data.cpu()
  219. data = torch.zeros(B*4, self.max_tok_num).scatter_(
  220. 1, ans_tok.unsqueeze(1), 1)
  221. if self.gpu: #To one-hot
  222. cur_inp = Variable(data.cuda())
  223. else:
  224. cur_inp = Variable(data)
  225. cur_inp = cur_inp.unsqueeze(1)
  226. t += 1
  227. cond_str_score = torch.stack(scores, 2)
  228. for b, num in enumerate(x_len):
  229. if num < max_x_len:
  230. cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM]
  231. cond_score = (cond_num_score,
  232. cond_col_score, cond_op_score, cond_str_score)
  233. return cond_score