Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

I wrote a function to train model, but got below error for executing the train function, could not figure out why?

'<' not supported between instances of 'ExponentialDecay' and 'int'

def train(mdl, train_dataset, valid_dataset, epochs, mdlname):
    print("start training ...")
    
    
    initial_learning_rate = 0.01
    lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True)

    checkpoint_cb = ModelCheckpoint("./100.temp/saved_temp_model.h5", save_best_only=True)
    
    early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    start = time.time()

    # compile model 
    mdl.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=lr_schedule), metrics=['accuracy']) 
    #mdl.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])    
    
    history = mdl.fit(  train_dataset,
                        validation_data=valid_dataset,
                        #batch_size=batch_size,
                        #shuffle=True, #False,
                        callbacks=[early_stop],
                        epochs=epochs,
                        verbose=2)

    training_time = time.time() - start
    print('total training_time:', training_time)
    
    print("plot train loss vs valid loss...")
    plt.figure(figsize = (10, 6))
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Train vs Validation Loss for cnn')
    plt.ylabel('Loss')
    plt.xlabel('epoch')
    plt.legend(['Train loss', 'Validation loss'], loc='upper right')
    plt.show()
    
    mdl.save(mdlname)
``

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
264 views
Welcome To Ask or Share your Answers For Others

1 Answer

While for constant learning rate you can use optimizer=optimizers.RMSprop(lr=lr_constant), for schedule you should use optimizer=optimizers.RMSprop(learning_rate=lr_schedule)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...