I wrote a function to train model, but got below error for executing the train function, could not figure out why?
'<' not supported between instances of 'ExponentialDecay' and 'int'
def train(mdl, train_dataset, valid_dataset, epochs, mdlname):
print("start training ...")
initial_learning_rate = 0.01
lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True)
checkpoint_cb = ModelCheckpoint("./100.temp/saved_temp_model.h5", save_best_only=True)
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
start = time.time()
# compile model
mdl.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=lr_schedule), metrics=['accuracy'])
#mdl.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
history = mdl.fit( train_dataset,
validation_data=valid_dataset,
#batch_size=batch_size,
#shuffle=True, #False,
callbacks=[early_stop],
epochs=epochs,
verbose=2)
training_time = time.time() - start
print('total training_time:', training_time)
print("plot train loss vs valid loss...")
plt.figure(figsize = (10, 6))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Train vs Validation Loss for cnn')
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.legend(['Train loss', 'Validation loss'], loc='upper right')
plt.show()
mdl.save(mdlname)
``