diff --git a/encoder/generator_model.py b/encoder/generator_model.py index 867d2a6d6..27112b780 100644 --- a/encoder/generator_model.py +++ b/encoder/generator_model.py @@ -29,10 +29,13 @@ def __init__(self, model, batch_size, randomize_noise=False): self.sess = tf.get_default_session() self.graph = tf.get_default_graph() - self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name) + in_expr, out_expr = next(expr for expr in model.components.synthesis._run_cache.values()) + + self.dlatent_variable = next(v for v in in_expr if 'learnable_dlatents' in v.name) + self.generator_output = next(v for v in out_expr if '_Run/concat' in v.name) + self.set_dlatents(self.initial_dlatents) - self.generator_output = self.graph.get_tensor_by_name('G_synthesis_1/_Run/concat:0') self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False) self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)