word_embedding.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. class WordEmbedding(nn.Module):
  8. def __init__(self, word_emb, N_word, gpu, SQL_TOK, our_model, trainable=False):
  9. super(WordEmbedding, self).__init__()
  10. self.trainable = trainable
  11. self.N_word = N_word
  12. self.our_model = our_model
  13. self.gpu = gpu
  14. self.SQL_TOK = SQL_TOK
  15. if trainable:
  16. print ("Using trainable embedding")
  17. self.w2i, word_emb_val = word_emb
  18. self.embedding = nn.Embedding(len(self.w2i), N_word)
  19. self.embedding.weight = nn.Parameter(
  20. torch.from_numpy(word_emb_val.astype(np.float32)))
  21. else:
  22. self.word_emb = word_emb
  23. print ("Using fixed embedding")
  24. def gen_x_batch(self, q, col):
  25. B = len(q)
  26. val_embs = []
  27. val_len = np.zeros(B, dtype=np.int64)
  28. for i, (one_q, one_col) in enumerate(zip(q, col)):
  29. if self.trainable:
  30. q_val = [self.w2i.get(x,0) for x in one_q]
  31. val_embs.append([1] + q_val + [2]) #<BEG> and <END>
  32. else:
  33. # print (i)
  34. # print ([x.encode('utf-8') for x in one_q])
  35. q_val = [self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)) for x in one_q]
  36. # print (q_val)
  37. # print ("#"*60)
  38. val_embs.append([np.zeros(self.N_word, dtype=np.float32)] + q_val + [np.zeros(self.N_word, dtype=np.float32)]) #<BEG> and <END>
  39. # exit(0)
  40. val_len[i] = len(q_val) + 2
  41. max_len = max(val_len)
  42. if self.trainable:
  43. val_tok_array = np.zeros((B, max_len), dtype=np.int64)
  44. for i in range(B):
  45. for t in range(len(val_embs[i])):
  46. val_tok_array[i,t] = val_embs[i][t]
  47. val_tok = torch.from_numpy(val_tok_array)
  48. if self.gpu:
  49. val_tok = val_tok.cuda()
  50. val_tok_var = Variable(val_tok)
  51. val_inp_var = self.embedding(val_tok_var)
  52. else:
  53. val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32)
  54. for i in range(B):
  55. for t in range(len(val_embs[i])):
  56. val_emb_array[i,t,:] = val_embs[i][t]
  57. val_inp = torch.from_numpy(val_emb_array)
  58. if self.gpu:
  59. val_inp = val_inp.cuda()
  60. val_inp_var = Variable(val_inp)
  61. return val_inp_var, val_len
  62. def gen_col_batch(self, cols):
  63. ret = []
  64. col_len = np.zeros(len(cols), dtype=np.int64)
  65. names = []
  66. for b, one_cols in enumerate(cols):
  67. names = names + one_cols
  68. col_len[b] = len(one_cols)
  69. name_inp_var, name_len = self.str_list_to_batch(names)
  70. return name_inp_var, name_len, col_len
  71. def str_list_to_batch(self, str_list):
  72. B = len(str_list)
  73. val_embs = []
  74. val_len = np.zeros(B, dtype=np.int64)
  75. for i, one_str in enumerate(str_list):
  76. if self.trainable:
  77. val = [self.w2i.get(x, 0) for x in one_str]
  78. else:
  79. val = [self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)) for x in one_str]
  80. val_embs.append(val)
  81. val_len[i] = len(val)
  82. max_len = max(val_len)
  83. if self.trainable:
  84. val_tok_array = np.zeros((B, max_len), dtype=np.int64)
  85. for i in range(B):
  86. for t in range(len(val_embs[i])):
  87. val_tok_array[i,t] = val_embs[i][t]
  88. val_tok = torch.from_numpy(val_tok_array)
  89. if self.gpu:
  90. val_tok = val_tok.cuda()
  91. val_tok_var = Variable(val_tok)
  92. val_inp_var = self.embedding(val_tok_var)
  93. else:
  94. val_emb_array = np.zeros(
  95. (B, max_len, self.N_word), dtype=np.float32)
  96. for i in range(B):
  97. for t in range(len(val_embs[i])):
  98. val_emb_array[i,t,:] = val_embs[i][t]
  99. val_inp = torch.from_numpy(val_emb_array)
  100. if self.gpu:
  101. val_inp = val_inp.cuda()
  102. val_inp_var = Variable(val_inp)
  103. return val_inp_var, val_len