TensorFlow Cheat Sheet - Simsso/NIPS-2018-Adversarial-Vision-Challenge GitHub Wiki

Setting Up Your Local Environment

It is assumed that you have Python (2 or 3) already installed.

Install virtualenv

pip install virtualenv

Create a new virtual environment

cd ~
mkdir .env
virtualenv ~/.env/tf

Activate the environment

Do this step each time you're starting to work on your project:

source ~/.env/tf/bin/activate

To leave the virtual environment, use deactivate. You can also create an alias so you can simply type tf to activate the environment: Append alias tf="source ~/.env/tf/bin/activate" to your ~/.bash_profile.

Install TensorFlow

pip install tensorflow

TensorFlow Cookbook

Project Structure

The default TensorFlow project structure is listed below. Inspired by MrGemy95/Tensorflow-Project-Template and Morgan's Medium post. Created using this JSFiddle.

`-- project-name/
    |-- .dockerignore
    |-- .gitignore
    |-- project_name/
    |   |-- main.py
    |   |-- data/
    |   |   |-- data_set_name.py
    |   |   `-- loader.sh
    |   |-- model/
    |   |   |-- cnn_3_layers.py
    |   |   `-- rnn.py
    |   |-- tests/
    |   |-- trainer/
    |   |   |-- sgd.py
    |   |   `-- random_search.py
    |   `-- util/
    |       |-- logging.py
    |       `-- linear_combination.py
    |-- Dockerfile
    |-- README.md
    |-- requirements.txt
    |-- setup.py
    `-- start.sh
  • Project root contains non-source files.
  • main.py starts the model. It may read flags, such as --train or --inference to distinguish modes of operation.
  • data contains data set related files. If needed, the loader.sh script downloads the data and stores it in /tmp/data. The Python file (data_set_name.py) provides this data to the other Python files. It also contains constants such as INPUT_SIZE.
  • model contains the ML model TF graph definitions. That includes the loss function(s).
  • test contains unit tests (if we ever happen to write any, for Python scripts).
  • trainer contains ways of training the model weights. The files include the model definition scripts and define a train function.
  • util contains utility functions, e.g. for logging.

Basic Setup

This is how basic TensorFlow code for a simple model is usually structured:


# imports and hyperparameters
import tensorflow as tf

LEARNING_RATE = 0.01
NUM_STEPS = 1000
BATCH_SIZE = 50

# -- define the graph --
# here, we assume that a training example has 784 dimensions - e.g. MNIST images - and that the labels are one-hot encoded

# create placeholders for the data inputs
# (None in the shape means that it will be determined at runtime - depending on our batch size)
x = tf.placeholder(tf.float32, shape=[None, 784], name="input-x")
labels = tf.placeholder(tf.float32, shape=None, name="true-labels")

# define the network architecture (here, we use a simple 2-layer feedforward ReLU-activated neural network
hidden_layer = tf.layers.dense(inputs=x, units=200, activation=tf.nn.relu, name="hidden-layer-200")
logits = tf.layers.dense(inputs=hidden_layer, units=10, name="logits-10")
probs = tf.nn.softmax(inputs=logits, name="probabilities-softmax")

# define the loss function and the optimizer training step
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
optimizer = tf.AdamOptimizer(learning_rate=LEARNING_RATE)
train = optimizer.minimize(loss)

# -- run the session --
with tf.Session() as session:
    session.run(tf.global_variables_initializer()) # don't forget this
    
    for _ in range(NUM_STEPS):
        # in the feed_dict, we need to include values for the placeholders we defined earlier
        x_batch, y_batch = get_training_batch(BATCH_SIZE)
        session.run(train, feed_dict={x: x_batch, labels: y_batch})

    # -- evaluate the model --
    x_val, y_val = get_validation_data()
    y_predict = session.run(probs, feed_dict={x: x_val})
    accuracy = tf.reduce_mean(tf.square(y_val - y_predict))

    print("Accuracy on validation set: {:.4}%".format(accuracy * 100))

Commands

  • List uninitialized variables: sess.run(tf.report_uninitialized_variables())

Links