tfmnist - juedaiyuer/researchNote GitHub Wiki

MNIST tensorflow

1. input_data.py

代码位置:

tensorflow/examples/tutorials/mnist/input_data.py

下载和读取MNIST数据

代码中有如下的一段

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

这段代码说明了源代码的位置,该文件‘mnist’才是下载并且读取数据的灵魂

代码位置

tensorflow/contrib/learn/python/learn/datasets/mnist.py

为了方便随时读阅,将源代码粘贴如下,总共200多行代码

  """Functions for downloading and reading MNIST data."""

  from __future__ import absolute_import
  from __future__ import division
  from __future__ import print_function

  import gzip

  import numpy
  from six.moves import xrange  # pylint: disable=redefined-builtin

  from tensorflow.contrib.learn.python.learn.datasets import base
  from tensorflow.python.framework import dtypes

  SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'


  def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]


  def extract_images(f):
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth].

    Args:
      f: A file object that can be passed into a gzip reader.

    Returns:
      data: A 4D uint8 numpy array [index, y, x, depth].

    Raises:
      ValueError: If the bytestream does not start with 2051.

    """
    print('Extracting', f.name)
    with gzip.GzipFile(fileobj=f) as bytestream:
      magic = _read32(bytestream)
      if magic != 2051:
        raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                        (magic, f.name))
      num_images = _read32(bytestream)
      rows = _read32(bytestream)
      cols = _read32(bytestream)
      buf = bytestream.read(rows * cols * num_images)
      data = numpy.frombuffer(buf, dtype=numpy.uint8)
      data = data.reshape(num_images, rows, cols, 1)
      return data


  def dense_to_one_hot(labels_dense, num_classes):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot


  def extract_labels(f, one_hot=False, num_classes=10):
    """Extract the labels into a 1D uint8 numpy array [index].

    Args:
      f: A file object that can be passed into a gzip reader.
      one_hot: Does one hot encoding for the result.
      num_classes: Number of classes for the one hot encoding.

    Returns:
      labels: a 1D uint8 numpy array.

    Raises:
      ValueError: If the bystream doesn't start with 2049.
    """
    print('Extracting', f.name)
    with gzip.GzipFile(fileobj=f) as bytestream:
      magic = _read32(bytestream)
      if magic != 2049:
        raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                        (magic, f.name))
      num_items = _read32(bytestream)
      buf = bytestream.read(num_items)
      labels = numpy.frombuffer(buf, dtype=numpy.uint8)
      if one_hot:
        return dense_to_one_hot(labels, num_classes)
      return labels


  class DataSet(object):

    def __init__(self,
                images,
                labels,
                fake_data=False,
                one_hot=False,
                dtype=dtypes.float32,
                reshape=True):
      """Construct a DataSet.
      one_hot arg is used only if fake_data is true.  `dtype` can be either
      `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
      `[0, 1]`.
      """
      dtype = dtypes.as_dtype(dtype).base_dtype
      if dtype not in (dtypes.uint8, dtypes.float32):
        raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
                        dtype)
      if fake_data:
        self._num_examples = 10000
        self.one_hot = one_hot
      else:
        assert images.shape[0] == labels.shape[0], (
            'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
        self._num_examples = images.shape[0]

        # Convert shape from [num examples, rows, columns, depth]
        # to [num examples, rows*columns] (assuming depth == 1)
        if reshape:
          assert images.shape[3] == 1
          images = images.reshape(images.shape[0],
                                  images.shape[1] * images.shape[2])
        if dtype == dtypes.float32:
          # Convert from [0, 255] -> [0.0, 1.0].
          images = images.astype(numpy.float32)
          images = numpy.multiply(images, 1.0 / 255.0)
      self._images = images
      self._labels = labels
      self._epochs_completed = 0
      self._index_in_epoch = 0

    @property
    def images(self):
      return self._images

    @property
    def labels(self):
      return self._labels

    @property
    def num_examples(self):
      return self._num_examples

    @property
    def epochs_completed(self):
      return self._epochs_completed

    def next_batch(self, batch_size, fake_data=False):
      """Return the next `batch_size` examples from this data set."""
      if fake_data:
        fake_image = [1] * 784
        if self.one_hot:
          fake_label = [1] + [0] * 9
        else:
          fake_label = 0
        return [fake_image for _ in xrange(batch_size)], [
            fake_label for _ in xrange(batch_size)
        ]
      start = self._index_in_epoch
      self._index_in_epoch += batch_size
      if self._index_in_epoch > self._num_examples:
        # Finished epoch
        self._epochs_completed += 1
        # Shuffle the data
        perm = numpy.arange(self._num_examples)
        numpy.random.shuffle(perm)
        self._images = self._images[perm]
        self._labels = self._labels[perm]
        # Start next epoch
        start = 0
        self._index_in_epoch = batch_size
        assert batch_size <= self._num_examples
      end = self._index_in_epoch
      return self._images[start:end], self._labels[start:end]


  def read_data_sets(train_dir,
                    fake_data=False,
                    one_hot=False,
                    dtype=dtypes.float32,
                    reshape=True,
                    validation_size=5000):
    if fake_data:

      def fake():
        return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)

      train = fake()
      validation = fake()
      test = fake()
      return base.Datasets(train=train, validation=validation, test=test)

    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

    local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                    SOURCE_URL + TRAIN_IMAGES)
    with open(local_file, 'rb') as f:
      train_images = extract_images(f)

    local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                    SOURCE_URL + TRAIN_LABELS)
    with open(local_file, 'rb') as f:
      train_labels = extract_labels(f, one_hot=one_hot)

    local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                    SOURCE_URL + TEST_IMAGES)
    with open(local_file, 'rb') as f:
      test_images = extract_images(f)

    local_file = base.maybe_download(TEST_LABELS, train_dir,
                                    SOURCE_URL + TEST_LABELS)
    with open(local_file, 'rb') as f:
      test_labels = extract_labels(f, one_hot=one_hot)

    if not 0 <= validation_size <= len(train_images):
      raise ValueError(
          'Validation size should be between 0 and {}. Received: {}.'
          .format(len(train_images), validation_size))

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
    validation = DataSet(validation_images,
                        validation_labels,
                        dtype=dtype,
                        reshape=reshape)
    test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

    return base.Datasets(train=train, validation=validation, test=test)


  def load_mnist(train_dir='MNIST-data'):
    return read_data_sets(train_dir)

2. mnist.py

代码位置:

tensorflow/examples/tutorials/mnist/mnist.py

2.1 源代码

"""Builds the MNIST network.

Implements the inference/loss/training pattern for model building.

1. inference() - Builds the model as far as is required for running the network
forward to make predictions.
2. loss() - Adds to the inference model the layers required to generate loss.
3. training() - Adds to the loss model the Ops required to generate and
apply gradients.

This file is used by the various "fully_connected_*.py" files and not meant to
be run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import tensorflow as tf

# The MNIST dataset has 10 classes, representing the digits 0 through 9.
NUM_CLASSES = 10

# The MNIST images are always 28x28 pixels.
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE


def inference(images, hidden1_units, hidden2_units):
  """Build the MNIST model up to where it may be used for inference.

  Args:
    images: Images placeholder, from inputs().
    hidden1_units: Size of the first hidden layer.
    hidden2_units: Size of the second hidden layer.

  Returns:
    softmax_linear: Output tensor with the computed logits.
  """
  # Hidden 1
  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  # Hidden 2
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  # Linear
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases
  return logits


def loss(logits, labels):
  """Calculates the loss from the logits and the labels.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size].

  Returns:
    loss: Loss tensor of type float.
  """
  labels = tf.to_int64(labels)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='xentropy')
  return tf.reduce_mean(cross_entropy, name='xentropy_mean')


def training(loss, learning_rate):
  """Sets up the training Ops.

  Creates a summarizer to track the loss over time in TensorBoard.

  Creates an optimizer and applies the gradients to all trainable variables.

  The Op returned by this function is what must be passed to the
  `sess.run()` call to cause the model to train.

  Args:
    loss: Loss tensor, from loss().
    learning_rate: The learning rate to use for gradient descent.

  Returns:
    train_op: The Op for training.
  """
  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)
  # Create the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # Create a variable to track the global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # Use the optimizer to apply the gradients that minimize the loss
  # (and also increment the global step counter) as a single training step.
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op


def evaluation(logits, labels):
  """Evaluate the quality of the logits at predicting the label.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size], with values in the
      range [0, NUM_CLASSES).

  Returns:
    A scalar int32 tensor with the number of examples (out of batch_size)
    that were predicted correctly.
  """
  # For a classifier model, we can use the in_top_k Op.
  # It returns a bool tensor with shape [batch_size] that is true for
  # the examples where the label is in the top k (here k=1)
  # of all logits for that example.
  correct = tf.nn.in_top_k(logits, labels, 1)
  # Return the number of true entries.
  return tf.reduce_sum(tf.cast(correct, tf.int32))

2.2 代码注解

2.2.1 inference

构建了图表,相当于一个方程式,给定输入,就能计算输出。

with tf.name_scope('hidden1')给该层命名,该层的每一个变量便都有一个前缀

比如说hidden1层里面的weights和biases。便是hiiden1/weights和hidden1/biases

2.2.2 loss

2.2.3 trainning

tf.summary.scalar('loss', loss)

SummaryWriter可以向事件文件(events file)中生成汇总值(summary values).

3. fully_connected_feed.py

代码位置:

tensorflow/examples/tutorials/mnist/fully_connected_feed.py

只需要直接运行就可以开始训练

python fully_connected_feed.py

3.1 代码注解

mnist.py定义了mnist的方法;input_data.py定义了数据的下载和操作

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

运行该段代码时的参数解读

name 是当前模块名,当模块被直接运行时模块名为 main 。这句话的意思就是,当模块被直接运行时,以下代码块将被运行,当模块是被导入时,代码块不被运行。

if __name__ == '__main__'

命令行解析代码如下

parser = argparse.ArgumentParser()
parser.add_argument(
    '--learning_rate',
    type=float,
    default=0.01,
    help='Initial learning rate.'
)
parser.add_argument(
    '--max_steps',
    type=int,
    default=2000,
    help='Number of steps to run trainer.'
)
parser.add_argument(
    '--hidden1',
    type=int,
    default=128,
    help='Number of units in hidden layer 1.'
)
parser.add_argument(
    '--hidden2',
    type=int,
    default=32,
    help='Number of units in hidden layer 2.'
)
parser.add_argument(
    '--batch_size',
    type=int,
    default=100,
    help='Batch size.  Must divide evenly into the dataset sizes.'
)
parser.add_argument(
    '--input_data_dir',
    type=str,
    default='/tmp/tensorflow/mnist/input_data',
    help='Directory to put the input data.'
)
parser.add_argument(
    '--log_dir',
    type=str,
    default='/tmp/tensorflow/mnist/logs/fully_connected_feed',
    help='Directory to put the log data.'
)
parser.add_argument(
    '--fake_data',
    default=False,
    help='If true, uses fake data for unit testing.',
    action='store_true'
)

运行下面的代码

$ python2 fully_connected_feed.py --help

结果如下

usage: fully_connected_feed.py [-h] [--learning_rate LEARNING_RATE] [--max_steps MAX_STEPS] [--hidden1 HIDDEN1] [--hidden2 HIDDEN2] [--batch_size BATCH_SIZE] [--input_data_dir INPUT_DATA_DIR] [--log_dir LOG_DIR] [--fake_data]

optional arguments: -h, --help show this help message and exit --learning_rate LEARNING_RATE Initial learning rate. --max_steps MAX_STEPS Number of steps to run trainer. --hidden1 HIDDEN1 Number of units in hidden layer 1. --hidden2 HIDDEN2 Number of units in hidden layer 2. --batch_size BATCH_SIZE Batch size. Must divide evenly into the dataset sizes. --input_data_dir INPUT_DATA_DIR Directory to put the input data. --log_dir LOG_DIR Directory to put the log data. --fake_data If true, uses fake data for unit testing.

source