123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- #!/usr/bin/env python
- # -*- encoding: utf-8 -*-
- '''
- @File : train.py
- @Time : 2019/06/25 04:00:53
- @Author : Liuyuqi
- @Version : 1.0
- @Contact : liuyuqi.gov@msn.cn
- @License : (C)Copyright 2019
- @Desc :
- '''
- import torch
- from sqlnet.utils import *
- from sqlnet.model.sqlnet import SQLNet
- import argparse
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--bs', type=int, default=16, help='Batch size')#
- parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
- parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')#
- parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
- parser.add_argument('--ca', action='store_true', help='Whether use column attention')#
- parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
- parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
- parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
- args = parser.parse_args()
- n_word=300
- if args.toy:
- use_small=True
- gpu=args.gpu
- batch_size=16
- else:
- use_small=False
- gpu=args.gpu
- batch_size=args.bs
- learning_rate = 1e-3
- # load dataset 加载训练数据和测试数据
- train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
- # word_emb 字典类型。
- word_emb = load_word_emb('data/char_embedding.json')
- model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
- if args.restore:
- model_path= 'saved_model/best_model'
- print ("Loading trained model from %s" % model_path)
- model.load_state_dict(torch.load(model_path))
- # used to record best score of each sub-task
- best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
- best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
- best_lf, best_lf_idx = 0.0, 0
- best_ex, best_ex_idx = 0.0, 0
- print ("#"*20+" Star to Train " + "#"*20)
- for i in range(args.epoch):# range(100)
- print ('Epoch %d'%(i+1))
- # train on the train dataset
- train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
- # evaluate on the dev dataset
- dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
- # accuracy of each sub-task
- print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
- dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
- # save the best model
- if dev_acc[1] > best_lf:
- best_lf = dev_acc[1]
- best_lf_idx = i + 1
- torch.save(model.state_dict(), 'saved_model/best_model')
- if dev_acc[2] > best_ex:
- best_ex = dev_acc[2]
- best_ex_idx = i + 1
- # record the best score of each sub-task
- if True:
- if dev_acc[0][0] > best_sn:
- best_sn = dev_acc[0][0]
- best_sn_idx = i+1
- if dev_acc[0][1] > best_sc:
- best_sc = dev_acc[0][1]
- best_sc_idx = i+1
- if dev_acc[0][2] > best_sa:
- best_sa = dev_acc[0][2]
- best_sa_idx = i+1
- if dev_acc[0][3] > best_wn:
- best_wn = dev_acc[0][3]
- best_wn_idx = i+1
- if dev_acc[0][4] > best_wc:
- best_wc = dev_acc[0][4]
- best_wc_idx = i+1
- if dev_acc[0][5] > best_wo:
- best_wo = dev_acc[0][5]
- best_wo_idx = i+1
- if dev_acc[0][6] > best_wv:
- best_wv = dev_acc[0][6]
- best_wv_idx = i+1
- if dev_acc[0][7] > best_wr:
- best_wr = dev_acc[0][7]
- best_wr_idx = i+1
- print ('Train loss = %.3f' % train_loss)
- print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
- print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
- print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
- if (i+1) % 10 == 0:
- print ('Best val acc: %s\nOn epoch individually %s'%(
- (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
- (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))
|