diff --git a/maml.py b/maml.py index 92eccbd9b..4af0f1dc2 100644 --- a/maml.py +++ b/maml.py @@ -200,7 +200,7 @@ def construct_conv_weights(self): weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden])) if FLAGS.datasource == 'miniimagenet': # assumes max pooling - weights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output], initializer=fc_initializer) + weights['w5'] = tf.get_variable('w5', [self.dim_hidden*6*6, self.dim_output], initializer=fc_initializer) weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') else: weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')