프로그래밍 패턴 - pai-plznw4me/tensorflow_basic GitHub Wiki
1. Tensorflow Module 패턴
SubClass(tf.Module):
MainClass(tf.Module):
def __init__():
subclass1 = SubClass()
subclass2 = SubClass()
def __call__():
subclass1()
subclass2()
- Main Class 내 생성된 Sub Class instance attribute 을 하나의 변수로 관리
- Main Class 내 생성된 Sub Class instance 를 모아 하나의 변수로 관리
2. Keras Layer,Model 패턴
class Dense(Layer):
def __init__(self, out_features, **kwargs):
super().__init__(**kwargs)
self.out_features = out_features
self.w, self.b = None, None
def build(self, input_shape):
self.w = tf.Variable(tf.random.normal([input_shape[-1], self.out_features], stddev=0.1), name='w')
self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
def call(self, inputs, activation):
return activation(tf.matmul(inputs, self.w) + self.b)
- build 는 call 을 처음 부를때만 한번 호출 된다.
- call 함수는 call 함수를 호출하는 식과 같은 방법으로 호출된다.
- variable, non_trainable, 등 Model 단에서 layer 에서 생성한 변수 관리 가능
Keras polymorphism
위 inputs, outputs 들어갈수 있는 자료형이 3가지임
tuple
, list
, dict
Input Structure 에 들어온 자료형은 Shallow Structure 에 영향을 미치고 최종적으로 출력되는 dtype 에도 영향을 미침
model = Model(inputs, outputs)
model.compile(optimzier, losses=[])
fit(xs, ys)
아래 점선 끼리는 같은 dtype 이 되어야 한다.
예를 들자면 아래 코드에서 주석을 보면 Input structure
의 output 변수의 dtype
을 dict
으로 잡으면 Follow structure
의 loss
, fit.ys
변수도 모두 dict
이 되어야 한다.
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
import tensorflow as tf
import numpy as np
# multi input
input_1 = Input(shape=[784])
output_1 = Dense(units=10, activation='relu', name='output_1')(input_1)
input_2 = Input(shape=[784])
output_2 = Dense(units=10, activation='relu', name='output_2')(input_2)
# single output
inputs = (input_1, input_2)
outputs = {"output_1": output_1, "output_2": output_2} # << dtype : dict
model = Model(inputs, outputs)
x = tf.zeros(shape=[1, 784])
input_values = {"input_1": x, "input_2": x}
losses = {"output_1": 'sparse_categorical_crossentropy',
"output_2": 'sparse_categorical_crossentropy'} # << dtype : dict
model.compile('rmsprop', loss=losses)
batch_xs = np.zeros(shape=[1, 784], dtype=np.float32)
batch_ys = np.zeros(shape=[1], dtype=np.float32)
batch_ys_bucket = {'output_1': batch_ys, 'output_2': batch_ys} # << dtype : dict
model.fit([batch_xs, batch_xs], batch_ys_bucket)
# model 실행 : __call__() => numpy
print(model.predict(input_values))
# model 실행 : Model.predict() => tensor
print(model(input_values))