Finetuning Language Models for Text Classification - Patent Dataset 
This notebook was submitted by NYU student Sky Achitoff
Code 
import  torch 
from  datasets import  load_dataset 
from  transformers import  AutoTokenizer, AutoModelForSequenceClassification 
from  torch.utils.data import  TensorDataset, DataLoader 
from  torch.nn.utils.rnn import  pad_sequence 
 
from  collections import  defaultdict 
import  random 
 
import  logging 
 logging.getLogger("transformers.modeling_utils" ).setLevel(logging.ERROR) 
 logging.getLogger("transformers" ).setLevel(logging.WARNING) 
 
 
Code 
 tokenizer =  AutoTokenizer.from_pretrained('distilbert-base-uncased' ) 
 model =  AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased' , num_labels= 2 ) 
 device =  torch.device('cuda'  if  torch.cuda.is_available() else  'cpu' ) 
 model.to(device) 
 
 BATCH_SIZE =  32  
 
def  encodeText(text, max_length= 512 ): 
     encodedDict =  tokenizer.encode_plus( 
         text, 
         add_special_tokens= True , 
         truncation= True , 
         padding= 'max_length' , 
         max_length= max_length, 
         return_attention_mask= True , 
         return_tensors= 'pt'  
     ) 
 
     return  encodedDict 
 
def  encodeData(data, max_length= 512 ): 
     encodedData =  [] 
     for  example in  data: 
         text =  ' ' .join([example[section] for  section in  ['abstract' , 'claims' ]]) 
         decision =  example['decision' ] 
         if  decision ==  'ACCEPTED' : 
             label =  1  
             encodedExample =  encodeText(text, max_length= max_length) 
             encodedData.append((encodedExample['input_ids' ], encodedExample['attention_mask' ], label)) 
         elif  decision ==  'REJECTED' : 
             label =  0  
             encodedExample =  encodeText(text, max_length= max_length) 
             encodedData.append((encodedExample['input_ids' ], encodedExample['attention_mask' ], label)) 
         else : 
             continue  
     return  encodedData 
 
def  oversampleData(data): 
 
     classCounts =  {} 
     for  example in  data: 
         label =  example[2 ] 
         if  label not  in  classCounts: 
             classCounts[label] =  0  
         classCounts[label] +=  1  
 
     minCount =  min (classCounts.values()) 
     minClass =  None  
     for  label, count in  classCounts.items(): 
         if  count ==  minCount: 
             minClass =  label 
 
     oversampledData =  [] 
     for  example in  data: 
         oversampledData.append(example) 
         if  example[2 ] ==  minClass: 
             oversampledData.append(example) 
 
     return  oversampledData 
 
def  getDataLoader(oversampledTrainData, valDataset, BATCH_SIZE): 
     trainDataLoader =  DataLoader(oversampledTrainData, batch_size= BATCH_SIZE, shuffle= True ) 
     valDataLoader =  None  
      
     if  valDataset is  not  None : 
         valDataLoader =  DataLoader(valDataset, batch_size= BATCH_SIZE) 
         valDatasetTensors =  [] 
         for  x in  valDataset: 
             inputIds =  torch.tensor(x[0 ]).squeeze() 
             attentionMask =  torch.tensor(x[1 ]).squeeze() 
             label =  torch.tensor(x[2 ]) 
             valDatasetTensors.append((inputIds, attentionMask, label)) 
         valDatasetTensors =  tuple ([torch.stack(t) for  t in  zip (* valDatasetTensors)]) 
         valDataLoader =  DataLoader(TensorDataset(* valDatasetTensors), batch_size= BATCH_SIZE) 
      
     return  trainDataLoader, valDataLoader 
 
 
 
Code 
def  evaluate(model, dataloader): 
     model.eval () 
     totalCount =  0  
     totalCorrect =  0  
     tp =  0  
     fp =  0  
     fn =  0  
 
     with  torch.no_grad(): 
         for  batch in  dataloader: 
             batchInputIds =  batch[0 ].to(device) 
             batchAttentionMask =  batch[1 ].to(device) 
             batchLabels =  batch[2 ].to(device) 
 
             outputs =  model(batchInputIds, attention_mask= batchAttentionMask, labels= batchLabels) 
             logits =  outputs.logits 
             predictions =  torch.argmax(logits, dim= 1 ) 
 
             totalCount +=  batchLabels.size(0 ) 
             totalCorrect +=  torch.sum (predictions ==  batchLabels) 
 
             tp +=  ((predictions ==  1 ) &  (batchLabels ==  1 )).sum ().item() 
             fp +=  ((predictions ==  1 ) &  (batchLabels ==  0 )).sum ().item() 
             fn +=  ((predictions ==  0 ) &  (batchLabels ==  1 )).sum ().item() 
 
     accuracy =  totalCorrect /  totalCount 
      
     if  (tp +  fp) ==  0 : 
         precision =  0  
     else : 
         precision =  tp /  (tp +  fp) 
      
     if  (tp +  fn) ==  0 : 
         recall =  0  
     else : 
         recall =  tp /  (tp +  fn) 
 
     print (f'Accuracy =  { accuracy} ' ) 
     print (f'Precision =  { precision} ' ) 
     print (f'Recall =  { recall} ' ) 
 
 
 
Code 
from  pynvml import  *  
 
 
def  print_gpu_utilization(): 
     nvmlInit() 
     handle =  nvmlDeviceGetHandleByIndex(0 ) 
     info =  nvmlDeviceGetMemoryInfo(handle) 
     print (f"GPU memory occupied:  { info. used// 1024 ** 2 }  MB." ) 
 
 
def  print_summary(result): 
     print (f"Time:  { result. metrics['train_runtime' ]:.2f} " ) 
     print (f"Samples/second:  { result. metrics['train_samples_per_second' ]:.2f} " ) 
     print_gpu_utilization() 
 
 
Code 
def  main(): 
 
     datasetDict =  load_dataset('HUPD/hupd' , 
                                 name= 'sample' , 
                                 #data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",  
                                 train_filing_start_date= '2016-01-01' , 
                                 train_filing_end_date= '2016-01-21' , 
                                 val_filing_start_date= '2016-01-22' , 
                                 val_filing_end_date= '2016-01-31' ) 
 
     #datasetDict = load_dataset('HUPD/hupd', name='sample')  
     trainData =  datasetDict['train' ] 
     valData =  datasetDict['validation' ] 
     trainData =  encodeData(trainData) 
     valData =  encodeData(valData) 
     valDataloader =  getDataLoader(trainData, valData, BATCH_SIZE) 
     trainData =  oversampleData(trainData) 
     trainDataLoader, valDataLoader =  getDataLoader(trainData, valData, BATCH_SIZE) 
 
     optimizer =  torch.optim.Adam(model.parameters(), lr= 1e-5 ) 
     lossFn =  torch.nn.BCEWithLogitsLoss() 
     model.train() 
     for  epoch in  range (40 ): 
         totalLoss =  0  
         for  batch in  trainDataLoader: 
             batchInputIds =  batch[0 ].to(device) 
             batchAttentionMask =  batch[1 ].to(device) 
             batchInputIds =  batchInputIds.squeeze(1 ) 
             batchAttentionMask =  batchAttentionMask.squeeze(1 ) 
             batchLabels =  batch[2 ].to(device) 
             optimizer.zero_grad() 
             outputs =  model(batchInputIds, attention_mask= batchAttentionMask, labels= batchLabels) 
             logits =  outputs.logits 
             loss =  lossFn(logits[:, 1 ], batchLabels.float ()) 
             loss.backward() 
             optimizer.step() 
             totalLoss +=  loss.item() 
         print (f'Total loss =  { totalLoss} ' ) 
         evaluate(model, valDataLoader) 
         valAccuracy =  evaluate(model, valDataLoader) 
         print (f'Epoch  { epoch} ' ) 
 
     model.save_pretrained('Patent-Tuned-distilbert' ) 
     tokenizer.save_pretrained('Patent-Tuned-distilbert' ) 
 
 
Code 
if  __name__  ==  '__main__' : 
     main() 
 
Loading dataset with config: PatentsConfig(name='sample', version=0.0.0, data_dir='sample', data_files=None, description='Patent data from January 2016, for debugging') 
 
Using metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710 
 
Reading metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Filtering train dataset by filing start date: 2016-01-01
Filtering train dataset by filing end date: 2016-01-21
Filtering val dataset by filing start date: 2016-01-22
Filtering val dataset by filing end date: 2016-01-31 
 
/tmp/ipykernel_496554/3803315954.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  inputIds = torch.tensor(x[0]).squeeze()
/tmp/ipykernel_496554/3803315954.py:70: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  attentionMask = torch.tensor(x[1]).squeeze() 
 
Total loss = 200.8989379107952
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Accuracy = 0.7837561368942261
Precision = 0.8339613754121527
Recall = 0.9095812997688159
Epoch 0
Total loss = 183.6143274307251
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Accuracy = 0.7722995281219482
Precision = 0.8295400663821716
Recall = 0.8987927048548677
Epoch 1
Total loss = 141.71546256542206
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Accuracy = 0.7536824941635132
Precision = 0.8283272283272284
Recall = 0.8713074749550476
Epoch 2
Total loss = 76.70414833724499
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Accuracy = 0.7825286388397217
Precision = 0.8172645739910314
Recall = 0.9362959157462112
Epoch 3
Total loss = 29.495045435614884
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Accuracy = 0.7606382966041565
Precision = 0.8184795321637427
Recall = 0.8987927048548677
Epoch 4
Total loss = 17.47426499146968
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Accuracy = 0.7426350116729736
Precision = 0.8190845240978445
Recall = 0.868738761880298
Epoch 5
Total loss = 8.936163732083514
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Accuracy = 0.7299509048461914
Precision = 0.8257786781463662
Recall = 0.8376573336758284
Epoch 6
Total loss = 8.666093362495303
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Accuracy = 0.7557283043861389
Precision = 0.8188046302858493
Recall = 0.8903159517081942
Epoch 7
Total loss = 4.926793378661387
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Accuracy = 0.7571604251861572
Precision = 0.826496138996139
Recall = 0.8797842281017211
Epoch 8
Total loss = 6.068470927071758
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Accuracy = 0.7268821597099304
Precision = 0.8236336032388664
Recall = 0.8361161058309787
Epoch 9
Total loss = 3.362050237308722
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Accuracy = 0.7538870573043823
Precision = 0.8191741813004272
Recall = 0.8867197534035448
Epoch 10
Total loss = 3.10541268682573
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Accuracy = 0.7340425848960876
Precision = 0.8202025191405286
Recall = 0.8530696121243257
Epoch 11
Total loss = 3.3352542413631454
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Accuracy = 0.7563420534133911
Precision = 0.8172851103804603
Recall = 0.8939121500128435
Epoch 12
Total loss = 8.131262744718697
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Accuracy = 0.786415696144104
Precision = 0.8170487424883152
Recall = 0.9429745697405599
Epoch 13
Total loss = 1.475531026662793
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Accuracy = 0.7729132771492004
Precision = 0.8211400876990538
Recall = 0.9139481119958901
Epoch 14
Total loss = 5.878371243306901
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Accuracy = 0.7686170339584351
Precision = 0.81572930955647
Recall = 0.9165168250706396
Epoch 15
Total loss = 2.953659498860361
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Accuracy = 0.746931254863739
Precision = 0.8228001944579485
Recall = 0.8695093758027228
Epoch 16
Total loss = 1.0596487982256804
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Accuracy = 0.7677987217903137
Precision = 0.817741935483871
Recall = 0.9116362702286155
Epoch 17
Total loss = 0.5237317680730484
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Accuracy = 0.7829378247261047
Precision = 0.8153674832962138
Recall = 0.9404058566658104
Epoch 18
Total loss = 0.5803493535640882
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Accuracy = 0.7776186466217041
Precision = 0.8164185836716283
Recall = 0.9298741330593373
Epoch 19
Total loss = 9.755138978056493
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Accuracy = 0.7581833004951477
Precision = 0.8256065337496997
Recall = 0.8828666837914205
Epoch 20
Total loss = 1.4521945840097032
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Accuracy = 0.7708674669265747
Precision = 0.8195436736575248
Recall = 0.9134343693809401
Epoch 21
Total loss = 1.2082649125077296
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Accuracy = 0.7675940990447998
Precision = 0.818728323699422
Recall = 0.9095812997688159
Epoch 22
Total loss = 4.308195063465973
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Accuracy = 0.7827332615852356
Precision = 0.8165958398568552
Recall = 0.9378371435910609
Epoch 23
Total loss = 2.21302125803777
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Accuracy = 0.7600245475769043
Precision = 0.8183520599250936
Recall = 0.8980220909324429
Epoch 24
Total loss = 0.3716375867370516
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Accuracy = 0.7608429193496704
Precision = 0.8183730715287517
Recall = 0.8993064474698176
Epoch 25 
 
--------------------------------------------------------------------------- 
KeyboardInterrupt                          Traceback (most recent call last)
Cell In[6], line 2 
      1  if  __name__  ==  ' __main__ ' :
----> 2      main ( ) 
Cell In[5], line 37 , in main () 
     35      loss. backward()
     36      optimizer. step()
---> 37      totalLoss + =  loss . item ( ) 
     38  print (f ' Total loss =  { totalLoss} ' )
     39  evaluate(model, valDataLoader)
KeyboardInterrupt :