상세 컨텐츠

본문 제목

Keras get_config 에러

카테고리 없음

by 리치윈드 - windFlex 2022. 7. 15. 22:02

본문

반응형

케라스 커스텀 모델 (keras Custom Model) 만들기. get_config() 에러

 

 

 

Keras Custom Model 에러 : get_config Not Implemented Error

 

 

Keras로 Custom 모델을 생성하고 학습을 하다 보면, "get_config(self)"가 정의되지 않았다는 에러를 만날 때가 있다. 

에러가 발생한 구체적인 사례는 다음과 같다.

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-87-6c20ab69a27d> in <module>
     13 
     14 
---> 15 history = custom_model.fit( X_train , y_train , batch_size = 32, epochs=nEpochs, 
     16                        validation_data=(X_test, y_test)  ,
     17                        # class_weight=class_weight,

~/anaconda3/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~/anaconda3/lib/python3.8/site-packages/keras/engine/training.py in get_config(self)
   2435 
   2436   def get_config(self):
-> 2437     raise NotImplementedError
   2438 
   2439   @classmethod

NotImplementedError:

 

그런데, 당혹스러운 것은 Keras 공식문서에서 Custom Model을 만들때는, get_config()를 구현을 필요로 하지 않기 때문이다.  "tf.keras.Model"을 상속받아 클래스를 만들고, 생성자(__init__)와 "call()"함수를 구현해 주면 된다. 다음은 Keras 공식문서에서 가이드 하고 있는 Custom Model을 만드는 템플릿이다. 

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()

 

get_config()를 구현해 주어야 하는 경우는, 모델 또는 모델의 weights를 저장해 주는 경우인데, 에러가 발생하는 곳은 Training/Evaluation 경우에 발생해서 의아해 하고 있었다. 다음은 공식 가이드의 내용이다. 

Configuring the SavedModel

New in TensorFlow 2.4 The argument save_traces has been added to model.save, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitions, so when save_traces=False, all custom objects must have defined get_config/from_config methods. When loading, the custom objects must be passed to the custom_objects argument. save_traces=False reduces the disk space used by the SavedModel and saving time.
https://keras.io/guides/serialization_and_saving/
내용은 요약하면, Tensorflow 2.4 이후 부터 "model.save" 할 때, get_config/from_config method가 정의되어 있어야 한다는 내용이다. ("save_traces=False")

잘 찾아보니, Training 할 때도 자동으로 save가 되는 경우를 하나 발견했다.  답은...

 

ModelCheckpoint

 

따라서, 해결책은 2가지로 볼 수 있다. 

1) ModelCheckPoint 사용을 중지한다. 

2) get_config함수를 구현해 준다. 

 

 

 

해결방안 1) Model Check Point 제거

 

ModelCheckPoint는 fit 함수 적용 시 checkPoint를 제외해 주면 되므로 간단한다. (대신 CheckPoint 저장이 되지 않을 것이다. )

 

일반적으로 Fit()함수가 다음과 같은 형태를 취하고 있다고 가정하면, 

history = fc_model.fit( X_train , y_train , batch_size = 32, epochs=nEpochs, 
                       validation_data=(X_test, y_test), callbacks=[checkpointer] )

아래와 같이 checkpointer를 제거해 주면 된다. 

history = fc_model.fit( X_train , y_train , batch_size = 32, epochs=nEpochs, 
                       validation_data=(X_test, y_test))

 

 

해결방안 2) get_config() 메소드 구현

두번째 방법은 Custome 모델이 제대로 저장될 수 있도록 get_config()를 구현해주는 방법이다. 

다음은 Keras공식 문서에서 제공하는 Custom Model 저장을 지원하는 Class 구현이다.  "get_config()" 주목하도록 하자.

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.hidden_units = hidden_units
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x

    def get_config(self):
        return {"hidden_units": self.hidden_units}

    @classmethod
    def from_config(cls, config):
        return cls(**config)

 

참조 : https://keras.io/guides/serialization_and_saving/

get_config()는 저장하고자 하는 노드의 값을 Dictionary형태로 반환해 주도록 구현해 주면 된다. 상기의 예제에서는 hidden_units의 파라미터만 저장하고 있다. 

 

필자는 아래와 같이 모델에서 사용하는 모든 변수를 다 정의해 주었다. 

def get_config(self):
        config = {
            'rateDropout' : self.rateDropout,
            'conv_block1': self.conv_block1,
            'conv_block2': self.conv_block2,
            'conv_block3': self.conv_block3,
            'conv_block4': self.conv_block4,
            'conv_block5': self.conv_block5,
            'conv_block6': self.conv_block6,
            'fc1': self.fc1,
        }
        return config

 

 

 

 

 

 

 

 

반응형

댓글 영역