Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

I am trying to implement the Sharpness Aware Minimization (SAM) method in a custom TensorFlow Training Loop. The algorithm follows these steps:

  • Calculate gradient with respect to loss value
  • Calculate epsilon-hat using equation in 2
  • Calculate gradients at model.trainable_weights+epsilon-hat
  • Update model.trainable_weights using the new gradients

My training loop is:

loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam()
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

for epoch in range(epochs):
    # Iterate over the batches of train dataset
    for batch, (inputs, targets) in enumerate(train_ds):
        with tf.GradientTape(persistent = True) as tape:
            # Forward pass
            predictions = model(inputs)
            # Compute loss value
            loss = loss_fn(targets, predictions)
        # Update accuracy
        train_acc_metric.update_state(targets, predictions)
        # Gradient wrt model's weights
        gradient = tape.gradient(loss, model.trainable_weights)

        # USING EQ 2 
        numerator1 = list(map(lambda g: tf.math.pow(tf.math.abs(g),q-1), gradient))
        numerator2 = list(map(lambda g: rho*tf.math.sign(g), gradient))
        numerator = list(map(lambda n1, n2: n1*n2, numerator1,numerator2))
        denominator = list(map(lambda g: tf.math.pow(tf.norm(g, ord=q),q), gradient))
        epsilon = list(map(lambda n, d: n/d, numerator, denominator))
        # Compute gradient at weights+epsilon
        modified_weights = list(map(lambda e, w: w+e, epsilon, model.trainable_weights))
        gradient = tape.gradient(loss, modified_weights)

        # Update weights (ValueError:No gradients provided for any variable)         
        optimizer.apply_gradients(zip(gradient, model.trainable_weights))

Upon inspecting the gradient calculated in tape.gradient(loss, modified_weights), gradients for all layers are None. I am unable to figure out how to avoid disconnections in graph.

A similar question has already been asked here but without any answers.

Equation 2: Equation 2


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
1.8k views
Welcome To Ask or Share your Answers For Others

1 Answer

等待大神解答

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...