We-Co

[We-Co] CheckpointManager API - TensorFlow 본문

Python/Tensorflow

[We-Co] CheckpointManager API - TensorFlow

위기의코딩맨 2021. 10. 6. 22:32
반응형

안녕하세요. 위기의코딩맨입니다.

오늘은 Tensorflow의 CheckpointManager API에 대해 알아보도록 하겠습니다.

학습도중에 끊겨버리면 어떻게 될까요..? 이만큼 암울한 일이 없습니다..

얼마전에 학습을 시키는 도중에 잠깐 인터넷이 끊겼는지 에러가 나더라구요..

나름 오랜시간 학습을 하고있었는데 다시 처음부터 해야해서 너무 답답했었는데

CheckpointManager API를 이용해서 학습 도중 내용을 저장해서 

학습이 중간에 끊겼을때, 다시 학습내용을 가져와서 그 지점부터 학습을 진행을 도와주는 API입니다.

 

먼저, 사용방법을 알아보도록 하겠습니다.

 

예제는 CNN에서 사용했던 예제로 확인해보겠습니다.

 

[ 저장 방법 ]

먼저 tf.train.Checkpoint 클래스를 사용하여 사용되는 인자 값들을 저장해야합니다.

첫번째 인자로는 전역 반복 횟수를 나타내며, model은 학습 모델을 적용하면 됩니다.  

Training_CK = tf.train.Checkpoint(step=tf.Variable(0), model = Model) 

 

두번째로, tf.CheckPointManager에 인자 값으로 위에서 선언한

Training_CK와 학습내용을 저장할 경로를 설정해야합니다.

첫번째는 Checkpoint 인자, directory는 저장폴더,

max_to_keep은 가장 최근에 학습내용 몇개를 저장할지 정하는 부분입니다. 

Training_CK_Manager = tf.train.CheckpointManager(Training_CK, directory= save_path, max_to_keep= 5)

 

다음으로, 저장하고자 하는시점에서 전역 반복횟수를 인자 값으로 선언한 

tf.train.CheckpointManager.save()를 호출합니다.

Training_CK_Manager.save(checkpoint_number = Training_CK.step)

 

마지막으로, 학습을 진행하면서 tf.train.Checkpoint의 step의 횟수를 반복적으로 1씩 증가시킵니다.

Training_CK.step.assign_add(1)

 

[ 불러오는 방법 ]

위에 부분은 저장하는 부분이고, 저장 데이터를 불러오는 부분을 알아보도록 하겠습니다.

저장된 경로를 통해서 파일을 가져옵니다.

Training_Last = tf.train.latest_checkpoint(save_path)

 

가져온 데이터를 tf.train.CheckpointManager.restore()의 인자로 가져와 넣어주도록 합니다.

Training_CK.restore(Trainig_Last )

 

 

[ 예제 ]

기본의 CNN예제를 통해서 알아보도록 하겠습니다.

 

import tensorflow as tf

 

(X_Train,Y_Train),(X_Test,Y_Test) = tf.keras.datasets.mnist.load_data()

X_Train,X_Test = X_Train.astype('float32'), X_Test.astype('float32')

X_Train, X_Test = X_Train.reshape([-1,784]), X_Test.reshape([-1784])

X_Train, X_Test = X_Train/255., X_Test/255.

Y_Train, Y_Test = tf.one_hot(Y_Train, depth= 10), tf.one_hot(Y_Test, depth= 10)

 

Train_Data= tf.data.Dataset.from_tensor_slices((X_Train,Y_Train))

Train_Data= Train_Data.repeat().shuffle(60000).batch(50)

Train_Data_Iter = iter(Train_Data)

 

class CNN(tf.keras.Model):

  def __init__(self):

    super(CNN, self).__init__()

    # Convolution layer 

    # 5 X 5 kernel Size / 32개 Filter

    self.Conv_Layer_1 = tf.keras.layers.Conv2D(filters=32, kernel_size=5, strides=1, padding='same', activation= 'relu')

    self.Pool_Layer_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides =2)

 

    # Convolustion Layer 2

    # 5 X 5 kernel Size / 64개 Filter

    self.Conv_Layer_2 = tf.keras.layers.Conv2D(filters=64, kernel_size=5, strides=1, padding='same', activation= 'relu')

    self.Pool_Layer_2 = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides =2)

 

    # Fully Connected Layer

    # 7 X 7 크기 64개의 activation map 1024개로 변환

    self.Flatten_Layer = tf.keras.layers.Flatten()

    self.FC_Layer_1 = tf.keras.layers.Dense(1024, activation='relu')

 

    # Output Layer

    # 1024개 특징을 10개의 클래스 one-hot-encoding으로 표현된 숫자로 변환

    self.Output_Layer = tf.keras.layers.Dense(10, activation=None)

  

  def call(selfx):

    #MNIST 데이터를 3차원 형태로 reshape

    X_Image = tf.reshape(x, [-128281])

    # 28X28X1을 28X28X32로 변환, -> 14X14X32

    Conv_1 = self.Conv_Layer_1(X_Image)

    Pool_1 = self.Pool_Layer_1(Conv_1)

    # 14X14X32를 14X14X64 변환, -> 7X7X64

    Conv_2 = self.Conv_Layer_2(Pool_1)

    Pool_2 = self.Pool_Layer_2(Conv_2)

 

    #7X7X64 -> 1024

    Pool_2_Flat = self.Flatten_Layer(Pool_2)

    FC_1 = self.FC_Layer_1(Pool_2_Flat)

 

    # 1024 -> 10

    Logits = self.Output_Layer(FC_1)

    Y_Pred = tf.nn.softmax(Logits)

    return Y_Pred, Logits

 

 

def Cross_Entropy_Loss(LogitsY):

  return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Logits, labels=Y))

#최적화 Adam

optimizer = tf.optimizers.Adam(1e-4)

 

# TensorBoard Summary 정보/ 저장할 폴더 경로 설정 및 FileWriter 

Train_Summary_Write = tf.summary.create_file_writer('./TensorBoard/Train')

Test_Summary_Write = tf.summary.create_file_writer('./TensorBoard/Test')

def train_step(modelxy):

  with tf.GradientTape() as tape:

    y_pred, logits = model(x)

    loss = Cross_Entropy_Loss(logits, y)

  # 매 step마다 tf.summary.scalar, tf.summary.image 텐서보드 로그를 기록합니다.

  with Train_Summary_Write.as_default():

    tf.summary.scalar('loss', loss, step=optimizer.iterations)

    x_image = tf.reshape(x, [-128281])

    tf.summary.image('training image', x_image, max_outputs=10, step=optimizer.iterations) 

  gradients = tape.gradient(loss, model.trainable_variables)

  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

 

def Compute_Accuracy(Y_PredYSummary_writer):

  correct_pr = tf.equal(tf.argmax(Y_Pred,1), tf.argmax(Y,1))

  accuracy = tf.reduce_mean(tf.cast(correct_pr, tf.float32))

 

  with Summary_writear.as_default():

    tf.summary.scalar('accuracy',accuracy, step= optimizer.iterations)

  return accuracy

 

CNN_Model = CNN()

save_path ="./SaveModel"

Training_CK = tf.train.Checkpoint(step=tf.Variable(0), model = CNN_Model) 

Training_CK_Manager = tf.train.CheckpointManager(Training_CK, directory= save_path, max_to_keep= 5)

Training_Last = tf.train.latest_checkpoint(save_path)

 

# 저장파일이 존재하면 불러와서 Restored하여 학습진행

if Training_Last:

  Training_CK.restore(Training_Last) 

  print("Epoch: %d, Train data 정확도: %f" %(Training_CK.step, train_accuracy))  

 

while int(Training_CK.step) <(1000 + 1):

  batch_x, batch_y = next(Train_Data_Iter)

  if Training_CK.step % 50 == 0:

    Training_CK_Manager.save(checkpoint_number = Training_CK.step)

    train_accuracy = Compute_Accuracy(CNN_Model(batch_x)[0],batch_y, Train_Summary_Write) 

    print("Epoch: %d, Train data 정확도: %f" %(Training_CK.step, train_accuracy))  

  

  train_step(CNN_Model, batch_x,batch_y)

  Training_CK.step.assign_add(1

 

변경된 부분은 굵은 표시를 해놨습니다.

위에서 설명했듯이, Training_CK를 선언하고, Manager에 저장, Training_Last로 데이터를 불러옵니다.

그리고 학습된 내용이 있으면,   Training_CK.restore(Training_Last)로 학습내용을 저장합니다.

 

while문은 Training_CK의 Steb의 크기를 설정하고 1000번의 학습을 진행했습니다.

Training_CK_Manager.save(checkpoint_number = Training_CK.step)를 Steb 50번마다 저장하도록 설정했습니다.

마지막 줄에 Step을 +1해주는 역할을 넣어주어 1001이 되면 while이 종료되게 됩니다.

 

SavaData

설정한 폴더안에 내용이 저장된 것을 확인할 수 있습니다.

checkpoint파일이 중요한데 이 파일을 통해서 학습된 내용의 최신파일을 가져오도록 합니다.

그럼 저장된 것은 확인 되었으니 테스트를 진행해 보겠습니다.

 

 

[ 테스트 ]

임의로 테스트를 위해 100번째까지만 돌리고 정지시켰습니다.

Test

Training_Last에 데이터가 들어가 저장된 것을 확인할 수 있습니다.

확인 후, 다시 학습을 진행 했을때, 100번째 부터 다시 돌아가는 것을 확인할 수 있습니다.

 

 

오늘은 CheckpointManager API를 통해서 학습된 내용을 저장하고 불러오는 방법에 대해서 알아보았습니다.

이제는 학습 도중에 끊겨도 걱정이 없습니다..ㅎㅎㅎ

반응형