Gradient Update - pai-plznw4me/tensorflow_basic GitHub Wiki

1


def train(model, x, y, optimizer):
    with tf.GradientTape() as t:
        current_loss = crossentropy(y=y, y_hat=model(x))
        # current_loss = cee(y, y_hat=model(x))
    deltas = t.gradient(current_loss, model.trainable_variables)
    lr = 0.01
    optimizer.apply_gradients(zip(deltas, model.trainable_variables))

    return current_loss
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = train(simple_dnn, batch_xs, batch_ys, optimizer)

2


def train(model, x, y):

    # get gradient
    with tf.GradientTape() as t:
        current_loss = crossentropy(y=y, y_hat=model(x))
    deltas = t.gradient(current_loss, model.trainable_variables)

    # update weights 
    lr = 0.01
    var.assign_sub(lr * delta) for var, delta in zip(model.trainable_variables, deltas)]

    return current_loss

optimizer = [var.assign_sub(lr * delta) for var, delta in zip(model.trainable_variables, deltas)]
loss = train(simple_dnn, batch_xs, batch_ys, optimizer)