AI/Deep Learning
[DL] DistilBERT 모델을 사용한 GLUE SST-2 데이터 추론
운호(Noah)
2022. 2. 15. 14:44
들어가기 앞서,
- DistilBERT 모델은 transformers 모델 중 하나이고,
- GLUE SST-2 데이터는 영화 리뷰에 대한 감정 분류 데이터셋입니다. (1:긍정, 0:부정)
- 해당 포스트에서는, DistilBERT 모델을 사용한 Batch 단위의 GLUE SST-2 데이터 추론 성능 측정 코드를 다루고 있으며,
- 각 코드별 실행 시간을 측정하기 위한 디버깅 코드가 포함되어있습니다.
- 모델과 데이터셋은 HuggingFace API 통해 사용했습니다.
예제 코드
# Install the required package
# !pip install datasets
# !pip install transformers
# Import Library
import datasets
from transformers import pipeline
from tqdm.auto import tqdm
import time
import pandas as pd
import numpy as np
import tensorflow as tf
model = None
load_model_time = None
result_df = pd.DataFrame(columns=['batch_size', 'accuracy', 'load_model_time', 'load_dataset_time','total_inference_time', 'avg_inference_time','ips', 'ips_inf'])
X_test = None
y_test = None
def load_model():
global load_model_time
global model
load_model_time = time.time()
model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", framework="tf", device=0) # devices -1 : CPU, 0 : GPU
load_model_time = time.time() - load_model_time
return model
def load_test_batch(batch_size):
global X_test
global y_test
dataset = datasets.load_dataset("glue", "sst2", split='validation')
X_test = dataset[:len(dataset)]["sentence"]
y_test = dataset[:len(dataset)]["label"]
X_test_preprocess = []
for i in range(len(X_test)):
# 영어문장을 utf-8 로 인코딩한 뒤, ascii 로 디코딩
# 이때, 오류가 발생하는 문자열은 무시하고 정상적인 문자열만 리턴
X_test_preprocess.append(X_test[i].encode('utf-8').decode('ascii', 'ignore'))
test_batch = tf.data.Dataset.from_tensor_slices((X_test_preprocess, y_test)).batch(batch_size)
return test_batch
def inference(batch_size):
# 전체 데이터에 대한 예측라벨 및 실제라벨 저장
pred_labels = []
real_labels = []
# 배치 단위에 따른 추론 시간 저장
iter_times = []
# 배치 단위의 테스트 데이터 로드
load_dataset_time = time.time()
test_batch = load_test_batch(batch_size)
load_dataset_time = time.time() - load_dataset_time
# 디버깅용 변수
success = 0
# 전체 데이터에 대한 추론 시작
inference_time = time.time()
# 전체 데이터를 배치 단위로 묶어서 사용 (반복문 한번당 배치 단위 추론 한번)
for i, (X_test_batch, y_test_batch) in enumerate(test_batch):
X_test_batch = X_test_batch.numpy().astype('str').tolist()
# 배치별 데이터 추론 시간
inference_time_per_batch = time.time()
# 배치 단위별 데이터셋 분류
y_pred_batch = model(X_test_batch)
# 배치별 데이터 추론 시간 저장
iter_times.append(time.time() - inference_time_per_batch)
# 배치 사이즈 만큼의 예측 라벨 저장
pred_labels.extend([*y_pred_batch])
# 배치 사이즈 만큼의 실제 라벨 저장
real_labels.extend([*(y_test_batch.numpy().tolist())])
# 디버깅
success += batch_size
if (success % 500 == 0):
print("{}/{}".format(success,len(test_batch)*batch_size))
inference_time = time.time() - inference_time
# 모든 데이터에 대한 배치별 추론 시간을 배열화
iter_times = np.array(iter_times)
# 모든 데이터에 대한 실제라벨과 예측라벨을 비교한 뒤, 정확도 계산
labeling = {'POSITIVE': 1, 'NEGATIVE': 0}
accuracy = len([1 for pred, real in zip(pred_labels, real_labels) if labeling[pred['label']] == real ]) / len(real_labels)
# Metric 결과 저장
global result_df
result_df = result_df.append({'batch_size' : batch_size ,
'accuracy' : accuracy,
'load_model_time' : round(load_model_time, 6),
'load_dataset_time' : round(load_dataset_time, 6),
'total_inference_time' : round(inference_time, 6),
'avg_inference_time' : round(inference_time / len(X_test), 6),
'ips' : round(len(X_test) / (load_model_time + load_dataset_time + inference_time), 6),
'ips_inf' : round(len(X_test) / inference_time, 6)}, ignore_index=True)
# 모델명
model_name = 'distilbert_sst2'
# 모델 로드
load_model()
# 배치 단위로 추론
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
inference(batch_size)
# 배치 단위 추론 결과 데이터를 저장할 경로
result_csv=f'./{model_name}_result.csv'
# 배치 단위 추론 결과 데이터 저장
result_df.to_csv(result_csv, index=False)