tf.Module - pai-plznw4me/tensorflow_basic GitHub Wiki
tf.Module μ μ¬μ©νλ©΄ tf.Variable
κ΄λ¦¬λ₯Ό μ½κ² ν΄μ€λ€.
Model λ΄ class λ³μ Model.trainables
, Model.non_trainables
, Model.variables
λ‘ λͺ¨λΈ λ΄ tf.Variable
μ κ΄λ¦¬ κ°λ₯νλ€.
μ¬μ© ν¨ν΄ 1 : static input
class Dense(tf.Module):
def __init__(self, in_features, out_features, name=None):
super().__init__(name=name)
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
class SimpleDNN(tf.Module):
def __init__(self, name):
super().__init__(name=name)
self.dense_1 = Dense(in_features=3, out_features=3)
self.dense_2 = Dense(in_features=3, out_features=2)
def __call__(self, x):
x = self.dense_1(x)
return self.dense_2(x)
SimpleDNN μμ 2κ°μ DNN class μμ μμ±λ tf.Variable λ€μ
SimpleDNN instance λ΄ variables
λ³μμμ κ΄λ¦¬ ν©λλ€.
κ·Έ μ΄μΈμλ submodules
, trainable_variable
, non_trainable_variable
attribute κ° μμ΅λλ€.
if __name__ == '__main__':
simple_dnn = SimpleDNN(name='simple_dnn')
print('Submodule')
[print(submodule) for submodule in simple_dnn.submodules]
print('Variable')
[print(variable) for variable in simple_dnn.variables]
print('Trainable variable')
[print(trainable_variable) for trainable_variable in simple_dnn.trainable_variables]
print('Non trainable variable')
[print(non_trainable_variable) for non_trainable_variable in simple_dnn.non_trainable_variables]
μ¬μ© ν¨ν΄ 2 : dynamic input
μ λ ₯ shape κ° λμ μΈ μν©μμλ μ λͺ¨λμ μ¬μ©ν μ μμ΅λλ€.
class FlexibleDenseModule(tf.Module):
def __init__(self, out_features, name=None):
super().__init__(name=name)
self.is_built = False
self.out_features = out_features
def __call__(self, x):
if not self.is_built:
self.w = tf.Variable(
tf.random.normal([x.shape[-1], self.out_features]), name='w')
self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
self.is_built = True
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
Reference
[tf.Module guide (official)](