Keras에서 weighted cross entropy 구현하기(class_weight)

2022. 4. 26. 12:36스터디/Python

불량 데이터는 정상 데이터보다 현저히 적은 수를 가질 수 밖에 없다.

이와 같이 데이터의 밸런스가 맞지 않는 경우, under-sampling을 수행하거나, weighted loss를 필요로 한다.

 

아래 내용은 불량 여부를 판단하는 이미지 모델에 weight binary cross entropy를 만든 것이다.

이 때 keras을 활용하였으며, 모델은 ResNet50을 구현되었다.

 

1. 모델 정의

from keras.applications.resnet import ResNet50
from keras.layers import Input, Dense
from keras.models import Model

model_input = Input(shape=train_X.shape[1:])
model = ResNet50(input_tensor=model_input, include_top=False, weights=args['weight'], pooling=args['pooling'])

x = model.layers[-1].output
x = Dense(args['n_class'], activation='sigmoid')(x)

model = Model(model.input, x)

 

2. 모델 컴파일

이 때, loss는 가중 여부와 관계없이 사용할 loss를 입력한다.

from tensorflow.keras import optimizers

model.compile(loss='binary_crossentropy',
              optimizer = optimizers.Adam(learning_rate=0.001),
              metrics = ['acc']
              )

 

3. 모델 학습

기존의 학습 코드에서 class_weight가 추가되었다.

우선 sklearn의 class_weight.compute_class_weight를 사용하여 라벨의 weight를 얻는다.

그리고, fit에 대입하여 활용하는데, 이 때 값은 라벨 별 딕셔너리 형태로 활용되어야 한다.

from sklearn.utils import class_weight

weights = class_weight.compute_class_weight(class_weight = 'balanced',
                                            classes = np.unique(train_y),
                                            y = train_y)
                                            
batch_size = 16
model.fit(train_X,
 		  train_y,
          batch_size = batch_size,
          epochs=epochs,
		  class_weight={i:weights[i] for i in range(len(weights))})