The method keras.Model.set_weights
seems to only take trainable weights. Non-trainable weights such as those from normalization layers cannot be imported this way. This is problematic, since in Keras with JAX backend, we can only perform stateless operations. That is; update the weights separate from the model. Therefore, before saving a model using keras.Model.save
, it is required to load in the weights. However, since non-trainable weights cannot be loaded in (and therefore not saved in the .keras
format), the saved model will underperform.
Is it possible to load/set the non-trainable weights in a Keras model? More generally, is there any way to save the complete model, including non-trainable weights when using JAX as backend?
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1742411973a4438996.html
评论列表(0条)