aggregator_predict.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. '''
  4. @File : aggregator_predict.py
  5. @Time : 2019/07/07 23:42:10
  6. @Author : Liuyuqi
  7. @Version : 1.0
  8. @Contact : liuyuqi.gov@msn.cn
  9. @License : (C)Copyright 2019
  10. @Desc : None
  11. '''
  12. import json
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from torch.autograd import Variable
  17. import numpy as np
  18. from sqlnet.model.modules.net_utils import run_lstm, col_name_encode
  19. class AggPredictor(nn.Module):
  20. def __init__(self, N_word, N_h, N_depth, use_ca):
  21. super(AggPredictor, self).__init__()
  22. self.use_ca = use_ca
  23. self.agg_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  24. if use_ca:
  25. print ("Using column attention on aggregator predicting")
  26. self.agg_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
  27. self.agg_att = nn.Linear(N_h, N_h)
  28. else:
  29. print ("Not using column attention on aggregator predicting")
  30. self.agg_att = nn.Linear(N_h, 1)
  31. self.agg_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), nn.Linear(N_h, 6))
  32. self.softmax = nn.Softmax(dim=-1)
  33. self.agg_out_K = nn.Linear(N_h, N_h)
  34. self.col_out_col = nn.Linear(N_h, N_h)
  35. def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None,
  36. col_len=None, col_num=None, gt_sel=None, gt_sel_num=None):
  37. B = len(x_emb_var)
  38. max_x_len = max(x_len)
  39. e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc)
  40. h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
  41. col_emb = []
  42. for b in range(B):
  43. cur_col_emb = torch.stack([e_col[b,x] for x in gt_sel[b]] + [e_col[b,0]] * (4-len(gt_sel[b])))
  44. col_emb.append(cur_col_emb)
  45. col_emb = torch.stack(col_emb)
  46. att_val = torch.matmul(self.agg_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() # .transpose(1,2))
  47. for idx, num in enumerate(x_len):
  48. if num < max_x_len:
  49. att_val[idx, num:] = -100
  50. att = self.softmax(att_val.view(B*4, -1)).view(B, 4, -1)
  51. K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
  52. agg_score = self.agg_out(self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze()
  53. return agg_score