{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "`The NPS Chat Corpus, which was demonstrated in 1, consists of over 10,000 posts from instant messaging sessions. These posts have all been labeled with one of 15 dialogue act types, such as \"Statement,\" \"Emotion,\" \"ynQuestion\", and \"Continuer.\"\n", "`\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T14:12:11.395128Z", "start_time": "2021-03-24T14:12:11.389313Z" } }, "outputs": [], "source": [ "import ssl\n", "\n", "try:\n", " _create_unverified_https_context = ssl._create_unverified_context\n", "except AttributeError:\n", " pass\n", "else:\n", " ssl._create_default_https_context = _create_unverified_https_context\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T14:12:25.704220Z", "start_time": "2021-03-24T14:12:11.399153Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package punkt to /Users/subir/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n", "[nltk_data] Downloading package nps_chat to /Users/subir/nltk_data...\n", "[nltk_data] Package nps_chat is already up-to-date!\n" ] } ], "source": [ "import nltk\n", "nltk.download('punkt')\n", "nltk.download('nps_chat')\n", "posts = nltk.corpus.nps_chat.xml_posts()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T14:12:25.720510Z", "start_time": "2021-03-24T14:12:25.709038Z" } }, "outputs": [], "source": [ "def dialogue_act_features(post):\n", " features = {}\n", " for word in nltk.word_tokenize(post):\n", " features['contains({})'.format(word.lower())] = True\n", " return features" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T14:12:28.495548Z", "start_time": "2021-03-24T14:12:25.731278Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6685606060606061\n" ] } ], "source": [ "featuresets = [(dialogue_act_features(post.text), post.get('class')) for post in posts]\n", "size = int(len(featuresets) * 0.1)\n", "train_set, test_set = featuresets[size:], featuresets[:size]\n", "classifier = nltk.NaiveBayesClassifier.train(train_set)\n", "print(nltk.classify.accuracy(classifier, test_set))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T14:12:28.941035Z", "start_time": "2021-03-24T14:12:28.498937Z" } }, "outputs": [], "source": [ "# save the model to disk\n", "import pickle\n", "filename = '../models/question_classification.sav'\n", "pickle.dump(classifier, open(filename, 'wb'))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T15:37:32.956836Z", "start_time": "2021-03-24T15:37:31.985215Z" } }, "outputs": [], "source": [ "# load the model from disk\n", "loaded_model = pickle.load(open(filename, 'rb'))\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-03-24T15:38:01.080846Z", "start_time": "2021-03-24T15:38:01.068742Z" } }, "outputs": [ { "data": { "text/plain": [ "'whQuestion'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_model.classify(dialogue_act_features(\"how are you\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "base" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 2 }