utils.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. '''
  4. @File : utils.py
  5. @Time : 2019/06/25 04:01:19
  6. @Author : Liuyuqi
  7. @Version : 1.0
  8. @Contact : liuyuqi.gov@msn.cn
  9. @License : (C)Copyright 2019
  10. @Desc : 工具类
  11. '''
  12. import json
  13. from sqlnet.lib.dbengine import DBEngine
  14. import numpy as np
  15. from tqdm import tqdm
  16. def load_data(sql_paths, table_paths, use_small=False):
  17. '''
  18. 加载数据
  19. '''
  20. if not isinstance(sql_paths, list):
  21. sql_paths = (sql_paths, )
  22. if not isinstance(table_paths, list):
  23. table_paths = (table_paths, )
  24. sql_data = []
  25. table_data = {}
  26. for SQL_PATH in sql_paths:
  27. with open(SQL_PATH, encoding='utf-8') as inf:
  28. for idx, line in enumerate(inf):
  29. sql = json.loads(line.strip())
  30. if use_small and idx >= 1000:
  31. break
  32. sql_data.append(sql)
  33. print ("Loaded %d data from %s" % (len(sql_data), SQL_PATH))
  34. for TABLE_PATH in table_paths:
  35. with open(TABLE_PATH, encoding='utf-8') as inf:
  36. for line in inf:
  37. tab = json.loads(line.strip())
  38. table_data[tab[u'id']] = tab
  39. print ("Loaded %d data from %s" % (len(table_data), TABLE_PATH))
  40. ret_sql_data = []
  41. for sql in sql_data:
  42. if sql[u'table_id'] in table_data:
  43. ret_sql_data.append(sql)
  44. return ret_sql_data, table_data
  45. def load_dataset(toy=False, use_small=False, mode='train'):
  46. print ("Loading dataset")
  47. # dev_sql: list, dev_table : dict
  48. dev_sql, dev_table = load_data('data/val/val.json', 'data/val/val.tables.json', use_small=use_small)
  49. dev_db = 'data/val/val.db'
  50. if mode == 'train':
  51. train_sql, train_table = load_data('data/train/train.json', 'data/train/train.tables.json', use_small=use_small)
  52. train_db = 'data/train/train.db'
  53. return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
  54. elif mode == 'test':
  55. test_sql, test_table = load_data('data/test/test.json', 'data/test/test.tables.json', use_small=use_small)
  56. test_db = 'data/test/test.db'
  57. return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
  58. def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):
  59. q_seq = []
  60. col_seq = []
  61. col_num = []
  62. ans_seq = []
  63. gt_cond_seq = []
  64. vis_seq = []
  65. sel_num_seq = []
  66. for i in range(st, ed):
  67. sql = sql_data[idxes[i]]
  68. sel_num = len(sql['sql']['sel'])
  69. sel_num_seq.append(sel_num)
  70. conds_num = len(sql['sql']['conds'])
  71. q_seq.append([char for char in sql['question']])
  72. col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
  73. col_num.append(len(table_data[sql['table_id']]['header']))
  74. ans_seq.append(
  75. (
  76. len(sql['sql']['agg']),
  77. sql['sql']['sel'],
  78. sql['sql']['agg'],
  79. conds_num,
  80. tuple(x[0] for x in sql['sql']['conds']),
  81. tuple(x[1] for x in sql['sql']['conds']),
  82. sql['sql']['cond_conn_op'],
  83. ))
  84. gt_cond_seq.append(sql['sql']['conds'])
  85. vis_seq.append((sql['question'], table_data[sql['table_id']]['header']))
  86. if ret_vis_data:
  87. return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq, vis_seq
  88. else:
  89. return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq
  90. def to_batch_seq_test(sql_data, table_data, idxes, st, ed):
  91. q_seq = []
  92. col_seq = []
  93. col_num = []
  94. raw_seq = []
  95. table_ids = []
  96. for i in range(st, ed):
  97. sql = sql_data[idxes[i]]
  98. q_seq.append([char for char in sql['question']])
  99. col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
  100. col_num.append(len(table_data[sql['table_id']]['header']))
  101. raw_seq.append(sql['question'])
  102. table_ids.append(sql_data[idxes[i]]['table_id'])
  103. return q_seq, col_seq, col_num, raw_seq, table_ids
  104. def to_batch_query(sql_data, idxes, st, ed):
  105. query_gt = []
  106. table_ids = []
  107. for i in range(st, ed):
  108. sql_data[idxes[i]]['sql']['conds'] = sql_data[idxes[i]]['sql']['conds']
  109. query_gt.append(sql_data[idxes[i]]['sql'])
  110. table_ids.append(sql_data[idxes[i]]['table_id'])
  111. return query_gt, table_ids
  112. def epoch_train(model, optimizer, batch_size, sql_data, table_data):
  113. '''
  114. 训练
  115. model,optimizer
  116. batch_size=16
  117. '''
  118. model.train()
  119. perm=np.random.permutation(len(sql_data))
  120. perm = list(range(len(sql_data)))
  121. cum_loss = 0.0
  122. for st in tqdm(range(len(sql_data)//batch_size+1)): # range(41522/17=2596)
  123. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  124. st = st * batch_size
  125. q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq = to_batch_seq(sql_data, table_data, perm, st, ed)
  126. # q_seq: char-based sequence of question
  127. # gt_sel_num: number of selected columns and aggregation functions
  128. # col_seq: char-based column name
  129. # col_num: number of headers in one table
  130. # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
  131. # gt_cond_seq: ground truth of conds
  132. gt_where_seq = model.generate_gt_where_seq_test(q_seq, gt_cond_seq)
  133. gt_sel_seq = [x[1] for x in ans_seq]
  134. score = model.forward(q_seq, col_seq, col_num, gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq, gt_sel_num=gt_sel_num)
  135. # sel_num_score, sel_col_score, sel_agg_score, cond_score, cond_rela_score
  136. # compute loss
  137. loss = model.loss(score, ans_seq, gt_where_seq)
  138. cum_loss += loss.data.cpu().numpy()*(ed - st)
  139. optimizer.zero_grad()
  140. loss.backward()
  141. optimizer.step()
  142. return cum_loss / len(sql_data)
  143. def predict_test(model, batch_size, sql_data, table_data, output_path):
  144. model.eval()
  145. perm = list(range(len(sql_data)))
  146. fw = open(output_path,'w')
  147. for st in tqdm(range(len(sql_data)//batch_size+1)):
  148. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  149. st = st * batch_size
  150. q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(sql_data, table_data, perm, st, ed)
  151. score = model.forward(q_seq, col_seq, col_num)
  152. sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
  153. for sql_pred in sql_preds:
  154. fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
  155. fw.close()
  156. def epoch_acc(model, batch_size, sql_data, table_data, db_path):
  157. engine = DBEngine(db_path)
  158. model.eval()
  159. perm = list(range(len(sql_data)))
  160. badcase = 0
  161. one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
  162. for st in tqdm(range(len(sql_data)//batch_size+1)):
  163. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  164. st = st * batch_size
  165. q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
  166. to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
  167. # q_seq: char-based sequence of question
  168. # gt_sel_num: number of selected columns and aggregation functions, new added field
  169. # col_seq: char-based column name
  170. # col_num: number of headers in one table
  171. # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
  172. # gt_cond_seq: ground truth of conditions
  173. # raw_data: ori question, headers, sql
  174. query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
  175. # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
  176. raw_q_seq = [x[0] for x in raw_data] # original question
  177. try:
  178. score = model.forward(q_seq, col_seq, col_num)
  179. pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
  180. # generate predicted format
  181. one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
  182. except:
  183. badcase += 1
  184. print ('badcase', badcase)
  185. continue
  186. one_acc_num += (ed-st-one_err)
  187. tot_acc_num += (ed-st-tot_err)
  188. # Execution Accuracy
  189. for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
  190. ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op'])
  191. try:
  192. ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op'])
  193. except:
  194. ret_pred = None
  195. ex_acc_num += (ret_gt == ret_pred)
  196. return one_acc_num / len(sql_data), tot_acc_num / len(sql_data), ex_acc_num / len(sql_data)
  197. def load_word_emb(file_name):
  198. print ('Loading word embedding from %s'%file_name)
  199. f = open(file_name)
  200. ret = json.load(f)
  201. f.close()
  202. # ret = {}
  203. # with open(file_name, encoding='latin') as inf:
  204. # ret = json.load(inf)
  205. # for idx, line in enumerate(inf):
  206. # info = line.strip().split(' ')
  207. # if info[0].lower() not in ret:
  208. # ret[info[0]] = np.array([float(x) for x in info[1:]])
  209. return ret