I am trying to implement the Sharpness Aware Minimization (SAM) method in a custom TensorFlow Training Loop. The algorithm follows these steps:
- Calculate gradient with respect to loss value
- Calculate epsilon-hat using equation in 2
- Calculate gradients at model.trainable_weights+epsilon-hat
- Update model.trainable_weights using the new gradients
My training loop is:
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam()
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
for epoch in range(epochs):
# Iterate over the batches of train dataset
for batch, (inputs, targets) in enumerate(train_ds):
with tf.GradientTape(persistent = True) as tape:
# Forward pass
predictions = model(inputs)
# Compute loss value
loss = loss_fn(targets, predictions)
# Update accuracy
train_acc_metric.update_state(targets, predictions)
# Gradient wrt model's weights
gradient = tape.gradient(loss, model.trainable_weights)
# USING EQ 2
numerator1 = list(map(lambda g: tf.math.pow(tf.math.abs(g),q-1), gradient))
numerator2 = list(map(lambda g: rho*tf.math.sign(g), gradient))
numerator = list(map(lambda n1, n2: n1*n2, numerator1,numerator2))
denominator = list(map(lambda g: tf.math.pow(tf.norm(g, ord=q),q), gradient))
epsilon = list(map(lambda n, d: n/d, numerator, denominator))
# Compute gradient at weights+epsilon
modified_weights = list(map(lambda e, w: w+e, epsilon, model.trainable_weights))
gradient = tape.gradient(loss, modified_weights)
# Update weights (ValueError:No gradients provided for any variable)
optimizer.apply_gradients(zip(gradient, model.trainable_weights))
Upon inspecting the gradient calculated in tape.gradient(loss, modified_weights)
, gradients for all layers are None. I am unable to figure out how to avoid disconnections in graph.
A similar question has already been asked here but without any answers.
Equation 2: