Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

How to train on part of the dataset #1382

Closed
Nic-Ma opened this issue Oct 13, 2020 · 17 comments
Closed

How to train on part of the dataset #1382

Nic-Ma opened this issue Oct 13, 2020 · 17 comments
Labels

Comments

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Oct 13, 2020

❓ Questions/Help/Support

Hi @vfdev-5 ,

We have some requirements for very big dataset:
See we have 100 iterations for 1 epoch, we want to call run() to train on the first 10 iterations data, do some other things for the model and then call run() to train on the second 10 iterations data, ...
Is it possible in ignite now?
I found the the iter for dataloader in ignite is always from beginning?
https://github.com/pytorch/ignite/blob/v0.4.2/ignite/engine/engine.py#L771

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 13, 2020

Hi @Nic-Ma

See we have 100 iterations for 1 epoch, we want to call run() to train on the first 10 iterations data, do some other things for the model and then call run() to train on the second 10 iterations data, ...

I'd say you can attach a handler to perform things for the model every 10 iterations and call once trainer.run().

trainer = ...


@trainer.on(Events.ITERATION_COMPLETED(every=10))
def do_something_with_model():
    # ...

trainer.run(big_dataset_loader, max_epochs=1)

Otherwise, we have a request here: #1371 that may be related to your need ...

Let me know if it is something you need or above solution works for you.

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 14, 2020

Hi @vfdev-5 ,

Thanks for your suggestion, your solution should work for most of cases.
But my problem is that we are trying to connect MONAI with our Clara-FL framework, which runs below logic:

  1. fl-client creates an engine, set global training epochs(see 10 epochs, every epoch has 100 iterations).
  2. fl-client calls engine.run() to train 10 iterations and stop.
  3. fl-client uploads model weights to fl-server and do aggregation.
  4. fl-client downloads new model weights from server and applies it to model.
  5. fl-client calls engine.run() to resume training from number 11 iteration.
    ...
  6. engine do validation every 2 epochs.

So in another word, the request is to resume training from previous iteration and epoch, can we do it in ignite?

The problem of your event-handler solution is that all the logic is controlled by fl-client itself, we just develop callbacks for fl-client, it's not easy to call fl-client logic in handler. Of course, we can also use multi-threads to lock the training in handler and let fl-client to do other things, but I feel it's not very straight-forward and easy to make bugs. So I think maybe run() -> resume() -> resume() ... ->stop() is the best solution?

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 14, 2020

Hi @Nic-Ma

OK, I understand better the need. Unfortunately, there is no simple way to perform run() -> resume() -> resume() ... ->stop() logic without transforming Engine to the generator (related to #1371, i'm working on that).

Currently, main issue is to correctly restore the dataflow from iterations 10, 20 etc.

If it could be possible to make the steps per epoch vs 10 iterations then a solution could be

from ignite.engine import Engine, Events

engine = Engine(lambda e, b: None)

def once_at_start(engine, _):
    return engine.state.epoch == 0

def once_at_end(engine, _):
    return engine.state.epoch == 3

engine.add_event_handler(Events.STARTED(once_at_start), lambda x: print("started"))
engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("{} epoch started".format(x.state.epoch)))
engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("{} epoch completed".format(x.state.epoch)))
engine.add_event_handler(Events.COMPLETED(once_at_end), lambda x: print("completed"))

state = engine.run([0, 1, 2], max_epochs=1)
print("Do something else")
state = engine.run([0, 1, 2], max_epochs=state.max_epochs + 1)
print("Do something else")
engine.run([0, 1, 2], max_epochs=state.max_epochs + 1)

with output

started
1 epoch started
1 epoch completed
Do something else
2 epoch started
2 epoch completed
Do something else
3 epoch started
3 epoch completed
completed

PS: We can reduce epoch length to 10 iteration, but it has its own implications...

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 15, 2020

Hi @vfdev-5 ,

Thanks for your detailed suggestions with code.
Let me do a little bit more investigation and discuss with internal team, then get back to you soon.

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 15, 2020

Hi @Nic-Ma sounds good ! Anyway, we also would like to propose a clean solution for Engine as a generator.

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 19, 2020

Hi @vfdev-5 ,

I tried to add multiple threads to control the logic, for example: thread 1 is FL client, thread 2 is ignite training loop.
During training, thread 1 waits for the lock, the FL handler(every 10 iterations) in thread 2 releases lock then thread 1 sends model weights to FL server, and thread 2 waits for the lock to do next steps or abort.
But how to do these logic in multi-gpu training based on torch.distributed.launch?
I didn't find a method to communicate between master process and the sub-processes:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 19, 2020

Hi @Nic-Ma

Intersting question, could you please provide a link or a code snippet that I could run from my side with FL client and ignite training loop ? It'd help a lot to suggest a solution for distributed computation too.

I didn't find a method to communicate between master process and the sub-processes

I think what you would like to do is the following:

fl_client = None
if idist.get_rank() == 0:
    fl_client = ...

@fl_trainer.on(Events.ITERATION_COMPLETED(every=10))
@idist.one_rank_only(rank=0, with_barrier=True)
def fl_sync_model():
    fl_client.send(ddp_model.module)

where idist.one_rank_only executes the handler on the rank 0 only with barrier :

    dist.barrier()
    if dist.get_rank() == 0:
        fl_client.send(ddp_model.module)
    dist.barrier()

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 19, 2020

Hi @vfdev-5 ,

Your solution is interesting, I am also investigating distributed training recently.
within dist.barrier() pending, I can also broadcast new model weights from rank 0 to others or send message to ask them to abort?

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 19, 2020

within dist.barrier() pending, I can also broadcast new model weights from rank 0 to others or send message to ask them to abort?

You can consider the barrier as an explicit call to synchronize all processes (e.g call all_reduce on a dummy input). So, yes, you can do

dist.barrier()
# this should be called from all processes
dist.broadcast(...)
# or abort from master
if dist.get_rank() == 0 and something_wrong:
    raise RuntimeError("abc")
dist.barrier()

For sending messages, I'd say is more tricky and probably would require to add some communication methods (in order to be executable with torch dist, horovod, xla).

PS. Sorry, wrong button pushed

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 19, 2020

Hi @vfdev-5 ,

Sounds good a plan, let me try to develop a simple demo program to verify it first.
And about the messaging, can raise RuntimeError("abc") kill all processes directly? Maybe we also need to call dist.destroy_process_group()?
Let's investigate how to send messages between processes gracefully, maybe just broadcast flag variables.

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 19, 2020

@Nic-Ma this should be confirmed but I think in the end all processes will be killed by a timeout on dist.barrier().

Let's investigate how to send messages between processes gracefully, maybe just broadcast flag variables.

Yes, something like that. I'd do it with all gather like that:

dist.barrier()

should_terminate = False
if dist.get_rank() == 0 and something_wrong:
    should_terminate = True

should_terminate = idist.all_gather(int(should_terminate))
if any(should_terminate):
    raise RuntimeError("stop")

dist.barrier()

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 19, 2020

Cool, your trick seems very useful, let me make a demo to verify first.
Thanks.

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 20, 2020

Hi @vfdev-5 ,

I developed a simulation program according to our discussion and sent it to your email.
Could you please help review it? Then I can move forward for the next real practice.

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 20, 2020

Hi @Nic-Ma thanks, I received you email and take a look a bit later today.

@Nic-Ma
Copy link
Contributor Author

Nic-Ma commented Oct 28, 2020

Hi @vfdev-5 ,

Actually, I double-checked your engine source code, seems you already tried to save previous training epochs, iterations, etc. to restore a training. Just don't delete the previous self._dataloader_iter, then it can remember previous iterations:
https://github.com/pytorch/ignite/blob/v0.4.2/ignite/engine/engine.py#L764
Am I right?

I tried to do the step based aggregation in this way:

dataloader_iter = None
class FLhandler:
    def __call__(self, engine):
        dataloader_iter = engine._dataloader_iter
        engine.terminate()

trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), FLhandler)
# round 1
trainer.run()
# round 2
trainer._dataloader_iter = dataloader_iter
trainer.run()
...
# round N
trainer._dataloader_iter = dataloader_iter
trainer.run()

What do you think?

Thanks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 28, 2020

Hi @Nic-Ma ,

yes, you are right about that 👍 This can be an interesting approach. Thanks !
Currently, a sort of problem with calling multiple times trainer.run() is that it will trigger all the times events like STARTED, EPOCH_STARTED etc which may not what we expect. This is more or less conceptual problem of what the run is. This is something we are discussing with the team.

There is another approach (workaround) to keep control of the dataflow:

dataloader = ...

def cycle(dataloader):
        while True:
            for i in dataloader:
                yield i

dataloader_iter = cycle(dataloader)

@trainer.on(Events.ITERATION_STARTED)
def prepare_batch(engine):
    engine.state.batch = next(dataloader_iter)

trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), FLhandler)

data = list(range(len(dataloader) * num_epochs))
# round 1
trainer.run(data, max_epochs=num_epochs)
# FL sync
# ...
# round 2
trainer.run(data)
# FL sync
# ...
# round 3
trainer.run(data)
# etc

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 20, 2021

@Nic-Ma Let me close this issue as answered, feel free to reopen if needed.

@vfdev-5 vfdev-5 closed this as completed Jan 20, 2021
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants