본문 바로가기

& 프로그래밍/& 머신러닝

Callback 함수 활용하기

Callback 함수는 머신러닝 학습을 진행중일때 자주 사용하는 함수이다

 

예를들어,

내가 epochs값을 임의로 15이라 할당하고 모델 학습을 진행시켰다

내가 예측한 accuracy값은 95%라서 그 이상 굳이 학습시킬 필요가 없다고 생각한다

history = model.fit(train_generator, steps_per_epoch= 8, epochs= 15, verbose = 1, callbacks=callbacks)

자, 모델을 학습시키는데 epochs값이 15로 지정되어 있기 때문에, 5번째 학습때 이미 내가 원하는 95%를 넘겼음에도 계속해서 학습을 진행시키는 모습을 볼 수 있다

 

학습에 일정한 패턴이 예상되거나, 내가 원하는 예측수준에 도달하면 학습을 종료시킬 수 있도록 하는 기능이

바로 오늘 소개할 'Callback'함수이다

 

class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy')>0.95 :
          print( "Reached 99% accuracy so cancelling training!" )
          self.model.stop_training = True

callbacks = myCallback()

위 코드가 callback 코드이다

 

하나하나 풀어보자면

class myCallback(tf.keras.callbacks.Callback):

class에서 myCallback이라는 함수를 불러온 후 괄호안에 텐서플로 모델을 입력해준다

  def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy')>0.95 :
          print( "Reached 95% accuracy so cancelling training!" )
          self.model.stop_training = True

callbacks = myCallback()

내가 on-epoch-end라는 또다른 함수를 만들어(def)주었다

만약 학습중인 데이터 로그 중 'accuracy'값이 95%를 넘게된다면

"목표치 95%에 도달했습니다. 학습을 종료합니다!"를 출력하고 학습을 실제로 종료토록 하는 조건을 넣었다

 

그래서 실제로 해당 콜백 함수를 사용하여 학습을 시키다보면

이렇게 목표치에 도달하면서 학습을 종료시키는걸 볼 수 있다

 

나는 지금까지 Convolutional neural network를 학습시키면서 사용해보았고,

epochs값을 100 이상 주었을때, 콜백함수를 활용하여 긴 기다림 없이 원하는 학습을 시킬 수 있었다