Gradient Update - pai-plznw4me/tensorflow_basic GitHub Wiki
1
def train(model, x, y, optimizer):
with tf.GradientTape() as t:
current_loss = crossentropy(y=y, y_hat=model(x))
# current_loss = cee(y, y_hat=model(x))
deltas = t.gradient(current_loss, model.trainable_variables)
lr = 0.01
optimizer.apply_gradients(zip(deltas, model.trainable_variables))
return current_loss
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = train(simple_dnn, batch_xs, batch_ys, optimizer)
2
def train(model, x, y):
# get gradient
with tf.GradientTape() as t:
current_loss = crossentropy(y=y, y_hat=model(x))
deltas = t.gradient(current_loss, model.trainable_variables)
# update weights
lr = 0.01
var.assign_sub(lr * delta) for var, delta in zip(model.trainable_variables, deltas)]
return current_loss
optimizer = [var.assign_sub(lr * delta) for var, delta in zip(model.trainable_variables, deltas)]
loss = train(simple_dnn, batch_xs, batch_ys, optimizer)