diff --git a/encoder/perceptual_model.py b/encoder/perceptual_model.py index 5da09fb29..f9430527a 100644 --- a/encoder/perceptual_model.py +++ b/encoder/perceptual_model.py @@ -30,6 +30,12 @@ def __init__(self, img_size, layer=9, batch_size=1, sess=None): self.features_weight = None self.loss = None + self.features_weight_placeholder = None + self.img_features_placeholder = None + + self.features_weight_op = None + self.img_features_op = None + def build_perceptual_model(self, generated_image_tensor): vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3)) self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output) @@ -46,6 +52,12 @@ def build_perceptual_model(self, generated_image_tensor): self.loss = tf.losses.mean_squared_error(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features) / 82890.0 + self.features_weight_placeholder = tf.placeholder(self.features_weight.dtype, shape=self.features_weight.get_shape()) + self.img_features_placeholder = tf.placeholder(self.ref_img_features.dtype, shape=self.ref_img_features.get_shape()) + + self.features_weight_op = self.features_weight.assign(self.features_weight_placeholder) + self.img_features_op = self.ref_img_features.assign(self.img_features_placeholder) + def set_reference_images(self, images_list): assert(len(images_list) != 0 and len(images_list) <= self.batch_size) loaded_image = load_images(images_list, self.img_size) @@ -65,8 +77,8 @@ def set_reference_images(self, images_list): image_features = np.vstack([image_features, np.zeros(empty_features_shape)]) - self.sess.run(tf.assign(self.features_weight, weight_mask)) - self.sess.run(tf.assign(self.ref_img_features, image_features)) + self.sess.run(self.features_weight_op, {self.features_weight_placeholder: weight_mask}) + self.sess.run(self.img_features_op, {self.img_features_placeholder: image_features}) def optimize(self, vars_to_optimize, iterations=500, learning_rate=1.): vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]