-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathsrez_input.py
136 lines (111 loc) · 5.17 KB
/
srez_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
import numpy as np
FLAGS = tf.app.flags.FLAGS
# generate mask based on alpha
def generate_mask_alpha(size=[128,128], r_factor_designed=5.0, r_alpha=3, axis_undersample=1,
acs=3, seed=0, mute=0):
# init
mask = np.zeros(size)
if seed>=0:
np.random.seed(seed)
# get samples
num_phase_encode = size[axis_undersample]
num_phase_sampled = int(np.floor(num_phase_encode/r_factor_designed))
# coordinate
coordinate_normalized = np.array(range(num_phase_encode))
coordinate_normalized = np.abs(coordinate_normalized-num_phase_encode/2)/(num_phase_encode/2.0)
prob_sample = coordinate_normalized**r_alpha
prob_sample = prob_sample/sum(prob_sample)
# sample
index_sample = np.random.choice(num_phase_encode, size=num_phase_sampled,
replace=False, p=prob_sample)
# sample
if axis_undersample == 0:
mask[index_sample,:]=1
else:
mask[:,index_sample]=1
# acs
if axis_undersample == 0:
mask[:int((acs+1)/2),:]=1
mask[-int(acs/2):,:]=1
else:
mask[:,:int((acs+1)/2)]=1
mask[:,-int(acs/2):]=1
# compute reduction
r_factor = len(mask.flatten())/sum(mask.flatten())
if not mute:
print('gen mask size of {1} for R-factor={0:.4f}'.format(r_factor, mask.shape))
print(num_phase_encode, num_phase_sampled, np.where(mask[0,:]))
return mask, r_factor
# generate mask based on .mat mask
def generate_mask_mat(mask=[], mute=0):
# shift
mask = np.fft.ifftshift(mask)
# compute reduction
r_factor = len(mask.flatten())/sum(mask.flatten())
if not mute:
print('load mask size of {1} for R-factor={0:.4f}'.format(r_factor, mask.shape))
return mask, r_factor
def setup_inputs_one_sources(sess, filenames_input, filenames_output, image_size=None,
axis_undersample=1, capacity_factor=1,
r_factor=4, r_alpha=0, r_seed=0,
sampling_mask=None, num_threads=1):
# image size
if image_size is None:
# image_size
if FLAGS.sample_size_y>0:
image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
else:
image_size = [FLAGS.sample_size, FLAGS.sample_size]
# generate default mask
if sampling_mask is None:
DEFAULT_MASK, _ = generate_mask_alpha(image_size, # kspace size
r_factor_designed=r_factor,
r_alpha=r_alpha,
seed=r_seed,
axis_undersample=axis_undersample
)
else:
# get input mask
DEFAULT_MASK, _ = generate_mask_mat(sampling_mask)
# convert to complex tf tensor
DEFAULT_MAKS_TF = tf.cast(tf.constant(DEFAULT_MASK), tf.float32)
DEFAULT_MAKS_TF_c = tf.cast(DEFAULT_MAKS_TF, tf.complex64)
# Read each JPEG file
reader_input = tf.WholeFileReader()
filename_queue_input = tf.train.string_input_producer(filenames_input, shuffle=False)
key, value_input = reader_input.read(filename_queue_input)
channels = 3
image_input = tf.image.decode_jpeg(value_input, channels=channels, name="input_image")
image_input.set_shape([256, 256, channels])
print('size_input_image', image_input.get_shape())
#choose the magnitude
image_input = image_input[0:image_size[0],0:image_size[1]]
# cast image to float in 0~1
image_input = tf.cast(image_input, tf.float32)/255.0
# use the last channel (B) for input and output, assume image is in gray-scale
image_output = image_input[:,:,-1]
image_input = image_input[:,:,-1]
# apply undersampling mask
kspace_input = tf.fft2d(tf.cast(image_input,tf.complex64))
kspace_zpad = kspace_input * DEFAULT_MAKS_TF_c
# zpad undersampled image for input
image_zpad = tf.ifft2d(kspace_zpad)
image_zpad_real = tf.real(image_zpad)
image_zpad_real = tf.reshape(image_zpad_real, [image_size[0], image_size[1], 1])
image_zpad_imag = tf.imag(image_zpad)
image_zpad_imag = tf.reshape(image_zpad_imag, [image_size[0], image_size[1], 1])
# concat to input, 2 channel for real and imag value
image_zpad_concat = tf.concat(axis=2, values=[image_zpad_real, image_zpad_imag])
# The feature is zpad image with 2 channel, label is the ground-truth real-valued image
feature = tf.reshape(image_zpad_concat, [image_size[0], image_size[1], 2])
label = tf.reshape(image_output, [image_size[0], image_size[1], 1])
mask = tf.reshape(DEFAULT_MAKS_TF_c, [image_size[0], image_size[1]])
# Using asynchronous queues
features, labels, masks = tf.train.batch([feature, label, mask],
batch_size = FLAGS.batch_size,
num_threads = num_threads,
capacity = capacity_factor*FLAGS.batch_size,
name = 'labels_and_features')
tf.train.start_queue_runners(sess=sess)
return features, labels, masks