python - Tensorflow error unknown expected keyword during loading of trained model - Stack Overflow

I have trained a Tensorflow model and saved it on another machine and want to load it locally. When i t

I have trained a Tensorflow model and saved it on another machine and want to load it locally. When i try to load it i get an error saying Agent.init() got an unexpected keyword argument 'name'. My Agent class is the neural net i want to load but no keyword called name is passed to it. My Agent class code is:

class Agent(Model):

    """
    Defines a class for the actors used in reinforcement leraning where the states are represented as a 2-D image

    params:
    number_of_outputs: the number of outputs the neural net should return
    number_of_hidden_units: the number of hidden units in the neural net
    """

    def __init__(self,number_of_outputs: int,number_of_hidden_units: int):
        super(Agent,self).__init__()

        self.number_of_outputs = number_of_outputs

        self.number_of_hidden_units = number_of_hidden_units

        self.first_block = Sequential(
            [
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu',data_format = 'channels_last', kernel_initializer='he_normal'),
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu',data_format = 'channels_last', kernel_initializer='he_normal'),
                MaxPooling2D(pool_size=3, padding='same')

            ]
        )

        self.second_block = Sequential(
            [
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu', data_format = 'channels_last', kernel_initializer='he_normal'),

                MaxPooling2D(pool_size=3, padding='same')

            ]
        )

        self.prediction_block = Sequential(

            [
                Flatten(),
                Dense(128,activation = 'linear'),
                Dense(number_of_outputs, activation = 'linear')
            ]
        )

        self.relu = ReLU()

        self.dropout = Dropout(0.25)

        self.normalize = BatchNormalization()

    def call(self,data):
        x = self.first_block(data)
        x = self.normalize(x)
        x = self.second_block(x)
        x = self.normalize(x)

        x = self.prediction_block(x)

        return x


    def get_config(self):
        base_config = super().get_config()

        config = {
            "number_of_outputs": self.number_of_outputs,
            "number_of_hidden_units" :self.number_of_hidden_units
        }
        return {**base_config, **config}

The code used to save the model is:

def save_full_model(self, episode):
        self.model.save(f'dqn_model_{episode}.h5')

And the code used to load the saved model is:

def load_full_model(self, path_to_model):
        self.model = load_model(path_to_model, custom_objects = {'Agent':Agent} )

Note that the exact same Agent class was used during training.

I have trained a Tensorflow model and saved it on another machine and want to load it locally. When i try to load it i get an error saying Agent.init() got an unexpected keyword argument 'name'. My Agent class is the neural net i want to load but no keyword called name is passed to it. My Agent class code is:

class Agent(Model):

    """
    Defines a class for the actors used in reinforcement leraning where the states are represented as a 2-D image

    params:
    number_of_outputs: the number of outputs the neural net should return
    number_of_hidden_units: the number of hidden units in the neural net
    """

    def __init__(self,number_of_outputs: int,number_of_hidden_units: int):
        super(Agent,self).__init__()

        self.number_of_outputs = number_of_outputs

        self.number_of_hidden_units = number_of_hidden_units

        self.first_block = Sequential(
            [
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu',data_format = 'channels_last', kernel_initializer='he_normal'),
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu',data_format = 'channels_last', kernel_initializer='he_normal'),
                MaxPooling2D(pool_size=3, padding='same')

            ]
        )

        self.second_block = Sequential(
            [
                Conv2D(number_of_hidden_units, kernel_size=2, padding='same', strides=1, activation = 'relu', data_format = 'channels_last', kernel_initializer='he_normal'),

                MaxPooling2D(pool_size=3, padding='same')

            ]
        )

        self.prediction_block = Sequential(

            [
                Flatten(),
                Dense(128,activation = 'linear'),
                Dense(number_of_outputs, activation = 'linear')
            ]
        )

        self.relu = ReLU()

        self.dropout = Dropout(0.25)

        self.normalize = BatchNormalization()

    def call(self,data):
        x = self.first_block(data)
        x = self.normalize(x)
        x = self.second_block(x)
        x = self.normalize(x)

        x = self.prediction_block(x)

        return x


    def get_config(self):
        base_config = super().get_config()

        config = {
            "number_of_outputs": self.number_of_outputs,
            "number_of_hidden_units" :self.number_of_hidden_units
        }
        return {**base_config, **config}

The code used to save the model is:

def save_full_model(self, episode):
        self.model.save(f'dqn_model_{episode}.h5')

And the code used to load the saved model is:

def load_full_model(self, path_to_model):
        self.model = load_model(path_to_model, custom_objects = {'Agent':Agent} )

Note that the exact same Agent class was used during training.

Share Improve this question edited Mar 14 at 16:12 desertnaut 60.5k32 gold badges155 silver badges182 bronze badges asked Mar 12 at 23:29 borjanobborjanob 151 silver badge4 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 2

This is likely happening because you overwrite __init__ in a way that does not accept the base Model class arguments like name, but use get_config from the super class. So from the config, the class expects to have the usual Model arguments like name. This should be fixed by adding the necessary arguments to __init__, easiest with **kwargs:

def __init__(self,number_of_outputs: int,number_of_hidden_units: int, **kwargs):
        super(Agent,self).__init__(**kwargs)

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744726110a4590188.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信