Keras model - pai-plznw4me/tensorflow_basic GitHub Wiki
์ฌ๋ฌ๊ฐ์ Layer ์ ํ๋๋ก ๋ชจ๋๋ก ๊ด๋ฆฌํ ์ ์๋๋ก ํธ๋ฆฌํ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ class.
Keras Model์ tf.Variable ์ ๊ด๋ฆฌ ๋ฐ Keras Layer ์ ์ฝ๊ฒ ๊ด๋ฆฌ ํ ์ ์๋ ์ฌ๋ฌ๊ฐ์ง ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค.
Model ์์ฑ ๋ฐฉ๋ฒ์ ๋ฐ๋ฅธ ๋ถ๋ฅ๋ก Functional Model API
์ Custom Model API
๋ก ๋๋์ ์๋ค.
๊ธฐ๋ณธ์ฝ๋
from tensorflow.keras.layers import Layer
import tensorflow as tf
from tensorflow.keras.datasets.mnist import load_data
import numpy as np
class Dense(Layer):
def __init__(self, out_features, **kwargs):
super().__init__(**kwargs)
self.out_features = out_features
self.w, self.b = None, None
def build(self, input_shape):
self.w = tf.Variable(tf.random.normal([input_shape[-1], self.out_features], stddev=0.1), name='w')
self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
@tf.function
def call(self, inputs, activation):
return activation(tf.matmul(inputs, self.w) + self.b)
class DNN(tf.keras.Model):
def __init__(self, name=None, **kwargs):
super().__init__(**kwargs)
self.dense_1 = Dense(256)
self.dense_2 = Dense(256)
def call(self, x):
x = self.dense_1(x, activation=tf.nn.relu)
return self.dense_2(x, activation=tf.nn.relu)
if __name__ == '__main__':
(train_xs, train_ys), (test_xs, test_ys) = load_data()
train_xs = train_xs.reshape(-1, 784)
batch_ys = train_ys[:6]
batch_xs = (train_xs[:6] / 255.).astype(np.float32)
dense = Dense(256, name='dynamic')
print(dense(batch_xs, tf.nn.relu))
# DNN Model
dnn = DNN('dynamic dnn')
print(dnn(batch_xs))
1. tf.Variables ๊ด๋ฆฌ
- keras.Model ๋ด ์์ฑ๋ ๋ณ์ ๊ด๋ฆฌ
1.1 trainable weights ์ ๊ทผ ํ๊ธฐ
model = Model(input_, output_)
dnn.trainable_weights
1.2 non trainable weights ์ ๊ทผ ํ๊ธฐ
model = Model(input_, output_)
dnn.non_trainable_weights
1.3 model ์์ weights ๊ฐ ๊ฐ์ ธ์ค๊ธฐ
model = Model(input_, output_)
model.get_weights()
1.4 model ์์ weights ๊ฐ ๋ณ๊ฒฝ ํ๊ธฐ
model = Model(input_, output_)
model.set_weights(model.get_weights())
2. Keras Layer ๊ด๋ฆฌ
2.1 ๋ชจ๋ layers ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
dnn = Model(input_, output_)
dnn.layers
2.2 ํน์ layer ์ ์ด๋ฆ์ผ๋ก ์ ๊ทผํ๊ธฐ
dnn.get_layer('dense')
3. build() ์ ํ์ฉํ weight ์ด๊ธฐ ์์ฑ ๊ธฐ๋ฅ
4. get_config, from_config ์ ํ์ฉํ model clone ๊ธฐ๋ฅ
dnn = Model(input_, output_)
config = dnn.get_config()
Model.from_config(config)