Finetuning Language Models for Text Classification - Patent Dataset

Author

Pantelis Monogioudis

Finetuning Language Models for Text Classification - Patent Dataset

Open In Colab

This notebook was submitted by NYU student Sky Achitoff

Code
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from collections import defaultdict
import random

import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.WARNING)
Code
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

BATCH_SIZE = 32

def encodeText(text, max_length=512):
    encodedDict = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_attention_mask=True,
        return_tensors='pt'
    )

    return encodedDict

def encodeData(data, max_length=512):
    encodedData = []
    for example in data:
        text = ' '.join([example[section] for section in ['abstract', 'claims']])
        decision = example['decision']
        if decision == 'ACCEPTED':
            label = 1
            encodedExample = encodeText(text, max_length=max_length)
            encodedData.append((encodedExample['input_ids'], encodedExample['attention_mask'], label))
        elif decision == 'REJECTED':
            label = 0
            encodedExample = encodeText(text, max_length=max_length)
            encodedData.append((encodedExample['input_ids'], encodedExample['attention_mask'], label))
        else:
            continue
    return encodedData

def oversampleData(data):

    classCounts = {}
    for example in data:
        label = example[2]
        if label not in classCounts:
            classCounts[label] = 0
        classCounts[label] += 1

    minCount = min(classCounts.values())
    minClass = None
    for label, count in classCounts.items():
        if count == minCount:
            minClass = label

    oversampledData = []
    for example in data:
        oversampledData.append(example)
        if example[2] == minClass:
            oversampledData.append(example)

    return oversampledData

def getDataLoader(oversampledTrainData, valDataset, BATCH_SIZE):
    trainDataLoader = DataLoader(oversampledTrainData, batch_size=BATCH_SIZE, shuffle=True)
    valDataLoader = None
    
    if valDataset is not None:
        valDataLoader = DataLoader(valDataset, batch_size=BATCH_SIZE)
        valDatasetTensors = []
        for x in valDataset:
            inputIds = torch.tensor(x[0]).squeeze()
            attentionMask = torch.tensor(x[1]).squeeze()
            label = torch.tensor(x[2])
            valDatasetTensors.append((inputIds, attentionMask, label))
        valDatasetTensors = tuple([torch.stack(t) for t in zip(*valDatasetTensors)])
        valDataLoader = DataLoader(TensorDataset(*valDatasetTensors), batch_size=BATCH_SIZE)
    
    return trainDataLoader, valDataLoader
Code
def evaluate(model, dataloader):
    model.eval()
    totalCount = 0
    totalCorrect = 0
    tp = 0
    fp = 0
    fn = 0

    with torch.no_grad():
        for batch in dataloader:
            batchInputIds = batch[0].to(device)
            batchAttentionMask = batch[1].to(device)
            batchLabels = batch[2].to(device)

            outputs = model(batchInputIds, attention_mask=batchAttentionMask, labels=batchLabels)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=1)

            totalCount += batchLabels.size(0)
            totalCorrect += torch.sum(predictions == batchLabels)

            tp += ((predictions == 1) & (batchLabels == 1)).sum().item()
            fp += ((predictions == 1) & (batchLabels == 0)).sum().item()
            fn += ((predictions == 0) & (batchLabels == 1)).sum().item()

    accuracy = totalCorrect / totalCount
    
    if (tp + fp) == 0:
        precision = 0
    else:
        precision = tp / (tp + fp)
    
    if (tp + fn) == 0:
        recall = 0
    else:
        recall = tp / (tp + fn)

    print(f'Accuracy = {accuracy}')
    print(f'Precision = {precision}')
    print(f'Recall = {recall}')
Code
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()
Code
def main():

    datasetDict = load_dataset('HUPD/hupd',
                                name='sample',
                                #data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
                                train_filing_start_date='2016-01-01',
                                train_filing_end_date='2016-01-21',
                                val_filing_start_date='2016-01-22',
                                val_filing_end_date='2016-01-31')

    #datasetDict = load_dataset('HUPD/hupd', name='sample')
    trainData = datasetDict['train']
    valData = datasetDict['validation']
    trainData = encodeData(trainData)
    valData = encodeData(valData)
    valDataloader = getDataLoader(trainData, valData, BATCH_SIZE)
    trainData = oversampleData(trainData)
    trainDataLoader, valDataLoader = getDataLoader(trainData, valData, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    lossFn = torch.nn.BCEWithLogitsLoss()
    model.train()
    for epoch in range(40):
        totalLoss = 0
        for batch in trainDataLoader:
            batchInputIds = batch[0].to(device)
            batchAttentionMask = batch[1].to(device)
            batchInputIds = batchInputIds.squeeze(1)
            batchAttentionMask = batchAttentionMask.squeeze(1)
            batchLabels = batch[2].to(device)
            optimizer.zero_grad()
            outputs = model(batchInputIds, attention_mask=batchAttentionMask, labels=batchLabels)
            logits = outputs.logits
            loss = lossFn(logits[:, 1], batchLabels.float())
            loss.backward()
            optimizer.step()
            totalLoss += loss.item()
        print(f'Total loss = {totalLoss}')
        evaluate(model, valDataLoader)
        valAccuracy = evaluate(model, valDataLoader)
        print(f'Epoch {epoch}')

    model.save_pretrained('Patent-Tuned-distilbert')
    tokenizer.save_pretrained('Patent-Tuned-distilbert')
Code
if __name__ == '__main__':
    main()
Loading dataset with config: PatentsConfig(name='sample', version=0.0.0, data_dir='sample', data_files=None, description='Patent data from January 2016, for debugging')
Using metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Reading metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Filtering train dataset by filing start date: 2016-01-01
Filtering train dataset by filing end date: 2016-01-21
Filtering val dataset by filing start date: 2016-01-22
Filtering val dataset by filing end date: 2016-01-31
/tmp/ipykernel_496554/3803315954.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  inputIds = torch.tensor(x[0]).squeeze()
/tmp/ipykernel_496554/3803315954.py:70: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  attentionMask = torch.tensor(x[1]).squeeze()
Total loss = 200.8989379107952
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Epoch 0
Total loss = 183.6143274307251
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Epoch 1
Total loss = 141.71546256542206
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Epoch 2
Total loss = 76.70414833724499
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Epoch 3
Total loss = 29.495045435614884
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Epoch 4
Total loss = 17.47426499146968
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Epoch 5
Total loss = 8.936163732083514
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Epoch 6
Total loss = 8.666093362495303
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Epoch 7
Total loss = 4.926793378661387
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Epoch 8
Total loss = 6.068470927071758
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Epoch 9
Total loss = 3.362050237308722
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Epoch 10
Total loss = 3.10541268682573
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Epoch 11
Total loss = 3.3352542413631454
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Epoch 12
Total loss = 8.131262744718697
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Epoch 13
Total loss = 1.475531026662793
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Epoch 14
Total loss = 5.878371243306901
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Epoch 15
Total loss = 2.953659498860361
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Epoch 16
Total loss = 1.0596487982256804
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Epoch 17
Total loss = 0.5237317680730484
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Epoch 18
Total loss = 0.5803493535640882
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Epoch 19
Total loss = 9.755138978056493
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Epoch 20
Total loss = 1.4521945840097032
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Epoch 21
Total loss = 1.2082649125077296
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Epoch 22
Total loss = 4.308195063465973
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Epoch 23
Total loss = 2.21302125803777
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Epoch 24
Total loss = 0.3716375867370516
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Epoch 25
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[6], line 2
      1 if __name__ == '__main__':
----> 2     main()

Cell In[5], line 37, in main()
     35     loss.backward()
     36     optimizer.step()
---> 37     totalLoss += loss.item()
     38 print(f'Total loss = {totalLoss}')
     39 evaluate(model, valDataLoader)

KeyboardInterrupt: 
Back to top