Catch error in CommandBuffer and poison the events#3523
Conversation
b13de7b to
1d33462
Compare
angeloskath
left a comment
There was a problem hiding this comment.
I left a couple of comments.
I think the main issue is that check_error in the Event/Fence wait is not doing what it's meant to since the calling thread is practically never the thread that scheduled the work.
|
I changed the implementation to propagate the error from
So error in one stream would be propagated to all other streams via events. (Except for CPU stream which currently just throws and crashes.) There is no good way to test it though, I'm mostly just using |
|
This honestly looks like an awesome cleanup! I am curious about any overheads... probably overthinking it. I want to run some generations and small model training to verify no regression and then we can merge. Love the fence cleanup in particular. |
|
I did a simple benchmarking by running For main: $ mlx_lm.benchmark --model meta-llama/Llama-3.1-8B-Instruct -n 10 -p 64 -g 512
Timing with prompt_tokens=64, generation_tokens=512, batch_size=1.
Trial 1: prompt_tps=371.523, generation_tps=24.235, peak_memory=16.175, total_time=21.350
Trial 2: prompt_tps=369.580, generation_tps=24.232, peak_memory=16.175, total_time=21.354
Trial 3: prompt_tps=354.408, generation_tps=24.237, peak_memory=16.175, total_time=21.357
Trial 4: prompt_tps=353.003, generation_tps=24.240, peak_memory=16.176, total_time=21.355
Trial 5: prompt_tps=368.560, generation_tps=24.238, peak_memory=16.176, total_time=21.350
Trial 6: prompt_tps=370.512, generation_tps=24.241, peak_memory=16.176, total_time=21.345
Trial 7: prompt_tps=354.048, generation_tps=24.238, peak_memory=16.176, total_time=21.357
Trial 8: prompt_tps=351.635, generation_tps=24.239, peak_memory=16.177, total_time=21.357
Trial 9: prompt_tps=371.316, generation_tps=24.239, peak_memory=16.177, total_time=21.347
Trial 10: prompt_tps=368.632, generation_tps=24.239, peak_memory=16.177, total_time=21.348
Averages: prompt_tps=363.322, generation_tps=24.238, peak_memory=16.176for this branch: $ mlx_lm.benchmark --model meta-llama/Llama-3.1-8B-Instruct -n 10 -p 64 -g 512
Timing with prompt_tokens=64, generation_tokens=512, batch_size=1.
Trial 1: prompt_tps=354.694, generation_tps=24.242, peak_memory=16.175, total_time=21.353
Trial 2: prompt_tps=371.278, generation_tps=24.237, peak_memory=16.175, total_time=21.348
Trial 3: prompt_tps=370.507, generation_tps=24.240, peak_memory=16.175, total_time=21.347
Trial 4: prompt_tps=352.574, generation_tps=24.246, peak_memory=16.176, total_time=21.350
Trial 5: prompt_tps=354.255, generation_tps=24.236, peak_memory=16.176, total_time=21.358
Trial 6: prompt_tps=355.678, generation_tps=24.242, peak_memory=16.176, total_time=21.353
Trial 7: prompt_tps=370.067, generation_tps=24.238, peak_memory=16.176, total_time=21.349
Trial 8: prompt_tps=371.358, generation_tps=24.230, peak_memory=16.177, total_time=21.355
Trial 9: prompt_tps=353.896, generation_tps=24.246, peak_memory=16.177, total_time=21.350
Trial 10: prompt_tps=352.977, generation_tps=24.235, peak_memory=16.177, total_time=21.360
Averages: prompt_tps=360.728, generation_tps=24.239, peak_memory=16.176the generation seemed to become slightly faster but it is probably just variability. Logically I think this branch added overhead when waiting events, but also reduced overhead by reducing the number of completion handlers. |
Close #2670.
When we got error in CommandBuffer's completion handler, save it, and then throw the error later in a few safe points:
In this way the program would be able to catch the error, and could continue running safely after recovering from the error.
The downside is that the timing of error throwing would be delayed much later until the program run into the check points, which is actually quite similar to how CUDA handles errors.