LSTM 을 이용한 Text 의 multi-class classification 예제


뉴스 타이틀을 4개 분야로 분류


소스: https://www.kaggle.com/ngyptr/multi-class-classification-with-lstm


데이터 파일: https://www.kaggle.com/uciml/news-aggregator-dataset









# This Python 3 environment comes with many helpful analytics libraries installed

# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python

# For example, here's several helpful packages to load in 


import numpy as np # linear algebra

import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D

from keras.models import Sequential

from sklearn.feature_extraction.text import CountVectorizer

from keras.preprocessing.text import Tokenizer

from keras.preprocessing.sequence import pad_sequences

from sklearn.model_selection import train_test_split

from keras.utils.np_utils import to_categorical

from keras.callbacks import EarlyStopping


# Input data files are available in the "../input/" directory.

# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory


import os

#print(os.listdir("../input"))


# Any results you write to the current directory are saved as output.

data = pd.read_csv('uci-news-aggregator.csv', usecols=['TITLE', 'CATEGORY'])


#M class has way less data than the orthers, thus the classes are unbalanced.

data.CATEGORY.value_counts()


#I do aspire here to have balanced classes

num_of_categories = 45000

shuffled = data.reindex(np.random.permutation(data.index))

e = shuffled[shuffled['CATEGORY'] == 'e'][:num_of_categories]

b = shuffled[shuffled['CATEGORY'] == 'b'][:num_of_categories]

t = shuffled[shuffled['CATEGORY'] == 't'][:num_of_categories]

m = shuffled[shuffled['CATEGORY'] == 'm'][:num_of_categories]

concated = pd.concat([e,b,t,m], ignore_index=True)

#Shuffle the dataset

concated = concated.reindex(np.random.permutation(concated.index))

concated['LABEL'] = 0


#One-hot encode the lab

concated.loc[concated['CATEGORY'] == 'e', 'LABEL'] = 0

concated.loc[concated['CATEGORY'] == 'b', 'LABEL'] = 1

concated.loc[concated['CATEGORY'] == 't', 'LABEL'] = 2

concated.loc[concated['CATEGORY'] == 'm', 'LABEL'] = 3

print(concated['LABEL'][:10])

labels = to_categorical(concated['LABEL'], num_classes=4)

print(labels[:10])

if 'CATEGORY' in concated.keys():

    concated.drop(['CATEGORY'], axis=1)

'''

 [1. 0. 0. 0.] e

 [0. 1. 0. 0.] b

 [0. 0. 1. 0.] t

 [0. 0. 0. 1.] m

'''


n_most_common_words = 8000

max_len = 130

tokenizer = Tokenizer(num_words=n_most_common_words, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)

tokenizer.fit_on_texts(concated['TITLE'].values)

sequences = tokenizer.texts_to_sequences(concated['TITLE'].values)

word_index = tokenizer.word_index

print('Found %s unique tokens.' % len(word_index))


X = pad_sequences(sequences, maxlen=max_len)


X_train, X_test, y_train, y_test = train_test_split(X , labels, test_size=0.25, random_state=42)


epochs = 10

emb_dim = 128

batch_size = 256

labels[:2]


print((X_train.shape, y_train.shape, X_test.shape, y_test.shape))


model = Sequential()

model.add(Embedding(n_most_common_words, emb_dim, input_length=X.shape[1]))

model.add(SpatialDropout1D(0.7))

model.add(LSTM(64, dropout=0.7, recurrent_dropout=0.7))

model.add(Dense(4, activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

print(model.summary())

history = model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size,validation_split=0.2,callbacks=[EarlyStopping(monitor='val_loss',patience=7, min_delta=0.0001)])


accr = model.evaluate(X_test,y_test)

print('Test set\n  Loss: {:0.3f}\n  Accuracy: {:0.3f}'.format(accr[0],accr[1]))


import matplotlib.pyplot as plt


acc = history.history['acc']

val_acc = history.history['val_acc']

loss = history.history['loss']

val_loss = history.history['val_loss']


epochs = range(1, len(acc) + 1)


plt.plot(epochs, acc, 'bo', label='Training acc')

plt.plot(epochs, val_acc, 'b', label='Validation acc')

plt.title('Training and validation accuracy')

plt.legend()


plt.figure()


plt.plot(epochs, loss, 'bo', label='Training loss')

plt.plot(epochs, val_loss, 'b', label='Validation loss')

plt.title('Training and validation loss')

plt.legend()


plt.show()




+ Recent posts