tf.Module - pai-plznw4me/tensorflow_basic GitHub Wiki

tf.Module 을 μ‚¬μš©ν•˜λ©΄ tf.Variable 관리λ₯Ό μ‰½κ²Œ ν•΄μ€€λ‹€. Model λ‚΄ class λ³€μˆ˜ Model.trainables, Model.non_trainables, Model.variables 둜 λͺ¨λΈ λ‚΄ tf.Variable 을 관리 κ°€λŠ₯ν•˜λ‹€.

μ‚¬μš© νŒ¨ν„΄ 1 : static input

class Dense(tf.Module):
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable(
            tf.random.normal([in_features, out_features]), name='w')
        self.b = tf.Variable(tf.zeros([out_features]), name='b')

    def __call__(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


class SimpleDNN(tf.Module):
    def __init__(self, name):
        super().__init__(name=name)

        self.dense_1 = Dense(in_features=3, out_features=3)
        self.dense_2 = Dense(in_features=3, out_features=2)

    def __call__(self, x):
        x = self.dense_1(x)
        return self.dense_2(x)

SimpleDNN μ•ˆμ— 2개의 DNN class μ—μ„œ μƒμ„±λœ tf.Variable 듀을 SimpleDNN instance λ‚΄ variables λ³€μˆ˜μ—μ„œ 관리 ν•©λ‹ˆλ‹€.

κ·Έ 이외에도 submodules, trainable_variable , non_trainable_variable attribute κ°€ μžˆμŠ΅λ‹ˆλ‹€.

if __name__ == '__main__':
    simple_dnn = SimpleDNN(name='simple_dnn')
    print('Submodule')
    [print(submodule) for submodule in simple_dnn.submodules]

    print('Variable')
    [print(variable) for variable in simple_dnn.variables]

    print('Trainable variable')
    [print(trainable_variable) for trainable_variable in simple_dnn.trainable_variables]

    print('Non trainable variable')
    [print(non_trainable_variable) for non_trainable_variable in simple_dnn.non_trainable_variables]

μ‚¬μš© νŒ¨ν„΄ 2 : dynamic input

μž…λ ₯ shape κ°€ 동적인 μƒν™©μ—μ„œλ„ μœ„ λͺ¨λ“ˆμ„ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

class FlexibleDenseModule(tf.Module):
    def __init__(self, out_features, name=None):
        super().__init__(name=name)
        self.is_built = False
        self.out_features = out_features

    def __call__(self, x):
        if not self.is_built:
            self.w = tf.Variable(
                tf.random.normal([x.shape[-1], self.out_features]), name='w')
            self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
            self.is_built = True

        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)

Reference

[tf.Module guide (official)](