Callback 함수는 머신러닝 학습을 진행중일때 자주 사용하는 함수이다
예를들어,
내가 epochs값을 임의로 15이라 할당하고 모델 학습을 진행시켰다
내가 예측한 accuracy값은 95%라서 그 이상 굳이 학습시킬 필요가 없다고 생각한다
history = model.fit(train_generator, steps_per_epoch= 8, epochs= 15, verbose = 1, callbacks=callbacks)
자, 모델을 학습시키는데 epochs값이 15로 지정되어 있기 때문에, 5번째 학습때 이미 내가 원하는 95%를 넘겼음에도 계속해서 학습을 진행시키는 모습을 볼 수 있다
학습에 일정한 패턴이 예상되거나, 내가 원하는 예측수준에 도달하면 학습을 종료시킬 수 있도록 하는 기능이
바로 오늘 소개할 'Callback'함수이다
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if logs.get('accuracy')>0.95 :
print( "Reached 99% accuracy so cancelling training!" )
self.model.stop_training = True
callbacks = myCallback()
위 코드가 callback 코드이다
하나하나 풀어보자면
class myCallback(tf.keras.callbacks.Callback):
class에서 myCallback이라는 함수를 불러온 후 괄호안에 텐서플로 모델을 입력해준다
def on_epoch_end(self, epoch, logs={}):
if logs.get('accuracy')>0.95 :
print( "Reached 95% accuracy so cancelling training!" )
self.model.stop_training = True
callbacks = myCallback()
내가 on-epoch-end라는 또다른 함수를 만들어(def)주었다
만약 학습중인 데이터 로그 중 'accuracy'값이 95%를 넘게된다면
"목표치 95%에 도달했습니다. 학습을 종료합니다!"를 출력하고 학습을 실제로 종료토록 하는 조건을 넣었다
그래서 실제로 해당 콜백 함수를 사용하여 학습을 시키다보면
이렇게 목표치에 도달하면서 학습을 종료시키는걸 볼 수 있다
나는 지금까지 Convolutional neural network를 학습시키면서 사용해보았고,
epochs값을 100 이상 주었을때, 콜백함수를 활용하여 긴 기다림 없이 원하는 학습을 시킬 수 있었다
'& 프로그래밍 > & 머신러닝' 카테고리의 다른 글
자주 사용하는 파이썬 라이브러리 모음 (0) | 2021.03.05 |
---|---|
Facebook이 개발한 오픈소스 Prophet 라이브러리 사용하기 (0) | 2021.03.03 |
ImageDataGenerator [이미지 전처리] (0) | 2021.03.03 |
Categorical Data Encoding(데이터 전처리 작업) (0) | 2021.03.02 |
이미지 학습 테스트 코드 (0) | 2021.03.02 |