Edge AI: BERT for Text Classification I

Text classification 是 NLP 最基本的任務,應用廣泛。例如 sentiment analysis, spam filtering, news categorization, etc. 本文聚焦 fake news detection (classifying a news as REAL or FAKE) 任務。

此處我們 detect fake news using the state-of-the-art models, 也就是 BERT. 這個 tutorial 可以 extended to really any text classification task.

The Transformer is the basic building block of most current state-of-the-art architectures of NLP. Its primary advantage is its multi-head attention mechanisms which allow for an increase in performance and significantly more parallelization than previous competing models such as recurrent neural networks.

In this tutorial, we will use pre-trained BERT, one of the most popular transformer models, and fine-tune it on fake news detection. [@chengBERTText2020]

Problem Statement

我們用 Kaggle "REAL and FAKE news dataset" (https://www.kaggle.com/nopdev/real-and-fake-news-dataset).

Dataset shape 7796×4, 1st column ID, 2nd column Title, 3rd column Text, 4th column Label. or 6335×4?

Step 1: Preprocess Dataset

The raw data 如下:[@chengLSTMText2020]

The dataset contains an arbitrary index, title, text, and the corresponding label.

import pandas as pd
from sklearn.model_selection import train_test_split

raw_data_path = './Data/news.csv'
destination_folder = './Data'

train_test_ratio = 0.10
train_valid_ratio = 0.80

first_n_words = 200

def trim_string(x):
    x = x.split(maxsplit=first_n_words)
    x = ' '.join(x[:first_n_words])
    return x

For preprocessing, 使用 Pandas and Sklearn 分為 train/valid/test csv files.
我們使用 trim_string function which will be used to cut each sentence to the first first_n_words words (e.g. 200). Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

# Read raw data
df_raw = pd.read_csv(raw_data_path)

# Prepare columns
df_raw['label'] = (df_raw['label'] == 'FAKE').astype('int')
df_raw['titletext'] = df_raw['title'] + ". " + df_raw['text']
df_raw = df_raw.reindex(columns=['label', 'title', 'text', 'titletext'])

# Drop rows with empty text
df_raw.drop( df_raw[df_raw.text.str.len() < 5].index, inplace=True)

# Trim text and titletext to first_n_words
df_raw['text'] = df_raw['text'].apply(trim_string)
df_raw['titletext'] = df_raw['titletext'].apply(trim_string) 

# Split according to label
df_real = df_raw[df_raw['label'] == 0]
df_fake = df_raw[df_raw['label'] == 1]

# Train-test split
df_real_full_train, df_real_test = train_test_split(df_real, train_size = train_test_ratio, random_state = 1)
df_fake_full_train, df_fake_test = train_test_split(df_fake, train_size = train_test_ratio, random_state = 1)

# Train-valid split
df_real_train, df_real_valid = train_test_split(df_real_full_train, train_size = train_valid_ratio, random_state = 1)
df_fake_train, df_fake_valid = train_test_split(df_fake_full_train, train_size = train_valid_ratio, random_state = 1)

# Concatenate splits of different labels
df_train = pd.concat([df_real_train, df_fake_train], ignore_index=True, sort=False)
df_valid = pd.concat([df_real_valid, df_fake_valid], ignore_index=True, sort=False)
df_test = pd.concat([df_real_test, df_fake_test], ignore_index=True, sort=False)

# Write preprocessed data
df_train.to_csv(destination_folder + '/train.csv', index=False)
df_valid.to_csv(destination_folder + '/valid.csv', index=False)
df_test.to_csv(destination_folder + '/test.csv', index=False)

Step 2: Libraries

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

Step 3: Load Dataset and Tokenizing

參考 [@chengBERTText2020]

tokenizer是一個將純文本轉換為編碼的過程,該過程不涉及將詞轉換成為詞向量,僅僅是對純文本進行分詞,並且添加[MASK]、[SEP]、[CLS] 等等標記,然後將這些詞轉換為字典索引。

如上例所示,[CLS] 一般會被放在輸入序列的最前面,而 zero padding 在之前的 Transformer 文章裡已經有非常詳細的介紹。[MASK] token 一般在 fine-tuning 或是 feature extraction 時不會用到,這邊只是為了展示預訓練階段的克漏字任務才使用的。

For the tokenizer, we use the “bert-base-uncased” version of BertTokenizer.

另外我們用 TorchText, we first create the Text Field and the Label Field. The Text Field will be used for containing the news articles and the Label is the true target. We limit each article to the first 128 tokens for BERT input. 利用 Field 中的 encoder.

最後 we create a TabularDataset from our dataset csv files using the two Fields to produce the train, validation, and test sets. Then we create Iterators to prepare them in batches.

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Model parameter
MAX_SEQ_LEN = 128
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)

# Fields

label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True,
                   fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX)
fields = [('label', label_field), ('title', text_field), ('text', text_field), ('titletext', text_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train.csv', validation='valid.csv',
                                           test='test.csv', format='CSV', fields=fields, skip_header=True)

# Iterators

train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text),
                            device=device, train=True, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text),
                            device=device, train=True, sort=True, sort_within_batch=True)
test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)

Step 4: Build Model

class BERT(nn.Module):

    def __init__(self):
        super(BERT, self).__init__()

        options_name = "bert-base-uncased"
        self.encoder = BertForSequenceClassification.from_pretrained(options_name)

    def forward(self, text, label):
        loss, text_fea = self.encoder(text, labels=label)[:2]

        return loss, text_fea

We are using the “bert-base-uncased” version of BERT, which is the smaller model trained on lower-cased English text (with 12-layer, 768-hidden, 12-heads, 110M parameters).

Step 5: Training

# Save and Load Functions

def save_checkpoint(save_path, model, valid_loss):

    if save_path == None:
        return
    
    state_dict = {'model_state_dict': model.state_dict(),
                  'valid_loss': valid_loss}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model):
    
    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']


def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):

    if save_path == None:
        return
    
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')


def load_metrics(load_path):

    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']

We use Adam optimizer and a suitable learning rate to tune BERT for 5 epochs.
We use BinaryCrossEntropy as the loss function since fake news detection is a two-class problem. Make sure the output is passed through Sigmoid before calculating the loss between the target and itself.

Step 6: Evaluation

# Evaluation Function

def evaluate(model, test_loader):
    y_pred = []
    y_true = []

    model.eval()
    with torch.no_grad():
        for (labels, title, text, titletext), _ in test_loader:

                labels = labels.type(torch.LongTensor)           
                labels = labels.to(device)
                titletext = titletext.type(torch.LongTensor)  
                titletext = titletext.to(device)
                output = model(titletext, labels)

                _, output = output
                y_pred.extend(torch.argmax(output, 1).tolist())
                y_true.extend(labels.tolist())
    
    print('Classification Report:')
    print(classification_report(y_true, y_pred, labels=[1,0], digits=4))
    
    cm = confusion_matrix(y_true, y_pred, labels=[1,0])
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, ax = ax, cmap='Blues', fmt="d")

    ax.set_title('Confusion Matrix')

    ax.set_xlabel('Predicted Labels')
    ax.set_ylabel('True Labels')

    ax.xaxis.set_ticklabels(['FAKE', 'REAL'])
    ax.yaxis.set_ticklabels(['FAKE', 'REAL'])
    
best_model = BERT().to(device)

load_checkpoint(destination_folder + '/model.pt', best_model)

evaluate(best_model, test_iter)

After evaluating our model, we find that our model achieves an impressive accuracy of 96.99%! 比起同樣的 LSTM model 只有 77.5%!

Reference

Cheng, Raymond. 2020a. “BERT Text Classification Using Pytorch.” Medium.
July 22, 2020.
https://towardsdatascience.com/bert-text-classification-using-pytorch-723dfb8b6b5b.

———. 2020b. “LSTM Text Classification Using Pytorch.” Medium. July 22,
2020.
https://towardsdatascience.com/lstm-text-classification-using-pytorch-2c6c657f8fc0.

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Design a site like this with WordPress.com
Get started