select_number.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from sqlnet.model.modules.net_utils import run_lstm, col_name_encode
  8. class SelNumPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, use_ca):
  10. super(SelNumPredictor, self).__init__()
  11. self.N_h = N_h
  12. self.use_ca = use_ca
  13. self.sel_num_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  14. self.sel_num_att = nn.Linear(N_h, 1)
  15. self.sel_num_col_att = nn.Linear(N_h, 1)
  16. self.sel_num_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), nn.Linear(N_h,4))
  17. self.softmax = nn.Softmax(dim=-1)
  18. self.sel_num_col2hid1 = nn.Linear(N_h, 2 * N_h)
  19. self.sel_num_col2hid2 = nn.Linear(N_h, 2 * N_h)
  20. if self.use_ca:
  21. print ("Using column attention on select number predicting")
  22. def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
  23. B = len(x_len)
  24. max_x_len = max(x_len)
  25. # Predict the number of select part
  26. # First use column embeddings to calculate the initial hidden unit
  27. # Then run the LSTM and predict select number
  28. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
  29. col_len, self.sel_num_lstm)
  30. num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
  31. for idx, num in enumerate(col_num):
  32. if num < max(col_num):
  33. num_col_att_val[idx, num:] = -1000000
  34. num_col_att = self.softmax(num_col_att_val)
  35. K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
  36. sel_num_h1 = self.sel_num_col2hid1(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()
  37. sel_num_h2 = self.sel_num_col2hid2(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()
  38. h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,hidden=(sel_num_h1, sel_num_h2))
  39. num_att_val = self.sel_num_att(h_num_enc).squeeze()
  40. for idx, num in enumerate(x_len):
  41. if num < max_x_len:
  42. num_att_val[idx, num:] = -1000000
  43. num_att = self.softmax(num_att_val)
  44. K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
  45. sel_num_score = self.sel_num_out(K_sel_num)
  46. return sel_num_score