Finetuning Language Models for Text Classification - Patent Dataset
This notebook was submitted by NYU student Sky Achitoff
Code
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict
import random
import logging
logging.getLogger("transformers.modeling_utils" ).setLevel(logging.ERROR)
logging.getLogger("transformers" ).setLevel(logging.WARNING)
Code
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased' )
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased' , num_labels= 2 )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
model.to(device)
BATCH_SIZE = 32
def encodeText(text, max_length= 512 ):
encodedDict = tokenizer.encode_plus(
text,
add_special_tokens= True ,
truncation= True ,
padding= 'max_length' ,
max_length= max_length,
return_attention_mask= True ,
return_tensors= 'pt'
)
return encodedDict
def encodeData(data, max_length= 512 ):
encodedData = []
for example in data:
text = ' ' .join([example[section] for section in ['abstract' , 'claims' ]])
decision = example['decision' ]
if decision == 'ACCEPTED' :
label = 1
encodedExample = encodeText(text, max_length= max_length)
encodedData.append((encodedExample['input_ids' ], encodedExample['attention_mask' ], label))
elif decision == 'REJECTED' :
label = 0
encodedExample = encodeText(text, max_length= max_length)
encodedData.append((encodedExample['input_ids' ], encodedExample['attention_mask' ], label))
else :
continue
return encodedData
def oversampleData(data):
classCounts = {}
for example in data:
label = example[2 ]
if label not in classCounts:
classCounts[label] = 0
classCounts[label] += 1
minCount = min (classCounts.values())
minClass = None
for label, count in classCounts.items():
if count == minCount:
minClass = label
oversampledData = []
for example in data:
oversampledData.append(example)
if example[2 ] == minClass:
oversampledData.append(example)
return oversampledData
def getDataLoader(oversampledTrainData, valDataset, BATCH_SIZE):
trainDataLoader = DataLoader(oversampledTrainData, batch_size= BATCH_SIZE, shuffle= True )
valDataLoader = None
if valDataset is not None :
valDataLoader = DataLoader(valDataset, batch_size= BATCH_SIZE)
valDatasetTensors = []
for x in valDataset:
inputIds = torch.tensor(x[0 ]).squeeze()
attentionMask = torch.tensor(x[1 ]).squeeze()
label = torch.tensor(x[2 ])
valDatasetTensors.append((inputIds, attentionMask, label))
valDatasetTensors = tuple ([torch.stack(t) for t in zip (* valDatasetTensors)])
valDataLoader = DataLoader(TensorDataset(* valDatasetTensors), batch_size= BATCH_SIZE)
return trainDataLoader, valDataLoader
Code
def evaluate(model, dataloader):
model.eval ()
totalCount = 0
totalCorrect = 0
tp = 0
fp = 0
fn = 0
with torch.no_grad():
for batch in dataloader:
batchInputIds = batch[0 ].to(device)
batchAttentionMask = batch[1 ].to(device)
batchLabels = batch[2 ].to(device)
outputs = model(batchInputIds, attention_mask= batchAttentionMask, labels= batchLabels)
logits = outputs.logits
predictions = torch.argmax(logits, dim= 1 )
totalCount += batchLabels.size(0 )
totalCorrect += torch.sum (predictions == batchLabels)
tp += ((predictions == 1 ) & (batchLabels == 1 )).sum ().item()
fp += ((predictions == 1 ) & (batchLabels == 0 )).sum ().item()
fn += ((predictions == 0 ) & (batchLabels == 1 )).sum ().item()
accuracy = totalCorrect / totalCount
if (tp + fp) == 0 :
precision = 0
else :
precision = tp / (tp + fp)
if (tp + fn) == 0 :
recall = 0
else :
recall = tp / (tp + fn)
print (f'Accuracy = { accuracy} ' )
print (f'Precision = { precision} ' )
print (f'Recall = { recall} ' )
Code
from pynvml import *
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0 )
info = nvmlDeviceGetMemoryInfo(handle)
print (f"GPU memory occupied: { info. used// 1024 ** 2 } MB." )
def print_summary(result):
print (f"Time: { result. metrics['train_runtime' ]:.2f} " )
print (f"Samples/second: { result. metrics['train_samples_per_second' ]:.2f} " )
print_gpu_utilization()
Code
def main():
datasetDict = load_dataset('HUPD/hupd' ,
name= 'sample' ,
#data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
train_filing_start_date= '2016-01-01' ,
train_filing_end_date= '2016-01-21' ,
val_filing_start_date= '2016-01-22' ,
val_filing_end_date= '2016-01-31' )
#datasetDict = load_dataset('HUPD/hupd', name='sample')
trainData = datasetDict['train' ]
valData = datasetDict['validation' ]
trainData = encodeData(trainData)
valData = encodeData(valData)
valDataloader = getDataLoader(trainData, valData, BATCH_SIZE)
trainData = oversampleData(trainData)
trainDataLoader, valDataLoader = getDataLoader(trainData, valData, BATCH_SIZE)
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-5 )
lossFn = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range (40 ):
totalLoss = 0
for batch in trainDataLoader:
batchInputIds = batch[0 ].to(device)
batchAttentionMask = batch[1 ].to(device)
batchInputIds = batchInputIds.squeeze(1 )
batchAttentionMask = batchAttentionMask.squeeze(1 )
batchLabels = batch[2 ].to(device)
optimizer.zero_grad()
outputs = model(batchInputIds, attention_mask= batchAttentionMask, labels= batchLabels)
logits = outputs.logits
loss = lossFn(logits[:, 1 ], batchLabels.float ())
loss.backward()
optimizer.step()
totalLoss += loss.item()
print (f'Total loss = { totalLoss} ' )
evaluate(model, valDataLoader)
valAccuracy = evaluate(model, valDataLoader)
print (f'Epoch { epoch} ' )
model.save_pretrained('Patent-Tuned-distilbert' )
tokenizer.save_pretrained('Patent-Tuned-distilbert' )
Code
if __name__ == '__main__' :
main()
Loading dataset with config: PatentsConfig(name='sample', version=0.0.0, data_dir='sample', data_files=None, description='Patent data from January 2016, for debugging')
Using metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Reading metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Filtering train dataset by filing start date: 2016-01-01
Filtering train dataset by filing end date: 2016-01-21
Filtering val dataset by filing start date: 2016-01-22
Filtering val dataset by filing end date: 2016-01-31
/tmp/ipykernel_496554/3803315954.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
inputIds = torch.tensor(x[0]).squeeze()
/tmp/ipykernel_496554/3803315954.py:70: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
attentionMask = torch.tensor(x[1]).squeeze()
Total loss = 200.8989379107952
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Epoch 0
Total loss = 183.6143274307251
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Epoch 1
Total loss = 141.71546256542206
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Epoch 2
Total loss = 76.70414833724499
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Epoch 3
Total loss = 29.495045435614884
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Epoch 4
Total loss = 17.47426499146968
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Epoch 5
Total loss = 8.936163732083514
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Epoch 6
Total loss = 8.666093362495303
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Epoch 7
Total loss = 4.926793378661387
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Epoch 8
Total loss = 6.068470927071758
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Epoch 9
Total loss = 3.362050237308722
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Epoch 10
Total loss = 3.10541268682573
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Epoch 11
Total loss = 3.3352542413631454
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Epoch 12
Total loss = 8.131262744718697
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Epoch 13
Total loss = 1.475531026662793
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Epoch 14
Total loss = 5.878371243306901
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Epoch 15
Total loss = 2.953659498860361
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Epoch 16
Total loss = 1.0596487982256804
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Epoch 17
Total loss = 0.5237317680730484
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Epoch 18
Total loss = 0.5803493535640882
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Epoch 19
Total loss = 9.755138978056493
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Epoch 20
Total loss = 1.4521945840097032
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Epoch 21
Total loss = 1.2082649125077296
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Epoch 22
Total loss = 4.308195063465973
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Epoch 23
Total loss = 2.21302125803777
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Epoch 24
Total loss = 0.3716375867370516
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Epoch 25
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[6], line 2
1 if __name__ == ' __main__ ' :
----> 2 main ( )
Cell In[5], line 37 , in main ()
35 loss. backward()
36 optimizer. step()
---> 37 totalLoss + = loss . item ( )
38 print (f ' Total loss = { totalLoss} ' )
39 evaluate(model, valDataLoader)
KeyboardInterrupt :
Back to top