Step by Step, A Tutorial on How to Feed Your Own Image Data to Tensorflow

This blog aims to teach you how to use your own data to train a convolutional neural network for image recognition in tensorflow. The focus will be given to how to feed your own data to the network instead of how to design the network architecture.
Before I started to survey tensorflow, me and my colleagues were using Torch7 or caffe. They both are very good machine learning tools for neural network. The original propose for turning to tensorflow is that we believe tensorflow will have a better support on mobile side, as we all know that Android) and tensorflow are both dominated by Google.
If you are really hurry with importing data to your program, visit my Github repo. to get the necessary code to generate, load and read data through tfrecords. I’m too busy to update the blog. Just clone the project and run the build_image_data.py and read_tfrecord_data.py.

Torch7 vs. Tensorflow

Althrough Facebook’s Torch7 has already had some support on Android, we still believe that it’s necessary to keep an eye on Google. After a few times’ update, tensorflow on Android was launched.
When comparing Torch7 and tensorflow, from a developer’s view, Torch7 is much more easier than tensorflow. Torch7 uses Lua, even through I don’t like script language Lua (the reason I don’t like it is its name sounds odd, they say that the name “Lua” comes from the “moon” in Portuguese), I still think that Torch7 is an excellent framework. It’s fast, it’s easy and you can use it without knowing how it works at the most of the time. I used to analyze the C code of the Torch7, I should say Torch7 should be a very fast framework and the drawback is that I think Torch7 is a little bit more resource consuming, it achieves faster training and inference speed at the cost of requiring more memory.
Another point is that Torch7’s I/O API (Application Programming Interface) is so user friendly, the only thing that you need to load an image it to call an imread function with the argument of “/path/of/your/image/data.jpg”.
But, for tensorflow, the basic tutorial didn’t tell you how to load your own data to form an efficient input data. In the official basic tutorials, they provided the way to decode the mnist dataset and cifar10 dataset, both were binary format, but our own image usually is .jpeg or .png format.
So, here I decided to summarize my experience on how to feed your own image data to tensorflow and build a simple conv. neural network.

Resources

I hope tensorflow can be as nice as Torch7 is, unfortunately it is not. I don’t even know how to code python before I started to use tensorflow. So, this is life, I got plenty of homework to do.
I assume that you have already installed the tensorflow, and you can at least run one demo no matter where you got it successfully. Then, here’s my road to tensorflow:
I learn basic python syntax from this well known book: A Byte of Python. I should say, from C to python, it’s a huge gap for me. I feel uncomfortable when I cannot explicitly use pointers and references. Python is much more easier than static programming language. The drawback, I think, there are at least two, first, the efficiency is low; second, too much APIs to remember. Python can almost finish all the functions you need, the only thing for you is to google a feasible answer.
After that, I learn numpy from this tutorial. I still cannot remember all the related APIs it mentioned. Google it when necessary. Because numpy is written by C, so the speed should be faster.
Is it the good time to go through the official documents of tensorflow? Maybe. I did go through 80% of the official tutorials from official tutorials. But it didn’t help much.
Then I tried to find some tutorials which are more basic. I highly recommend you read this article Hello, tensorflow, and this tutorial LearningTensorflow.
The last two articles are really helpful to me, they tell you how tensorflow actually works and how to correctly use some of the key op. such as placeholder or image reverse APIs.
At last, do not forget about the all mighty Github, another branch of tensorflow has a few open source network structures. They may not provide you with the state-of-the-art performance, but I believe they are good enough for you train your own solution. Powerful Inception-v3 and Resnet are all open source under tensorflow.
If you want to play with a simple demo, please click here and follow the README.
I created this simple implementation for tensorflow newbies to getting start. You can feed your own image data to the network simply by change the I/O path in python code.

A Good News

Good news is that Google released a new document for TF-Slim today (08/31/2016), there’s a few scripts for training or fine tuning the Inception-v3. I followed that document, it’s working.
So far, I suppose that is the best document for Tensorflow, because Inception-v3 is one of a few the state-of-art architectures and tensorflow is a very powerful deep learning tool.
Google open sourced Inception-resnet-v2 yesterday (02/09/2016), what can I say~ :)

Play with The Data

There’s a lot of data I/O api in python, so it’s not a difficult task. What I’m gonna do here is to write a python script to turn all the images and associated label from a folder (folder name afters the label) into a tfRecord file, then feed the tfRecord into the network.
The related skills I think maybe covers: python-numpy, python-os, python-scipy, python-pillow, protocol buffers, tensorflow.
Let’s get started on directory traversal script, this scrpit will do the directory traversal to your current directory, list all the file names or folder names, and select all the files end with .tfrecord. Return the list of names of the tfrecord files. File

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os  # handle system path and filenames
import tensorflow as tf # import tensorflow as usual

# define a function to list tfrecord files.
def list_tfrecord_file(file_list):
tfrecord_list = []
for i in range(len(file_list)):
current_file_abs_path = os.path.abspath(file_list[i])
if current_file_abs_path.endswith(".tfrecord"):
tfrecord_list.append(current_file_abs_path)
print("Found %s successfully!" % file_list[i])
else:
pass
return tfrecord_list

# Traverse current directory
def tfrecord_auto_traversal():
current_folder_filename_list = os.listdir("./") # Change this PATH to traverse other directories if you want.
if current_folder_filename_list != None:
print("%s files were found under current folder. " % len(current_folder_filename_list))
print("Please be noted that only files end with '*.tfrecord' will be load!")
tfrecord_list = list_tfrecord_file(current_folder_filename_list)
if len(tfrecord_list) != 0:
for list_index in xrange(len(tfrecord_list)):
print(tfrecord_list[list_index])
else:
print("Cannot find any tfrecord files, please check the path.")
return tfrecord_list

def main():
tfrecord_list = tfrecord_auto_traversal()

if __name__ == "__main__":
main()

After we got this program, we no longer need to list all the tfrecord files manually.

Then I found the following script in tensorflow repo. to build your own image into tfrecord. I did a little bit modify on the PATH and filename part.File
The correct way to use it is:

  • Create a label.txt file under your current directory.
  • Edit the label.txt file according to your image folder, I mean the image folder name is the real label of the images. such as “sushi”, “steak”, “cat”, “dog”, here is an example.
  • Make sure your image folder resides under the current folder.
  • Run the script.
    1
    python build_image_data.py

Then it will turn all your images into tfrecord file.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os
import random
import sys
import threading


import numpy as np
import tensorflow as tf
from PIL import Image

tf.app.flags.DEFINE_string('train_directory', './',
'Training data directory')
tf.app.flags.DEFINE_string('validation_directory', '',
'Validation data directory')
tf.app.flags.DEFINE_string('output_directory', './',
'Output data directory')

tf.app.flags.DEFINE_integer('train_shards', 4,
'Number of shards in training TFRecord files.')
tf.app.flags.DEFINE_integer('validation_shards', 0,
'Number of shards in validation TFRecord files.')

tf.app.flags.DEFINE_integer('num_threads', 4,
'Number of threads to preprocess the images.')

# The labels file contains a list of valid labels are held in this file.
# Assumes that the file contains entries as such:
# dog
# cat
# flower
# where each line corresponds to a label. We map each label contained in
# the file to an integer corresponding to the line number starting from 0.
tf.app.flags.DEFINE_string('labels_file', './label.txt', 'Labels file')


FLAGS = tf.app.flags.FLAGS

i = 0

def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _convert_to_example(filename, image_buffer, label, text, height, width):
"""Build an Example proto for an example.

Args:
filename: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
label: integer, identifier for the ground truth for the network
text: string, unique human-readable, e.g. 'dog'
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""


colorspace = 'RGB'
channels = 3
image_format = 'JPEG'


example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/text': _bytes_feature(text),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(os.path.basename(filename)),
'image/encoded': _bytes_feature(image_buffer)}))
return example


class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""

def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()

# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})

def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image


def _is_png(filename):
"""Determine if a file contains a PNG format image.

Args:
filename: string, path of the image file.

Returns:
boolean indicating if the image is a PNG.
"""

return '.png' in filename


def _process_image(filename, coder):
"""Process a single image file.

Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""

# Read the image file.
image_data = tf.gfile.FastGFile(filename, 'r').read()

# Convert any PNG to JPEG's for consistency.
if _is_png(filename):
print('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)

# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
print(tf.Session().run(tf.shape(image)))

# image = tf.Session().run(tf.image.resize_image_with_crop_or_pad(image, 128, 128))
# image_data = tf.image.encode_jpeg(image)
# img = Image.fromarray(image, "RGB")
# img.save(os.path.join("./re_steak/"+str(i)+".jpeg"))
# i = i+1


# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3


return image_data, height, width


def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
texts, labels, num_shards):

"""Processes and saves list of images as TFRecord in 1 thread.

Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""

# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)

shard_ranges = np.linspace(ranges[thread_index][0],
ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]

counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.2d-of-%.2d.tfrecord' % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)

shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = filenames[i]
label = labels[i]
text = texts[i]

image_buffer, height, width = _process_image(filename, coder)

example = _convert_to_example(filename, image_buffer, label,
text, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
print(counter)

if not counter % 1000:
print('%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()

print('%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print('%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()


def _process_image_files(name, filenames, texts, labels, num_shards):
"""Process and save list of images as TFRecord of Example protos.

Args:
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""

assert len(filenames) == len(texts)
assert len(filenames) == len(labels)

# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i+1]])

# Launch a thread for each batch.
print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
sys.stdout.flush()

# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()

# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()

threads = []
for thread_index in xrange(len(ranges)):
args = (coder, thread_index, ranges, name, filenames,
texts, labels, num_shards)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)

# Wait for all the threads to terminate.
coord.join(threads)
print('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(filenames)))
sys.stdout.flush()


def _find_image_files(data_dir, labels_file):
"""Build a list of all images files and labels in the data set.

Args:
data_dir: string, path to the root directory of images.

Assumes that the image data set resides in JPEG files located in
the following directory structure.

data_dir/dog/another-image.JPEG
data_dir/dog/my-image.jpg

where 'dog' is the label associated with these images.

labels_file: string, path to the labels file.

The list of valid labels are held in this file. Assumes that the file
contains entries as such:
dog
cat
flower
where each line corresponds to a label. We map each label contained in
the file to an integer starting with the integer 0 corresponding to the
label contained in the first line.

Returns:
filenames: list of strings; each string is a path to an image file.
texts: list of strings; each string is the class, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth.
"""

print('Determining list of input files and labels from %s.' % data_dir)
unique_labels = [l.strip() for l in tf.gfile.FastGFile(
labels_file, 'r').readlines()]

labels = []
filenames = []
texts = []

# Leave label index 0 empty as a background class.
label_index = 1

# Construct the list of JPEG files and labels.
for text in unique_labels:
jpeg_file_path = '%s/%s/*' % (data_dir, text)
matching_files = tf.gfile.Glob(jpeg_file_path)

labels.extend([label_index] * len(matching_files))
texts.extend([text] * len(matching_files))
filenames.extend(matching_files)

if not label_index % 100:
print('Finished finding files in %d of %d classes.' % (
label_index, len(labels)))
label_index += 1

# Shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
shuffled_index = range(len(filenames))
random.seed(12345)
random.shuffle(shuffled_index)

filenames = [filenames[i] for i in shuffled_index]
texts = [texts[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]

print('Found %d JPEG files across %d labels inside %s.' %
(len(filenames), len(unique_labels), data_dir))
return filenames, texts, labels


def _process_dataset(name, directory, num_shards, labels_file):
"""Process a complete data set and save it as a TFRecord.

Args:
name: string, unique identifier specifying the data set.
directory: string, root path to the data set.
num_shards: integer number of shards for this data set.
labels_file: string, path to the labels file.
"""

filenames, texts, labels = _find_image_files(directory, labels_file)
_process_image_files(name, filenames, texts, labels, num_shards)


def main(unused_argv):
assert not FLAGS.train_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
assert not FLAGS.validation_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with '
'FLAGS.validation_shards')
print('Saving results to %s !' % FLAGS.output_directory)

# Run it!
_process_dataset('validation', FLAGS.validation_directory,
FLAGS.validation_shards, FLAGS.labels_file)
_process_dataset('train', FLAGS.train_directory,
FLAGS.train_shards, FLAGS.labels_file)


if __name__ == '__main__':
tf.app.run()

At last, we need to read the image back from tfrecord to feed the network or do whatever you want.
I wrote the following scrpit to do this. Be noted that this script must be used along the above script, otherwise, believe me, it wouldn’t work.
This program will call the first script to find all the tfrecord files, then extract the images, label, filenames etc. from the tfrecord file. And crop and resize the image to 299x299x3 and save the preprocessed image to the resized_image folder.
My demo has only 300 example images, so, the iteration is 300 times. File

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
import numpy as np
import os

from PIL import Image
from dir_traversal_tfrecord import tfrecord_auto_traversal

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("image_number", 300, "Number of images in your tfrecord, default is 300.")
flags.DEFINE_integer("class_number", 3, "Number of class in your dataset/label.txt, default is 3.")
flags.DEFINE_integer("image_height", 299, "Height of the output image after crop and resize. Default is 299.")
flags.DEFINE_integer("image_width", 299, "Width of the output image after crop and resize. Default is 299.")

def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

class image_object:
def __init__(self):
self.image = tf.Variable([], dtype = tf.string)
self.height = tf.Variable([], dtype = tf.int64)
self.width = tf.Variable([], dtype = tf.int64)
self.filename = tf.Variable([], dtype = tf.string)
self.label = tf.Variable([], dtype = tf.int32)

def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = {
"image/encoded": tf.FixedLenFeature([], tf.string),
"image/height": tf.FixedLenFeature([], tf.int64),
"image/width": tf.FixedLenFeature([], tf.int64),
"image/filename": tf.FixedLenFeature([], tf.string),
"image/class/label": tf.FixedLenFeature([], tf.int64),})

image_encoded = features["image/encoded"]
image_raw = tf.image.decode_jpeg(image_encoded, channels=3)

current_image_object = image_object()

current_image_object.image = tf.image.resize_image_with_crop_or_pad(image_raw, FLAGS.image_height, FLAGS.image_width) # cropped image with size 299x299
# current_image_object.image = tf.cast(image_crop, tf.float32) * (1./255) - 0.5
current_image_object.height = features["image/height"] # height of the raw image
current_image_object.width = features["image/width"] # width of the raw image
current_image_object.filename = features["image/filename"] # filename of the raw image
current_image_object.label = tf.cast(features["image/class/label"], tf.int32) # label of the raw image

return current_image_object


filename_queue = tf.train.string_input_producer(
tfrecord_auto_traversal(),
shuffle = True)


current_image_object = read_and_decode(filename_queue)

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
print("Write cropped and resized image to the folder './resized_image'")
for i in range(FLAGS.image_number): # number of examples in your tfrecord
pre_image, pre_label = sess.run([current_image_object.image, current_image_object.label])
img = Image.fromarray(pre_image, "RGB")
if not os.path.isdir("./resized_image/"):
os.mkdir("./resized_image")
img.save(os.path.join("./resized_image/class_"+str(pre_label)+"_Index_"+str(i)+".jpeg"))
if i % 10 == 0:
print ("%d images in %d has finished!" % (i, FLAGS.image_number))
print("Complete!!")
coord.request_stop()
coord.join(threads)
sess.close()

print("cd to current directory, the folder 'resized_image' should contains %d images with %dx%d size." % (FLAGS.image_number,FLAGS.image_height, FLAGS.image_width))

The script named flower_train_cnn.py is a script to feed a flower dataset to a typical CNN from scratch.

Follow ups

Currently, the above code can meet my demand, I’ll keep updating it to make things easier.
The next steps are:

  • Try to display the label and the image at the same time, generate the preprocessed images according to their labels. The problem currently is how to handle multiple return values from tf.graph(). (Already fixed.)
  • Feed the images to a network to complete the demo (Fixed).
  • Add batch shuffle function (Fixed).
  • Add input argument function. (Fixed)


License


The content of this blog itself is licensed under the Creative Commons Attribution 4.0 International License.
CC-BY-SA LICENCES

The containing source code (if applicable) and the source code used to format and display that content is licensed under the Apache License 2.0.
Copyright [2016] [yeephycho]
Licensed under the Apache License, Version 2.0 (the “License”);
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Apache License 2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an “AS IS” BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied. See the License for the specific language
governing permissions and limitations under the License.
APACHE LICENCES