AutoGraph - lshhhhh/deep-learning-study GitHub Wiki
Why do we need graphs in TensorFlow?
- ์ฑ๋ฅ: (๊ณตํต์ sub-expressions์ ์ ๊ฑฐํ๊ฑฐ๋, pruning, kernel fusing์ ํ๋ ๋ฑ์) ๋ชจ๋ ์ข ๋ฅ์ ์ต์ ํ๋ฅผ ํ ์ ์๊ฒ ํ๋ค.
- ์ด์์ฑ(portability): ๊ณ์ฐ์ ๋ํ platform-independent ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ ๋๋ฌธ์, distributed training๊ณผ ๋ชจ๋ ์ข
๋ฅ์ ํ๊ฒฝ์์์ deployment๋ฅผ ์ฉ์ดํ๊ฒ ํ๋ค.
multiple GPU๋ TPU์์์ distributed training
TensorFlow Lite๋ฅผ ์ด์ฉํ์ฌ mobile์ด๋ IoT๊ณผ ๊ฐ์ ๋ค๋ฅธ platform์์ ๋ชจ๋ธ์ deploy
TensorFlow Functions and Graphs
TF 1.x์์๋ session.run์ ํตํด ์
๋ ฅ๊ณผ ํจ์๋ฅผ ์ง์ ํ์ฌ ํจ์๋ฅผ ํธ์ถํ์๋ค. TF 2.0์์๋ ์ธ์
๋์ tf.function()๋ฅผ ์ฌ์ฉํ๋ค. ์ด๋ ๊ฒ ํ๋ฉด TF๊ฐ ์ด ํจ์๋ฅผ ํ๋์ ๊ทธ๋ํ๋ก ์คํํ๊ธฐ ์ํด JIT ์ปดํ์ผ์ ํ๋ค. ์ด ๋ฉ์ปค๋์ฆ ๋๋ถ์ TF 2.0์์ default ๋ชจ๋๊ฐ ๊ทธ๋ํ ๋ชจ๋๊ฐ ์๋ eager ๋ชจ๋์ด์ง๋ง, ๊ทธ๋ํ์ ์ฅ์ ์ ๋ชจ๋ ๊ฐ์ ธ์ฌ ์ ์์๋ค.
(Eager Execution for Faster Prototyping, Graph for Execution)
# TF 1.x
outputs = session.run(f(placeholder), feed_dict={placeholder: input})
# TF 2.0
outputs = f(input)
๋จผ์ Python function -> TensorFlow function ๋ณํ์ ํด๋ณด์.
def cube(x):
return x ** 3
>>> cube(2)
8
>>> cube(tf.constant(2.0))
<tf.Tensor: id=18634148, shape=(), dtype=float32, numpy=8.0>
>>> tf_cube = tf.function(cube)
>>> tf_cube
<tensorflow.python.eager.def_function.Function at 0x1546fc080>
>>> tf_cube(2)
<tf.Tensor: id=18634201, shape=(), dtype=int32, numpy=8>
>>> tf_cube(tf.constant(2.0))
<tf.Tensor: id=18634211, shape=(), dtype=float32, numpy=8.0>
๋ ์ผ๋ฐ์ ์ผ๋ก๋ tf.function decorator๋ฅผ ์จ์ ๋ง๋ ๋ค.
@tf.function
def tf_cube(x):
return x ** 3
TF function์์ ๋ณธ๋์ python function๋ ์ฝ๊ฒ ๊ฐ์ ธ์ฌ ์ ์๋ค.
>>> tf_cube.python_function(2)
8
- Python ํจ์์ ๋ํ์ฌ computation graph, ์ฌ์ฉํ์ง ์๋ node๋ค pruning, expressions์ ๊ฐ๋จํ๊ฒ ๋ง๋ค๊ธฐ(e.g., 1 + 2 -> 3).. ๋ฑ๋ฑ์ ์ต์ ํํ๋ค.
- ์ต์ ํ๋ graph๊ฐ ์ค๋น๋๋ฉด, TF ํจ์๋ graph์์ ์ ์ ํ ์์๋ก, ๋ณ๋ ฌ ์ํ์ ํ ์ ์๋ค๋ฉด ๋ณ๋ ฌ๋ก๋, ํจ์จ์ ์ผ๋ก ์คํ๋๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก ๊ธฐ์กด์ Python ํจ์์ ๋นํด ๋ง์ด ๋นจ๋ผ์ง๊ณ , ํนํ ๋ณต์กํ ๊ณ์ฐ์ ์ํํ ๋ ๋ ์ ์ฉํ๋ค. Python ํจ์๋ฅผ boostํ๊ณ ์ถ๋ค๋ฉด TF function์ผ๋ก ๋ฐ๊ฟ๋ณด์.
AutoGraph and Tracing
ํ
์ํ๋ก๋ Python ์ธํฐํ๋ฆฌํฐ๊ฐ ์๋ ๋ชจ๋ฐ์ผ, C++, ์๋ฐ์คํฌ๋ฆฝํธ ๊ฐ์ ํ๊ฒฝ์์๋ ์คํ๋๋๋ฐ, ์ฌ์ฉ์๊ฐ ํ๊ฒฝ์ ๋ฐ๋ผ ์ฝ๋๋ฅผ ์ฌ์์ฑํ์ง ์๋๋ก @tf.function๋ฅผ ์ถ๊ฐํ๋ฉด AutoGraph๊ฐ ํ์ด์ฌ ์ฝ๋๋ฅผ ๋์ผํ ํ
์ํ๋ก ๊ทธ๋ํ ์ฝ๋๋ก ๋ณ๊ฒฝํจ์ผ๋ก์จ ๊ฐ๋ฅํ ์ผ์ด๋ค.
TF 2.0์์ AutoGraph๋ tf.function์ด ์ฌ์ฉ๋๋ฉด ์๋์ผ๋ก ์ ์ฉ๋๋ค.
AutoGraph process
- ๋จผ์ , ๋ชจ๋ control flow statement (for, if, while, break, continue, return)๋ฅผ ์๊ธฐ ์ํด Python ํจ์๋ฅผ ๋ถ์ํ๋ค. ํจ์ ์ฝ๋๋ฅผ ๋ถ์ํ ํ์, autoGraph๋ control flow๋ฅผ ์ ์ ํ TF operation์ผ๋ก ์
๊ทธ๋ ์ด๋๋ ํจ์๋ฅผ output์ผ๋ก ๋ด๋ณด๋ธ๋ค.
for/while->tf.while_loop(break๊ณผcontinue๋ฌธ ์ง์)if->tf.condfor _ in dataset->dataset.reduce
- ๋ค์์ผ๋ก๋ TF๋ ์ด ์ ๊ทธ๋ ์ด๋๋ ํจ์๋ฅผ callํ๋๋ฐ, argument๋ฅผ passํ๋ ๋์ , symbolic tensor(์ค์ ๊ฐ์ ์๊ณ , ์ด๋ฆ๊ณผ data type, shape๋ง ๋ค์ด์๋ tensor)๋ฅผ passํ๋ค.
- ์ด TF ํจ์๋ฅผ tracingํ์ฌ ๊ทธ๋ํ๊ฐ ๋ง๋ค์ด์ง๋ค.
์ด์ ์ด ํจ์๋ ๊ทธ๋ํ ๋ชจ๋๋ก ์คํ๋๋ค.
AutoGraph๋ ์์์ ์ค์ฒฉ๋ control flow๋ ์ง์ํ๋ค. ์ํ์ค(sequence) ๋ชจ๋ธ, ๊ฐํ ํ์ต, ๋ ์์ ์ธ ํ๋ จ ๋ฃจํ ๋ฑ ๋ณต์กํ ๋จธ์ ๋ฌ๋ ํ๋ก๊ทธ๋จ์ ๊ฐ๊ฒฐํ๋ฉด์ ๋์ ์ฑ๋ฅ์ ๋ด๋๋ก ๊ตฌํํ ์ ์๋ค.