티스토리 뷰

반응형
[딥러닝] tensorflow에서 학습된 모델 저장, 불러오기

저장하기

저장할 때에는 2가지 방법이 있다. 모델 구조와 weight를 한 번에 저장하는 방법과, weight만을 저장하는 방법이다. 상황에 따라 필요한 방법을 사용하면 된다.

모델을 통째로 저장할 때에는 디렉토리 경로를 지정해주면 해당 경로에 모델이 저장된다. weight만을 저장할 때에는 아래와 같이 확장자 지정 없이 경로를 지정해 주면 weigh값만 저장이 된다.

# 1. 모델 통째로 저장
model.save('./my_model')

# 2. weight만 저장
model.save_weights('./my_model/epoch_001')

# 3. callbacks를 사용하여 저장
# 체크포인트 경로 지정({}변수 에 epoch 값이 들어가도록 epoch fotmat을 포함시켜야 한다.)
checkpoint_path = "./checkpoints/epoch_{epoch:03d}.ckpt"

# 체크포인트 콜백 만들기
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
																								 period=1, # 1개의 epoch마다 저장
                                                 verbose=1)
model.fit(train_x, train_y,
          epochs=10, callbacks=[cp_callback],
          verbose=1)

불러오기

모델 전체를 가져오는 건 저장된 경로만 알려주면 된다. weight만 불러오는 건 이미 model network가 구성이 되어 있어야 하고, load_weight 함수를 통해 weight만 update해준다.

# 1. 모델 통째로 불러오기
model = keras.models.load_model('./my_model')

# 2. weight만 불러오기
model = Model()
model.load_weight('./my_model/epoch_001')

# 3. 위 콜백으로 저장된 것 중 latest모델 가져오기
latest = tf.train.latest_checkpoints('./checkpoints')
model.load_weight(latest)

참고자료

모델 저장과 복원 | TensorFlow Core
32/32 - 0s - loss: 0.4108 - accuracy: 0.0890 복원된 모델의 정확도: 8.90% WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used.
https://www.tensorflow.org/tutorials/keras/save_and_load?hl=ko
반응형
댓글
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday