@@ -481,8 +481,13 @@ static void llama_params_fit_impl(
481481 } else {
482482 LLAMA_LOG_INFO (" %s: filling dense-only layers back-to-front:\n " , __func__);
483483 }
484- uint32_t n_unassigned = hp_ngl;
485484 for (int id = nd - 1 ; id >= 0 ; id--) {
485+ uint32_t n_unassigned = hp_ngl;
486+ for (size_t jd = id + 1 ; jd < nd; ++jd) {
487+ assert (n_unassigned >= ngl_per_device[jd].n_layer );
488+ n_unassigned -= ngl_per_device[jd].n_layer ;
489+ }
490+
486491 std::vector<ngl_t > ngl_per_device_high = ngl_per_device;
487492 ngl_per_device_high[id].n_layer = n_unassigned;
488493 if (hp_nex > 0 ) {
@@ -491,7 +496,9 @@ static void llama_params_fit_impl(
491496 if (ngl_per_device_high[id].n_layer > 0 ) {
492497 std::vector<int64_t > mem_high = get_memory_for_layers (__func__, ngl_per_device_high, overflow_bufts, partial_moe);
493498 if (mem_high[id] > targets[id]) {
499+ assert (ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer );
494500 uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer ;
501+ LLAMA_LOG_DEBUG (" %s: start filling device %" PRIu32 " , delta=%" PRIu32 " \n " , __func__, id, delta);
495502 while (delta > 1 ) {
496503 uint32_t step_size = int64_t (delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
497504 step_size = std::max (step_size, uint32_t (1 ));
@@ -505,20 +512,19 @@ static void llama_params_fit_impl(
505512 const std::vector<int64_t > mem_test = get_memory_for_layers (__func__, ngl_per_device_test, overflow_bufts, partial_moe);
506513
507514 if (mem_test[id] <= targets[id]) {
508- ngl_per_device = ngl_per_device_test;
509- mem = mem_test;
510- n_unassigned -= ngl_per_device[id].n_layer ;
515+ ngl_per_device = ngl_per_device_test;
516+ mem = mem_test;
511517 LLAMA_LOG_DEBUG (" %s: set ngl_per_device[%d].n_layer=%" PRIu32 " \n " , __func__, id, ngl_per_device[id].n_layer );
512518 } else {
513519 ngl_per_device_high = ngl_per_device_test;
514520 mem_high = mem_test;
515- LLAMA_LOG_DEBUG (" %s: set ngl_per_device_high[%d].n_layer=%" PRIu32 " \n " , __func__, id, ngl_per_device [id].n_layer );
521+ LLAMA_LOG_DEBUG (" %s: set ngl_per_device_high[%d].n_layer=%" PRIu32 " \n " , __func__, id, ngl_per_device_high [id].n_layer );
516522 }
517523 delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer ;
518524 }
519525 } else {
520- ngl_per_device = ngl_per_device_high ;
521- n_unassigned -= ngl_per_device[id]. n_layer ;
526+ assert (ngl_per_device_high[id]. n_layer == n_unassigned) ;
527+ ngl_per_device = ngl_per_device_high ;
522528 LLAMA_LOG_DEBUG (" %s: set ngl_per_device[%d].n_layer=%" PRIu32 " \n " , __func__, id, ngl_per_device[id].n_layer );
523529 }
524530 }
0 commit comments