Pre-trained Models with Keras in TensorFlow
Tuesday May 2, 2017
With TensorFlow 1.1, Keras is now at
tf.contrib.keras. With TensorFlow 1.2, it'll be at
tf.keras. This is great for making new models, but we also get canned models previously found outside core Keras. It's so easy to classify images!
import tensorflow as tf model = tf.contrib.keras.applications.ResNet50()
This will automatically download trained weights for a model based on Deep Residual Learning for Image Recognition. The weights are cached below your home directory, in
Convenient image tools are also included. Let's use an image of a koala from the imagen ImageNet subset.
filename = 'n01882714_4157_koala_bear.jpg' image = tf.contrib.keras.preprocessing.image.load_img( filename, target_size=(224, 224))
This model can take input images that are 224 pixels on a side, so we have to make our image that size. We're just doing it by squishing, in this case.
We'll make that into an array that the model can take as input.
import numpy as np array = tf.contrib.keras.preprocessing.image.img_to_array(image) array = np.expand_dims(array, axis=0)
Now we can classify the image!
probabilities = model.predict(array)
We have one thousand probabilities, one for each class the model knows about. To interpret the result, we can use another helpful function.
tf.contrib.keras.applications.resnet50.decode_predictions(probabilities) ## [[(u'n01882714', u'koala', 0.99466419), ## (u'n02497673', u'Madagascar_cat', 0.0013330306), ## (u'n01877812', u'wallaby', 0.00085774728), ## (u'n02137549', u'mongoose', 0.00063530984), ## (u'n02123045', u'tabby', 0.00056512095)]]
Great success! The model is highly confident that it's looking at a koala. Not bad.
It's pretty fun that this kind of super-easy access to quite good pre-trained models is now available all within the TensorFlow package. Just
pip install and go!
The thousand ImageNet categories this model knows about include some things that are commonly associated with people, but not a "person" class. Still, just for fun, what will
ResNet50 say about me?
## [[(u'n02883205', u'bow_tie', 0.3144455), ## (u'n03787032', u'mortarboard', 0.059674311), ## (u'n02992529', u'cellular_telephone', 0.049916871), ## (u'n04357314', u'sunscreen', 0.048197504), ## (u'n04350905', u'suit', 0.03481029)]]
I guess I'll take it?
The model may have been trained on the very koala picture we're testing it with. I'm okay with that. Feel free to test your own koala pictures!
There's also another function,
resnet50.preprocess_input, which in theory should help the model work better, but my tests gave seemingly worse results when using that pre-processing. It would be used like this:
array = tf.contrib.keras.applications.resnet50.preprocess_input(array)
Keras in TensorFlow also contains
xception models as well, along the same lines as
I'm working on Building TensorFlow systems from components, a workshop at OSCON 2017.