|
@@ -0,0 +1,219 @@
|
|
|
+import argparse
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+CROP_SIZE = 256 # scale_size = CROP_SIZE
|
|
|
+ngf = 64
|
|
|
+ndf = 64
|
|
|
+
|
|
|
+
|
|
|
+def preprocess(image):
|
|
|
+ with tf.name_scope('preprocess'):
|
|
|
+ # [0, 1] => [-1, 1]
|
|
|
+ return image * 2 - 1
|
|
|
+
|
|
|
+
|
|
|
+def deprocess(image):
|
|
|
+ with tf.name_scope('deprocess'):
|
|
|
+ # [-1, 1] => [0, 1]
|
|
|
+ return (image + 1) / 2
|
|
|
+
|
|
|
+
|
|
|
+def conv(batch_input, out_channels, stride):
|
|
|
+ with tf.variable_scope('conv'):
|
|
|
+ in_channels = batch_input.get_shape()[3]
|
|
|
+ filter = tf.get_variable('filter', [4, 4, in_channels, out_channels], dtype=tf.float32,
|
|
|
+ initializer=tf.random_normal_initializer(0, 0.02))
|
|
|
+ # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
|
|
|
+ # => [batch, out_height, out_width, out_channels]
|
|
|
+ padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
|
|
|
+ conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding='VALID')
|
|
|
+ return conv
|
|
|
+
|
|
|
+
|
|
|
+def lrelu(x, a):
|
|
|
+ with tf.name_scope('lrelu'):
|
|
|
+ # adding these together creates the leak part and linear part
|
|
|
+ # then cancels them out by subtracting/adding an absolute value term
|
|
|
+ # leak: a*x/2 - a*abs(x)/2
|
|
|
+ # linear: x/2 + abs(x)/2
|
|
|
+
|
|
|
+ # this block looks like it has 2 inputs on the graph unless we do this
|
|
|
+ x = tf.identity(x)
|
|
|
+ return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
|
|
|
+
|
|
|
+
|
|
|
+def batchnorm(input):
|
|
|
+ with tf.variable_scope('batchnorm'):
|
|
|
+ # this block looks like it has 3 inputs on the graph unless we do this
|
|
|
+ input = tf.identity(input)
|
|
|
+
|
|
|
+ channels = input.get_shape()[3]
|
|
|
+ offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
|
|
|
+ scale = tf.get_variable('scale', [channels], dtype=tf.float32,
|
|
|
+ initializer=tf.random_normal_initializer(1.0, 0.02))
|
|
|
+ mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
|
|
|
+ variance_epsilon = 1e-5
|
|
|
+ normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
|
|
|
+ return normalized
|
|
|
+
|
|
|
+
|
|
|
+def deconv(batch_input, out_channels):
|
|
|
+ with tf.variable_scope('deconv'):
|
|
|
+ batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
|
|
|
+ filter = tf.get_variable('filter', [4, 4, out_channels, in_channels], dtype=tf.float32,
|
|
|
+ initializer=tf.random_normal_initializer(0, 0.02))
|
|
|
+ # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
|
|
|
+ # => [batch, out_height, out_width, out_channels]
|
|
|
+ conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels],
|
|
|
+ [1, 2, 2, 1], padding='SAME')
|
|
|
+ return conv
|
|
|
+
|
|
|
+
|
|
|
+def process_image(x):
|
|
|
+ with tf.name_scope('load_images'):
|
|
|
+ raw_input = tf.image.convert_image_dtype(x, dtype=tf.float32)
|
|
|
+
|
|
|
+ raw_input.set_shape([None, None, 3])
|
|
|
+
|
|
|
+ # break apart image pair and move to range [-1, 1]
|
|
|
+ width = tf.shape(raw_input)[1] # [height, width, channels]
|
|
|
+ a_images = preprocess(raw_input[:, :width // 2, :])
|
|
|
+ b_images = preprocess(raw_input[:, width // 2:, :])
|
|
|
+
|
|
|
+ inputs, targets = [a_images, b_images]
|
|
|
+
|
|
|
+ # synchronize seed for image operations so that we do the same operations to both
|
|
|
+ # input and output images
|
|
|
+ def transform(image):
|
|
|
+ r = image
|
|
|
+
|
|
|
+ # area produces a nice downscaling, but does nearest neighbor for upscaling
|
|
|
+ # assume we're going to be doing downscaling here
|
|
|
+ r = tf.image.resize_images(r, [CROP_SIZE, CROP_SIZE], method=tf.image.ResizeMethod.AREA)
|
|
|
+
|
|
|
+ return r
|
|
|
+
|
|
|
+ with tf.name_scope('input_images'):
|
|
|
+ input_images = tf.expand_dims(transform(inputs), 0)
|
|
|
+
|
|
|
+ with tf.name_scope('target_images'):
|
|
|
+ target_images = tf.expand_dims(transform(targets), 0)
|
|
|
+
|
|
|
+ return input_images, target_images
|
|
|
+
|
|
|
+ # Tensor('batch:1', shape=(1, 256, 256, 3), dtype=float32) -> 1 batch size
|
|
|
+
|
|
|
+
|
|
|
+def create_generator(generator_inputs, generator_outputs_channels):
|
|
|
+ layers = []
|
|
|
+
|
|
|
+ # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
|
|
|
+ with tf.variable_scope('encoder_1'):
|
|
|
+ output = conv(generator_inputs, ngf, stride=2)
|
|
|
+ layers.append(output)
|
|
|
+
|
|
|
+ layer_specs = [
|
|
|
+ ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
|
|
|
+ ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
|
|
|
+ ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
|
|
|
+ ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
|
|
|
+ ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
|
|
|
+ ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
|
|
|
+ ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
|
|
|
+ ]
|
|
|
+
|
|
|
+ for out_channels in layer_specs:
|
|
|
+ with tf.variable_scope('encoder_%d' % (len(layers) + 1)):
|
|
|
+ rectified = lrelu(layers[-1], 0.2)
|
|
|
+ # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
|
|
|
+ convolved = conv(rectified, out_channels, stride=2)
|
|
|
+ output = batchnorm(convolved)
|
|
|
+ layers.append(output)
|
|
|
+
|
|
|
+ layer_specs = [
|
|
|
+ (ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
|
|
|
+ (ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
|
|
|
+ (ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
|
|
|
+ (ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
|
|
|
+ (ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
|
|
|
+ (ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
|
|
|
+ (ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
|
|
|
+ ]
|
|
|
+
|
|
|
+ num_encoder_layers = len(layers)
|
|
|
+ for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
|
|
|
+ skip_layer = num_encoder_layers - decoder_layer - 1
|
|
|
+ with tf.variable_scope('decoder_%d' % (skip_layer + 1)):
|
|
|
+ if decoder_layer == 0:
|
|
|
+ # first decoder layer doesn't have skip connections
|
|
|
+ # since it is directly connected to the skip_layer
|
|
|
+ input = layers[-1]
|
|
|
+ else:
|
|
|
+ input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
|
|
|
+
|
|
|
+ rectified = tf.nn.relu(input)
|
|
|
+ # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
|
|
|
+ output = deconv(rectified, out_channels)
|
|
|
+ output = batchnorm(output)
|
|
|
+
|
|
|
+ if dropout > 0.0:
|
|
|
+ output = tf.nn.dropout(output, keep_prob=1 - dropout)
|
|
|
+
|
|
|
+ layers.append(output)
|
|
|
+
|
|
|
+ # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
|
|
|
+ with tf.variable_scope('decoder_1'):
|
|
|
+ input = tf.concat([layers[-1], layers[0]], axis=3)
|
|
|
+ rectified = tf.nn.relu(input)
|
|
|
+ output = deconv(rectified, generator_outputs_channels)
|
|
|
+ output = tf.tanh(output)
|
|
|
+ layers.append(output)
|
|
|
+
|
|
|
+ return layers[-1]
|
|
|
+
|
|
|
+
|
|
|
+def create_model(inputs, targets):
|
|
|
+ with tf.variable_scope('generator') as scope:
|
|
|
+ out_channels = int(targets.get_shape()[-1])
|
|
|
+ outputs = create_generator(inputs, out_channels)
|
|
|
+
|
|
|
+ return outputs
|
|
|
+
|
|
|
+
|
|
|
+def convert(image):
|
|
|
+ return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True, name='output') # output tensor
|
|
|
+
|
|
|
+
|
|
|
+def generate_output(x):
|
|
|
+ with tf.name_scope('generate_output'):
|
|
|
+ test_inputs, test_targets = process_image(x)
|
|
|
+
|
|
|
+ # inputs and targets are [batch_size, height, width, channels]
|
|
|
+ model = create_model(test_inputs, test_targets)
|
|
|
+
|
|
|
+ # deprocess files
|
|
|
+ outputs = deprocess(model)
|
|
|
+
|
|
|
+ # reverse any processing on images so they can be written to disk or displayed to user
|
|
|
+ converted_outputs = convert(outputs)
|
|
|
+ return converted_outputs
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--model-input', dest='input_folder', type=str, help='Model folder to import.')
|
|
|
+ parser.add_argument('--model-output', dest='output_folder', type=str, help='Model (reduced) folder to export.')
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ x = tf.placeholder(tf.uint8, shape=(256, 512, 3), name='image_tensor') # input tensor
|
|
|
+ y = generate_output(x)
|
|
|
+
|
|
|
+ with tf.Session() as sess:
|
|
|
+ # Restore original model
|
|
|
+ saver = tf.train.Saver()
|
|
|
+ checkpoint = tf.train.latest_checkpoint(args.input_folder)
|
|
|
+ saver.restore(sess, checkpoint)
|
|
|
+
|
|
|
+ # Export reduced model used for prediction
|
|
|
+ saver = tf.train.Saver()
|
|
|
+ saver.save(sess, '{}/reduced_model'.format(args.output_folder))
|