Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Excuse me, why is the flower dataset I test the result is very different from result.png #41

Open
keqkeq opened this issue Apr 21, 2020 · 0 comments

Comments

@keqkeq
Copy link

keqkeq commented Apr 21, 2020

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
from tensorlayer.prepro import *
from tensorlayer.cost import *
import numpy as np
import scipy
from scipy.io import loadmat
import time, os, re, nltk

from utils import *
from model import *
import model
import pickle

###======================== PREPARE DATA ====================================###
print("Loading data from pickle ...")
import pickle
with open("_vocab.pickle", 'rb') as f:
vocab = pickle.load(f)
with open("_image_train.pickle", 'rb') as f:
_, images_train = pickle.load(f)
with open("_image_test.pickle", 'rb') as f:
_, images_test = pickle.load(f)
with open("_n.pickle", 'rb') as f:
n_captions_train, n_captions_test, n_captions_per_image, n_images_train, n_images_test = pickle.load(f)
with open("_caption.pickle", 'rb') as f:
captions_ids_train, captions_ids_test = pickle.load(f)

images_train_256 = np.array(images_train_256)

images_test_256 = np.array(images_test_256)

images_train = np.array(images_train)
images_test = np.array(images_test)

ni = int(np.ceil(np.sqrt(batch_size)))
save_dir = "checkpoint"

t_real_image = tf.placeholder('float32', [batch_size, image_size, image_size, 3], name = 'real_image')

t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input')

t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise')
generator_txt2img = model.generator_txt2img_resnet

net_rnn = rnn_embed(t_real_caption, is_train=False, reuse=False)
net_g, _ = generator_txt2img(t_z,
net_rnn.outputs,
is_train=False, reuse=False, batch_size=batch_size)

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tl.layers.initialize_global_variables(sess)

net_rnn_name = os.path.join(save_dir, 'net_rnn.npz400.npz')
net_cnn_name = os.path.join(save_dir, 'net_cnn.npz400.npz')
net_g_name = os.path.join(save_dir, 'net_g.npz400.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz400.npz')

net_rnn_res = tl.files.load_and_assign_npz(sess=sess, name=net_rnn_name, network=net_rnn)

net_g_res = tl.files.load_and_assign_npz(sess=sess, name=net_g_name, network=net_g)

sample_size = batch_size
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)

n = int(sample_size / ni)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * n +
["this flower has petals that are yellow, white and purple and has dark lines"] * n +
["the petals on this flower are white with a yellow center"] * n +
["this flower has a lot of small round pink petals."] * n +
["this flower is orange in color, and has petals that are ruffled and rounded."] * n +
["the flower has yellow petals and the center of it is brown."] * n +
["this flower has petals that are blue and white."] * n +
["these white flowers have petals that start off white in color and end in a white towards the tips."] * n

for i, sentence in enumerate(sample_sentence):
print("seed: %s" % sentence)
sentence = preprocess_caption(sentence)
sample_sentence[i] = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(sentence)] + [vocab.end_id] # add END_ID

sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post')

img_gen, rnn_out = sess.run([net_g_res.outputs, net_rnn_res.outputs], feed_dict={
t_real_caption : sample_sentence,
t_z : sample_seed})

save_images(img_gen, [ni, ni], 'samples/gen_samples/gen.png')

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant