-
Notifications
You must be signed in to change notification settings - Fork 716
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
2,596 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../grading.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import numpy as np | ||
import random | ||
|
||
|
||
def test_vocab(vocab, PAD, UNK, START, END): | ||
return [ | ||
len(vocab), | ||
len(np.unique(list(vocab.values()))), | ||
int(all([_ in vocab for _ in [PAD, UNK, START, END]])) | ||
] | ||
|
||
|
||
def test_captions_indexing(train_captions_indexed, vocab, UNK): | ||
starts = set() | ||
ends = set() | ||
between = set() | ||
unk_count = 0 | ||
for caps in train_captions_indexed: | ||
for cap in caps: | ||
starts.add(cap[0]) | ||
between.update(cap[1:-1]) | ||
ends.add(cap[-1]) | ||
for w in cap: | ||
if w == vocab[UNK]: | ||
unk_count += 1 | ||
return [ | ||
len(starts), | ||
len(ends), | ||
len(between), | ||
len(between | starts | ends), | ||
int(all([isinstance(x, int) for x in (between | starts | ends)])), | ||
unk_count | ||
] | ||
|
||
|
||
def test_captions_batching(batch_captions_to_matrix): | ||
return (batch_captions_to_matrix([[1, 2, 3], [4, 5]], -1, max_len=None).ravel().tolist() | ||
+ batch_captions_to_matrix([[1, 2, 3], [4, 5]], -1, max_len=2).ravel().tolist() | ||
+ batch_captions_to_matrix([[1, 2, 3], [4, 5]], -1, max_len=10).ravel().tolist()) | ||
|
||
|
||
def get_feed_dict_for_testing(decoder, IMG_EMBED_SIZE, vocab): | ||
return { | ||
decoder.img_embeds: np.random.random((32, IMG_EMBED_SIZE)), | ||
decoder.sentences: np.random.randint(0, len(vocab), (32, 20)) | ||
} | ||
|
||
|
||
def test_decoder_shapes(decoder, IMG_EMBED_SIZE, vocab, s): | ||
tensors_to_test = [ | ||
decoder.h0, | ||
decoder.word_embeds, | ||
decoder.flat_hidden_states, | ||
decoder.flat_token_logits, | ||
decoder.flat_ground_truth, | ||
decoder.flat_loss_mask, | ||
decoder.loss | ||
] | ||
all_shapes = [] | ||
for t in tensors_to_test: | ||
_ = s.run(t, feed_dict=get_feed_dict_for_testing(decoder, IMG_EMBED_SIZE, vocab)) | ||
all_shapes.extend(_.shape) | ||
return all_shapes | ||
|
||
|
||
def test_random_decoder_loss(decoder, IMG_EMBED_SIZE, vocab, s): | ||
loss = s.run(decoder.loss, feed_dict=get_feed_dict_for_testing(decoder, IMG_EMBED_SIZE, vocab)) | ||
return loss | ||
|
||
|
||
def test_validation_loss(decoder, s, generate_batch, val_img_embeds, val_captions_indexed): | ||
np.random.seed(300) | ||
random.seed(300) | ||
val_loss = 0 | ||
for _ in range(1000): | ||
val_loss += s.run(decoder.loss, generate_batch(val_img_embeds, | ||
val_captions_indexed, | ||
32, | ||
20)) | ||
val_loss /= 1000. | ||
return val_loss |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import queue | ||
import threading | ||
import zipfile | ||
import tqdm | ||
import cv2 | ||
import numpy as np | ||
import pickle | ||
|
||
|
||
def image_center_crop(img): | ||
raise NotImplementedError() | ||
|
||
|
||
def decode_image_from_buf(buf): | ||
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) | ||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
return img | ||
|
||
|
||
def crop_and_preprocess(img, input_shape, preprocess_for_model): | ||
img = image_center_crop(img) # take center crop | ||
img = cv2.resize(img, input_shape) # resize for our model | ||
img = img.astype("float32") # prepare for normalization | ||
img = preprocess_for_model(img) # preprocess for model | ||
return img | ||
|
||
|
||
def apply_model(zip_fn, model, preprocess_for_model, extensions=(".jpg",), input_shape=(224, 224), batch_size=32): | ||
# queue for cropped images | ||
q = queue.Queue(maxsize=batch_size * 10) | ||
|
||
# when read thread put all images in queue | ||
read_thread_completed = threading.Event() | ||
|
||
# time for read thread to die | ||
kill_read_thread = threading.Event() | ||
|
||
def reading_thread(zip_fn): | ||
zf = zipfile.ZipFile(zip_fn) | ||
for fn in tqdm.tqdm_notebook(zf.namelist()): | ||
if kill_read_thread.is_set(): | ||
break | ||
if os.path.splitext(fn)[-1] in extensions: | ||
buf = zf.read(fn) # read raw bytes from zip for fn | ||
img = decode_image_from_buf(buf) # decode raw bytes | ||
img = crop_and_preprocess(img, input_shape, preprocess_for_model) | ||
while True: | ||
try: | ||
q.put((os.path.split(fn)[-1], img), timeout=1) # put in queue | ||
except queue.Full: | ||
if kill_read_thread.is_set(): | ||
break | ||
continue | ||
break | ||
|
||
read_thread_completed.set() # read all images | ||
|
||
# start reading thread | ||
t = threading.Thread(target=reading_thread, args=(zip_fn,)) | ||
t.daemon = True | ||
t.start() | ||
|
||
img_fns = [] | ||
img_embeddings = [] | ||
|
||
batch_imgs = [] | ||
|
||
def process_batch(batch_imgs): | ||
batch_imgs = np.stack(batch_imgs, axis=0) | ||
batch_embeddings = model.predict(batch_imgs) | ||
img_embeddings.append(batch_embeddings) | ||
|
||
try: | ||
while True: | ||
try: | ||
fn, img = q.get(timeout=1) | ||
except queue.Empty: | ||
if read_thread_completed.is_set(): | ||
break | ||
continue | ||
img_fns.append(fn) | ||
batch_imgs.append(img) | ||
if len(batch_imgs) == batch_size: | ||
process_batch(batch_imgs) | ||
batch_imgs = [] | ||
q.task_done() | ||
# process last batch | ||
if len(batch_imgs): | ||
process_batch(batch_imgs) | ||
finally: | ||
kill_read_thread.set() | ||
t.join() | ||
|
||
q.join() | ||
|
||
img_embeddings = np.vstack(img_embeddings) | ||
return img_embeddings, img_fns | ||
|
||
|
||
def save_pickle(obj, fn): | ||
with open(fn, "wb") as f: | ||
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
def read_pickle(fn): | ||
with open(fn, "rb") as f: | ||
return pickle.load(f) |
Oops, something went wrong.