Practical Deep Neural Network in a few lines of code with TensorFlow

This is using the very latest TensorFlow tf.contrib.learn API.  The documentation is extremely thin, and as of writing there are no tutorials out there similar to this.  The API may change, and I don’t actually know if I’m doing things correctly.  This is just trial and error 🙂

So let’s say you want to:

  1. Train a Deep Neural Network by using some existing data, and training it against known labels.  e.g. labeled images.
  2. Use that trained network in your app to predict y given x.  E.g. guess the best label for a given image

    And as a bonus:

  3. Have some nice graphs of how well the training is doing

I assume that the data you have in a plain array-of-arrays or numpy 2D array.  It should be easy to change this to use csv or panda etc data. A row is a single example, and a column is a particular feature (e.g. house price, or pixel intensity at a specific location)

So, without further ado, here’s the function that will be shared between your training and your app

def get_tf_model():
    """ Setup our Deep Neural Network and LOAD any existing model
          ( Our (maybe trained) model, a helper input function)
    # FEATURES is a short description for each column
    # of your data.  Change this to match your data
    FEATURES = ["x1", "x2"]
    # Set this to describe your columns.  If they are all real values,
    # you don't need to change it.
    feature_columns = [tf.contrib.layers.real_valued_column(k) for k in FEATURES]

    # Build 3 layer DNN.  You can change this however you want, or use a
    # linear regressor, or use a classifier etc.
    # NOTE:
    #   This will LOAD any existing model in the "model_dir" directory!
    # The documentation fails to mention this point as of writing
    regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_columns,
                                              hidden_units=[128, 128, 128],

    def input_fn(x_data, y_data = None):
        # Note the 'shape' parameter is to suppress a very noisy warning.
        # You can probably remove this parameter in a month or two.
        feature_cols = {k: tf.constant(x_data[:,i], shape=[len(x_data[:,i]), 1])
                        for i, k in enumerate(FEATURES)}
        if y_data is None:
            return feature_cols
        labels = tf.constant(y_data)
        return feature_cols, labels

    return regressor, input_fn

Now the code for training. Note that it’s absolutely fine to carry on training a model that is already trained, if you have some new data for it.

x_test = None
y_test = None
regressor, input_fn = get_tf_model()

def train(training_data_y, training_data_x):
    global x_test, y_test, regressor, input_fn

    if x_test is None:
        x_train = np.array(training_data_x[:-20], dtype=np.float32)
        x_test = np.array(training_data_x[-20:], dtype=np.float32)
        y_train = np.array(training_data_y[:-20], dtype=np.float32)
        y_test = np.array(training_data_y[-20:], dtype=np.float32)
        x_train = np.array(training_data_x, dtype=np.float32)
        y_train = np.array(training_data_y, dtype=np.float32)

    print("Training model ...")
    # Fit model. input_fn(x_train, y_train), steps=2000)

    ev = regressor.evaluate(input_fn=lambda: input_fn(x_test, y_test), steps=1)
    print('  -  Trained Loss: {0:f}'.format(ev["loss"]))

And now finally, the code to use this model in the app. Note that no explicit loading is needed, because it’s loading it from the model_dir

regressor, input_fn = get_tf_model()
def predict(x_data):
    return regressor.predict(input_fn=lambda:input_fn(x_data))

Isn’t that simple?

We can also view a graph of the loss etc with:

tensorboard tf_model  # or whatever you set model_dir to

Then navigating to in the browser


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s