-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Extending ImageDataGenerator #3338
Comments
I would be interested if someone can make the ImageDataGenerator extendable. Plus, another feature I would like to have is to separate the random variables generator part which used by ImageDataGenerator, thus the variables can be saved somewhere and make the process reproducible, you don't need to actually store the transformed images, just store the random variables, as a dictionary maybe. These random variables can be feed to into a ImageDataGenerator to produce the results. This going to be super useful in case of dense prediction(such as segmentation), in which we need to do data augmentation on both input image and output image, we can make two ImageDataGenerator, and use the separated random generator to produce the random variables, and then feed into two ImageDataGenerator, so the spatial transform such as rotation are synchronized on both input and output. I would begin to work on this, please let me know if anyone else have ideas regarding this. EDIT: perhaps you would think there is a |
I am working on improve ImageDataGenerator but would be great anyone interested could provide suggestions and feedbacks.
For
For
|
@oeway When it comes to visual semantic segmentation, I find out that extending ImageDataGenerator is really necessary! Current ImageDataGenerator will only conduct transformations on |
@pengpaiSH I have a working version right now, you are welcome to try it out. https://github.com/oeway/keras/tree/extendImageDataGenerator (branch: extendImageDataGenerator) For now, you can customize the pipeline and synchronize two ImageDataGenerator by a from keras.preprocessing.image import ImageDataGenerator,standardize,random_transform
# input generator with standardization on
datagenX = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
featurewise_standardize_axis=(0, 2, 3),
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
fill_mode='reflect',
seed=0,
verbose=1)
# output generator with standardization off
datagenY = ImageDataGenerator(
featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
fill_mode='reflect',
seed=0)
def center_crop(x, center_crop_size, **kwargs):
centerw, centerh = x.shape[1]//2, x.shape[2]//2
halfw, halfh = center_crop_size[0]//2, center_crop_size[1]//2
return x[:, centerw-halfw:centerw+halfw,centerh-halfh:centerh+halfh]
def random_crop(x, random_crop_size, sync_seed=None, **kwargs):
np.random.seed(sync_seed)
w, h = x.shape[1], x.shape[2]
rangew = (w - random_crop_size[0]) // 2
rangeh = (h - random_crop_size[1]) // 2
offsetw = 0 if rangew == 0 else np.random.randint(rangew)
offseth = 0 if rangeh == 0 else np.random.randint(rangeh)
return x[:, offsetw:offsetw+random_crop_size[0], offseth:offseth+random_crop_size[1]]
datagenX.config['random_crop_size'] = (800, 800)
datagenY.config['random_crop_size'] = (800, 800)
datagenX.config['center_crop_size'] = (512, 512)
datagenY.config['center_crop_size'] = (360, 360)
# customize the pipeline
datagenX.set_pipeline([random_crop, random_transform, standardize, center_crop])
datagenY.set_pipeline([random_crop, random_transform, center_crop])
# flow from directory is extended to support more format and also you can even use your own reader function
# here is an example of reading image data saved in csv file
# datagenX.flow_from_directory(csvFolder, image_reader=csvReaderGenerator, read_formats={'csv'}, reader_config={'target_size':(572,572),'resolution':20, 'crange':(0,100)}, class_mode=None, batch_size=1)
dgdx= datagenX.flow_from_directory(inputDir, class_mode=None, read_formats={'png'}, batch_size=2)
dgdy= datagenY.flow_from_directory(outputDir, class_mode=None, read_formats={'png'}, batch_size=2)
# you can now fit a generator as well
datagenX.fit_generator(dgdx, nb_iter=100)
# here we sychronize two generator and combine it into one
train_generator = dgdx+dgdy
model.fit_generator(
train_generator,
samples_per_epoch=2000,
nb_epoch=50,
validation_data=validation_generator,
nb_val_samples=800) |
@oeway Thank you for your quick response and your enhancement for ImageGenerator which I think is really really necessary. Below are my confusions:
|
Sorry for the confusion, I will write documents about it when I have 1, yes, you just need to provide a function and put in the pipeline, 2, I forgot to put the import line, so random_transform and standardize are from keras.preprocessing.image import standardize,random_transform 3, fit_generator and fit is for some augment function such as feature_wise I will document all the features if there are more people interested, and On Tue, Aug 9, 2016 at 10:49 Pai Peng notifications@github.com wrote:
|
@oeway Thank you for your clarifying. Perfect! By the way, the transformations defined in the pipeline come in order or not ? |
@pengpaiSH Welcome, of course, the idea is use an ordered list to defined the pipeline, and you can change the order of the pipeline. |
I also need random cropping, better scaling etc and was just about to start implementing something like this myself. Looks very useful. Will you try to get this merged into the main repo? |
Any plans to merge @oeway branch? I'm specifically interested in the CSV reader option |
@burgalon It's very straight-forward .. like maybe 10 lines of code.. to extend the data generator to read csv.. Still agree it should be merged but in case you dont want to wait :) . I attached some very unpolished code which works perfectly for me (see 'flow_np_from_directory' function and 'NumpyDirectoryIterator' class specifically) to read in .npy files in the same directory structure as the image version, e.g.:
|
Thanks @ncullen93 |
@oeway @burgalon @ncullen93 also, I notice that .flow_from_directory has a way to make the original input to target_size, but .flow() does not....maybe we should add this functionality to the ImageDataGenerator so the for example I can resize my image size of 21x21 to 224x224. I don't find any function like what I have pointed out above. |
Hm yeah.. I noticed that too. That's probably because the image resizing is done directly on the image object before converting to an array using PIL's resize method (see line 302 in preprocessing/image.py)... .flow() would require resampling a numpy array, probably using scipy's functionality, so I guess it's assumed you will just resize it yourself since the data already fits in memory. |
@ncullen93 yup, exactly, but sometime your data may be large array but not image. So it would be nice to provide a way to do resize. I think cv2 can do that nicely. if this functionality is included, I don't need to write a custom datagenerator to do the resize using cv2. |
true. i don't make those decisions unfortunately. In the meantime, here's a very slightly modified image.py file with a big commented section on line 719 ( look for the "### ADD CODE HERE TO RESHAPE ARRAY TO TARGET SIZE ####") for you to add one line of code to upsample/reshape your array however you like! Just add code to upsample/reshape your array and then you can use "target_size" on the flow() function just like flow_from_directory(). It's very simple. Then replace this image.py file with that in your keras directory, and reinstall. That's all I personally can offer. |
@ncullen93 Wow, thanks! I will look at that. |
I extended @oeway 's ImageDataGenerator fork to use a multiprocessing Pool. I got a pretty good speedup: https://github.com/stratospark/food-101-keras/blob/master/tools/image_gen_extended.py This was to fully utilize a Titan X GPU as I write about here: http://blog.stratospark.com/deep-learning-applied-food-classification-deep-learning-keras.html |
Can you comment a little more on it. Does it support both arrays and directory sampling? Seems to only be arrays from my quick look. Directory sampling is probably the biggest bottleneck in my opinion, since bigger nets typically coincide w/ datasets that dont fit in memory - so the speedup margins on using this w/ directory sampling would be much greater. anyways, implementing just the multiprocessing speedup on the actual keras image generator without the pipeline stuff (since it's not standardized) would be a huge contribution in my opinion |
@ncullen93 I'm planning on cleaning up the code soon and posting a PR so it can hopefully be merged into the main Keras branch. Enough to maintain compatibility with the mainline Keras features. It was a quick hack to support my use case, augmenting images in memory fast enough to keep my GPU close to full utilization. I had to disable any augmentation related to fitting the model, such as normalization and zca. This was due to the locks not being able to be pickled when doing a multiprocessing map. Another change I had to make was having to explicitly pass a multiprocessing.Pool. This is due to the fact that Python multiprocessing forks the process, and thus it is easy to run into out of memory errors if you fork after loading your images. I have to create the pool at the very beginning of the script so it uses as little resources as possible. I'm pretty new at this, so there is probably a more elegant way of handling this. One thing that I would like to fix is when interrupting the training in Jupyter Notebook, it kills the processes so I have to restart the kernel and load up all the images to do another run. This really breaks the flow of trying different models and parameters! To support reproducible results, I also had to create separate random number generators for each process, so each process didn't return the same random numbers (and thus the exact same images as each other). I know I left a few numpy default random generators, so I need to fix that up. |
This would fit nicely in the contrib library: https://github.com/farizrahman4u/keras-contrib |
@stratospark I agree with @joeyearsley you should add a PR for keras-contrib, particularly now that keras-2 is out. |
I shared how I extended Keras' ImageDataGenerator to support random cropping in this blog post: I think the same technique could be easily adapted as a solution to the original question. "Is there an easy way to write generator extensions for Keras? I'd like to use some of the ImageDataGenerator preprocessing steps but also add some of my own such as randomly occluding areas of the image, adding noise etc. If not I can add some of these things to the ImageDataGenerator class in Keras, is that something that would be useful?" |
I am doing semantic segmentation, thus I have both input images X and label images Y of size 1440x1920. and I do: dgdy= datagenY.flow_np_from_directory() And syncronizing the generators: And if I modify image.py I can add a random_crop_from_categories. Is it possible to implement a random_crop_from_categories in the image.py file to sample equally from the categories in Y and apply the same crop in X? |
In the keras command, you have to define the image type, that is, color_mode = 'rgb' or 'grayscale' when using the command flow_from_directory with model.fit_generator. So how do I change color_mode in flow_from_directory to use with 2, 4, 5 and 6 channels? I thank you for your attention, |
@ChristianEschen, I copied my reply on Disqus over here. You can define your paired_crop_generator() to crop the same region from X (images) and Y (masks) simultaneously. The code below is just quick and dirty hack, but hopefully you get the idea.
|
@gledsonmelotti, I believe Keras' flow_from_directory() API could only handle 'rgb' or 'grayscale' images. For training data with 2, 4, 5, 6, ... channels, you'd need to implement your own data loading code. |
@oeway I've a model with multiple-output loss functions and need to provide for each of those losses the ground-truth masks, sth. like.: |
Hey @tinalegre |
@tinalegre @mbenami If your output images are with the same size, what I would do is to combine these them into one image with multiple channels, so you can generate the data as if there are one input image and one output image. with my extension, most of those transformation functions they can support any number of channels. def customized_objective(y_true, y_pred):
# split the tensor into different channels
ch1_true = y_true[:, :, :, :1]
ch2_true = y_true[:, :, :, 1:2]
...
return mse(ch1_true, ch1_pred) + L1(ch2_true, ch2_pred) At least, this is how I do it for outputs, good luck. |
@jkjung-avt Thank you very much. |
@oeway @mbenami thank you for your feedback, my previous solution was actually wrong. I followed your suggestion @oeway and I changed my network by using a Lambda function to concatenate the outputs as follows:
without the Lambda it was not working as Keras expected a Tensor as output parameter. When compiling the model, I'm doing:
where y_preds is of shape |
@tinalegre It seems to me that you need to make y_true has the same shape as y_pred, by padding with zeros for example. I guess Keras expect you will have the same shape for y_true and y_pred, so you just need to pass the check of the Keras engine. |
@oeway the problem is somehow in the |
@tinalegre I think |
@oeway Hi I am confused on the memory allocation scheme of ImageDataGenerator. I write a generator, the structure is similar to @jagiella 's example sulution. But when I use |
For Transforming Images and masks (synchronously) for segmentation for example, take a look on keras documentation (checked recently) 👍 |
@amineio thanks for the info. did you use it? and how would you manage one-hot encoding for the masks? |
@tinalegre there is only two classes (black and white masks) the final layer looks like that : |
@oeway Currently, I am using the I've been search in the documentation, but the only thing I could find is
Does your code account for that? |
@bernardohenz thanks for your feedback. I don't have time to look in more details right now. I think probably it's not guaranteed, Keras introduced |
hi @oeway I've checked this link already, but the generator already offers so much features (flow_from_directory, several transformations, pipeline ordering) that would take some time to implement in the |
@oeway I've adapted your code to use the |
Closing as this is resolved |
Just for the reference, here's another way of adding image manipulation (random & center crop in my case) to the pipeline before Keras resizes images. It's done by monkey patching https://gist.github.com/rstml/bbd491287efc24133b90d4f7f3663905 |
@stratospark @oeway Are any of the scripts here merged into the main keras? 3 years have gone and default keras still doesn't allow the use of csv/npy ... |
@oeway in your code you showed us how to sychronize two generator and combine it into one:
|
Hi all. Many thanks to @oeway and @bernardohenz for their very helpful solutions. Just checking back in two years later: is there a similar adaptation of ImageDataGenerator now that the Keras team have refactored this code? The specific functionality I'm looking for is just precise reproducibility; I'd like ImageDataGenerator to be seedable but do not need the other features. For reference, my workflow is ImageDataGenerator -> flow -> fit_generator and I'm running Keras 2.3.1 with TF 2.0 backend in python 3.6.9 in a docker container with base Ubuntu 18.04. |
@jkjung-avt @XiaoTonyLuo |
Is there an easy way to write generator extensions for Keras? I'd like to use some of the ImageDataGenerator preprocessing steps but also add some of my own such as randomly occluding areas of the image, adding noise etc. If not I can add some of these things to the ImageDataGenerator class in Keras, is that something that would be useful?
Also can I just double check that the ImageDataGenerator does indeed generate batches of examples in a non-locking way, thereby not causing any GPU training bottlenecks?
The text was updated successfully, but these errors were encountered: