-
-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
How to train on part of the dataset #1382
Comments
Hi @Nic-Ma
I'd say you can attach a handler to perform things for the model every 10 iterations and call once trainer = ...
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def do_something_with_model():
# ...
trainer.run(big_dataset_loader, max_epochs=1) Otherwise, we have a request here: #1371 that may be related to your need ... Let me know if it is something you need or above solution works for you. |
Hi @vfdev-5 , Thanks for your suggestion, your solution should work for most of cases.
So in another word, the request is to resume training from previous iteration and epoch, can we do it in ignite? The problem of your event-handler solution is that all the logic is controlled by Thanks. |
Hi @Nic-Ma OK, I understand better the need. Unfortunately, there is no simple way to perform Currently, main issue is to correctly restore the dataflow from iterations 10, 20 etc. If it could be possible to make the steps per epoch vs 10 iterations then a solution could be from ignite.engine import Engine, Events
engine = Engine(lambda e, b: None)
def once_at_start(engine, _):
return engine.state.epoch == 0
def once_at_end(engine, _):
return engine.state.epoch == 3
engine.add_event_handler(Events.STARTED(once_at_start), lambda x: print("started"))
engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("{} epoch started".format(x.state.epoch)))
engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("{} epoch completed".format(x.state.epoch)))
engine.add_event_handler(Events.COMPLETED(once_at_end), lambda x: print("completed"))
state = engine.run([0, 1, 2], max_epochs=1)
print("Do something else")
state = engine.run([0, 1, 2], max_epochs=state.max_epochs + 1)
print("Do something else")
engine.run([0, 1, 2], max_epochs=state.max_epochs + 1) with output
PS: We can reduce epoch length to 10 iteration, but it has its own implications... |
Hi @vfdev-5 , Thanks for your detailed suggestions with code. Thanks. |
Hi @Nic-Ma sounds good ! Anyway, we also would like to propose a clean solution for Engine as a generator. |
Hi @vfdev-5 , I tried to add multiple threads to control the logic, for example: thread 1 is FL client, thread 2 is ignite training loop. Thanks. |
Hi @Nic-Ma Intersting question, could you please provide a link or a code snippet that I could run from my side with FL client and ignite training loop ? It'd help a lot to suggest a solution for distributed computation too.
I think what you would like to do is the following: fl_client = None
if idist.get_rank() == 0:
fl_client = ...
@fl_trainer.on(Events.ITERATION_COMPLETED(every=10))
@idist.one_rank_only(rank=0, with_barrier=True)
def fl_sync_model():
fl_client.send(ddp_model.module) where idist.one_rank_only executes the handler on the rank 0 only with barrier : dist.barrier()
if dist.get_rank() == 0:
fl_client.send(ddp_model.module)
dist.barrier() |
Hi @vfdev-5 , Your solution is interesting, I am also investigating distributed training recently. Thanks. |
You can consider the barrier as an explicit call to synchronize all processes (e.g call all_reduce on a dummy input). So, yes, you can do dist.barrier()
# this should be called from all processes
dist.broadcast(...)
# or abort from master
if dist.get_rank() == 0 and something_wrong:
raise RuntimeError("abc")
dist.barrier() For sending messages, I'd say is more tricky and probably would require to add some communication methods (in order to be executable with torch dist, horovod, xla). PS. Sorry, wrong button pushed |
Hi @vfdev-5 , Sounds good a plan, let me try to develop a simple demo program to verify it first. Thanks. |
@Nic-Ma this should be confirmed but I think in the end all processes will be killed by a timeout on
Yes, something like that. I'd do it with all gather like that: dist.barrier()
should_terminate = False
if dist.get_rank() == 0 and something_wrong:
should_terminate = True
should_terminate = idist.all_gather(int(should_terminate))
if any(should_terminate):
raise RuntimeError("stop")
dist.barrier() |
Cool, your trick seems very useful, let me make a demo to verify first. |
Hi @vfdev-5 , I developed a simulation program according to our discussion and sent it to your email. Thanks. |
Hi @Nic-Ma thanks, I received you email and take a look a bit later today. |
Hi @vfdev-5 , Actually, I double-checked your I tried to do the step based aggregation in this way: dataloader_iter = None
class FLhandler:
def __call__(self, engine):
dataloader_iter = engine._dataloader_iter
engine.terminate()
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), FLhandler)
# round 1
trainer.run()
# round 2
trainer._dataloader_iter = dataloader_iter
trainer.run()
...
# round N
trainer._dataloader_iter = dataloader_iter
trainer.run() What do you think? Thanks. |
Hi @Nic-Ma , yes, you are right about that 👍 This can be an interesting approach. Thanks ! There is another approach (workaround) to keep control of the dataflow: dataloader = ...
def cycle(dataloader):
while True:
for i in dataloader:
yield i
dataloader_iter = cycle(dataloader)
@trainer.on(Events.ITERATION_STARTED)
def prepare_batch(engine):
engine.state.batch = next(dataloader_iter)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), FLhandler)
data = list(range(len(dataloader) * num_epochs))
# round 1
trainer.run(data, max_epochs=num_epochs)
# FL sync
# ...
# round 2
trainer.run(data)
# FL sync
# ...
# round 3
trainer.run(data)
# etc |
@Nic-Ma Let me close this issue as answered, feel free to reopen if needed. |
❓ Questions/Help/Support
Hi @vfdev-5 ,
We have some requirements for very big dataset:
See we have 100 iterations for 1 epoch, we want to call
run()
to train on the first 10 iterations data, do some other things for the model and then callrun()
to train on the second 10 iterations data, ...Is it possible in ignite now?
I found the the iter for dataloader in ignite is always from beginning?
https://github.com/pytorch/ignite/blob/v0.4.2/ignite/engine/engine.py#L771
Thanks.
The text was updated successfully, but these errors were encountered: