Text Classification using CNN - Yicong-Huang/Wildfires GitHub Wiki
🚘 How to use CNN to do classification on tweets text?
Contents
- Data Labeling Standard
- Dataset Preparation
- TextCNN Module Architecture
- Word Embedding
- Training & Testing
- Performance
- Author
Data Labeling Standard [TO BE IMPROVED]
For each tweets record, it needs three labeler: Label1
, Label2
, and Judge
. The label should be 0
, 1
or 2
.
0
represents for true
, 1
represents for false
, and 2
represents for not sure
.
-
When should a labeler label a record to be true?
If a tweet is talking about a real wildfire or some wildfire smoke in recent time(not history), and within a close distance, then we label this record to be true. Here follows some examples:- "Wildfire just to the north of us Windy as hell #sonomacounty #kron4 #wildfire"
- "Biggest threat to homes this am is wildfire coming down from Annadel State Park to White Oak. Calm winds mean fire slow moving. @CBSSF"
- "Oregons Wildfire right now is hella fucking up washingtons Sky like I can almost stare at the sun and not be affected."
-
When should a labeler label a record to be false?
If a tweet is not talking about a real wildfire or smoke, or is talking about historical wildfires or wildfires very far away, then we label it to be false. Here follows some examples:- "Wow, Wildfire is such a good song @ddlovato"
- "BREAKING: Illegal Muslim From Iran Arrested For Starting California Wildfire"
-
When should a labeler label a record to be not sure?
If a tweet seems to talking about a real wildfire but the expression is equivocal. You can't tell the exact information like when and where the wildfire happens. Here follows some examples:- "Wildfire...."
- "It's so hazy out here in Hood River. I wonder if it's smoke from the wildfires? I thought they were pretty far away... A new wildfire?"
Dataset Preparation
For more details of building dataset and dataloader
, please checkout to yutong-text-cnn
branch and check Wildfires/yutong_nlp/textcnn_mine/mydatasets_embed.py
.
training dataset
Use this sql to select 468 tweets records from database which are labeled to be true(0).
SELECT text from records where label1 = 0
Use this sql to select 532 tweets records from database which are labeled to be false(1).
SELECT text from records where label1 = 1
So in all the training dataset contains 1000 different tweet records.
testing dataset
Use similar sql statement to select 528 tweets records from database where label2 = 0 and label1 is null
or label2 = 1 and label1 is null
, in order to make testing data different from training data.
TextCNN Module Architecture
The architecture of CNN used to classifying tweets text follows the work in Convolutional neural networks for sentence classification.
The following code defines the architecture of CNN.
-
The parameter
vocab_len
is the length of vocabulary for word embedding andweights
is the initial weights for word embedding. -
The first layer of the network is the Embedding layer. User can decide to use single or multiple embedding layers. After embedding, each word, which is represented by index, will be transformed to a vector with length 300.
-
Then define three parallel convolutional layers, following three ReLU layers and three MaxPool layers.
-
After the final softmax layer, you can get a vector of length 2 for each tweet, indicating the probability result of the tweet to be wildfire or not wildfire related. To be detail, the probability in position of index 0 represents the probability to be real wildfire.
class CNN_Text(nn.Module):
def __init__(self, args, vocab_len, weights):
super(CNN_Text, self).__init__()
self.args = args
# parameters
V = vocab_len
D = args.embed_dim
C = 2
Ci = 1 # conv layer input dimention
Co = args.kernel_num # conv layer output dimention
Ks_str = args.kernel_sizes
Ks = [int(k) for k in Ks_str.split(',')]
# embedding layer, input_size = V, output_size = D
self.embed = nn.Embedding(V, D)
if args.glove_embed:
self.embed.weight = nn.Parameter(torch.FloatTensor(weights), requires_grad=args.glove_embed_train)
if args.multichannel:
self.embed1 = nn.Embedding(V, D)
self.embed1.weight = nn.Parameter(torch.FloatTensor(weights), requires_grad=False)
Ci = 2
# self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (1, 300)) for K in Ks])
self.dropout = nn.Dropout(args.dropout)
self.fc1 = nn.Linear(len(Ks) * Co, C)
def forward(self, input):
x = self.embed(input)
x = x.unsqueeze(1)
if self.args.multichannel:
x1 = self.embed1(input)
x1 = x1.unsqueeze(1)
x = torch.cat((x, x1), 1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
x = torch.cat(x, 1)
x = self.dropout(x)
logit = F.softmax(self.fc1(x), dim=-1)
return logit
Word Embedding
To embed each word as a vector, I download an existed embedding file GoogleNews-vectors-negative300.bin
from website, the vocabulary of which contains 3,000,000 different words. Through embedding, each word's dimension changes from [1, 1]
to [1, 300]
.
gensim_model = KeyedVectors.load_word2vec_format(self.read_path + 'GoogleNews-vectors-negative300.bin',
binary=True)
vocab = gensim_model.vocab
vocab_len = len(vocab)
weights = gensim_model.vectors
Training & Testing
For more details of training and testing process, please checkout to yutong-text-cnn
branch and check Wildfires/yutong_nlp/textcnn_mine/train.py
. And for more details of all the arguments value, please checkout to yutong-text-cnn
branch and check Wildfires/yutong_nlp/textcnn_mine/main.py
.
Here are several highlighted training parameters:
- learning rate: 0.001
- epochs: 4
- batch size: 64
- padding length: 64
- optimizer: Adam
- loss function: Cross Entropy
To train the model, use following code to call the train method:
import train
train.train(train_loader, validate_loader, test_loader, model, my_loss, weight, args)
To test the model during training process, call train.eval
method to evaluate the accuracy and loss of the model.
Performance
Model Name | Accuracy | Specificity | Precision | Recall (TPR) | F1 | False Positive Rate (FPR) |
---|---|---|---|---|---|---|
TextCNN | 0.8201 | 0.8658 | 0.8826 | 0.7845 | 0.8307 | 0.1342 |
Author
Yutong Wang / @RainyTong