LeetCode 3013 Solution: Divide an Array Into Subarrays With Minimum Cost II

Youโ€™re given a 0โ€‘indexed array nums (length n) and two integers k and dist. You must split nums into k disjoint contiguous subarrays. The cost of any subarray is simply its first element. You also have a constraint on where the later subarrays can start:

If the subarrays start at indices 0 = s0 < s1 < s2 < ... < s(k-1), then you must have:

s(k-1) - s1 <= dist

Your goal: minimize the sum of subarray costs.

Constraints are big (n up to 1e5), so we need an O(n log n)-ish solution.


1) The key reduction (turning โ€œpartitioningโ€ into โ€œpick k-1 indicesโ€)

Cost structure

The first subarray always starts at index 0, so nums[0] is always included in the total cost.

The remaining (k - 1) subarrays start at indices:

s1, s2, ..., s(k-1)

So the total cost is:

nums[0] + nums[s1] + nums[s2] + ... + nums[s(k-1)]

โ€œContiguous subarraysโ€ doesnโ€™t add extra restrictions here

Once you choose the start indices in increasing order, the partition is automatically valid:

  • Subarray 1: nums[0 .. s1-1]
  • Subarray 2: nums[s1 .. s2-1]
  • Subarray k: nums[s(k-1) .. n-1]

So the problem becomes:

Choose (k-1) indices from 1..n-1 that minimize the sum of nums[i], subject to maxIndex - minIndex <= dist.


2) From index constraint โ†’ sliding window of length dist + 1

The condition:

s(k-1) - s1 <= dist

means all chosen indices lie inside some interval:

[L, L + dist]

That interval contains exactly dist + 1 positions.

So for every window of indices of length dist + 1 over nums[1..n-1], the best you can do in that window is:

pick the (k-1) smallest values in the window.

Therefore:

Answer = nums[0] + min_over_windows( sum of (k-1 smallest in window) )

This is exactly why the problem is tagged with sliding window + heap / ordered set.


3) Maintaining โ€œk-1 smallest in a sliding windowโ€

Classic trick: maintain two groups:

  • small: the (k-1) smallest elements in the current window
  • large: the rest

And keep:

  • sumSmall = sum(small)

Then each time the window slides:

  1. remove the outgoing element
  2. add the incoming element
  3. rebalance so len(small) == k-1
  4. update answer with sumSmall

Why two groups works

Because the window changes by one element at a time, we can maintain membership and only do log(window) work per step.

Time: O(n log dist)
Space: O(dist)


4) Implementation notes (Python-friendly)

In C++ youโ€™d use multiset for easy โ€œremove arbitrary elementโ€ + โ€œget min/maxโ€. In Python, we typically use two heaps + lazy deletion:

  • small as a max-heap (store (-value, index))
  • large as a min-heap (store (value, index))
  • a delayed map to mark removed indices until they reach heap top
  • a where[index] map to know whether an item currently belongs to small or large
  • track small_sum

This avoids needing external libraries like sortedcontainers.


5) Python solution (Two Heaps + Lazy Deletion)

from heapq import heappush, heappop
from collections import defaultdict
from typing import List

class DualHeap:
    """
    Maintain:
      - 'small': k smallest elements (as max-heap via negative values)
      - 'large': remaining elements (min-heap)
      - small_sum: sum of values in 'small'

    Supports sliding window add/remove in O(log W).
    """
    def __init__(self, k: int):
        self.k = k  # desired size of small
        self.small = []  # (-val, idx)
        self.large = []  # (val, idx)
        self.delayed = defaultdict(int)  # idx -> pending deletions
        self.where = {}  # idx -> 'small' or 'large'
        self.small_size = 0
        self.large_size = 0
        self.small_sum = 0

    def _prune_small(self) -> None:
        while self.small and self.delayed[self.small[0][1]]:
            _, idx = heappop(self.small)
            self.delayed[idx] -= 1
            if self.delayed[idx] == 0:
                del self.delayed[idx]

    def _prune_large(self) -> None:
        while self.large and self.delayed[self.large[0][1]]:
            _, idx = heappop(self.large)
            self.delayed[idx] -= 1
            if self.delayed[idx] == 0:
                del self.delayed[idx]

    def _pop_small_valid(self):
        self._prune_small()
        neg, idx = heappop(self.small)
        return -neg, idx

    def _pop_large_valid(self):
        self._prune_large()
        val, idx = heappop(self.large)
        return val, idx

    def _make_balance(self) -> None:
        # If total elements < k (early stage), small should just contain all of them.
        target = min(self.k, self.small_size + self.large_size)

        while self.small_size > target:
            val, idx = self._pop_small_valid()
            self.small_size -= 1
            self.small_sum -= val
            heappush(self.large, (val, idx))
            self.large_size += 1
            self.where[idx] = 'large'

        while self.small_size < target:
            val, idx = self._pop_large_valid()
            self.large_size -= 1
            heappush(self.small, (-val, idx))
            self.small_size += 1
            self.small_sum += val
            self.where[idx] = 'small'

        self._prune_small()
        self._prune_large()

    def add(self, val: int, idx: int) -> None:
        if self.k == 0:
            heappush(self.large, (val, idx))
            self.large_size += 1
            self.where[idx] = 'large'
            return

        self._prune_small()
        if not self.small or val <= -self.small[0][0]:
            heappush(self.small, (-val, idx))
            self.small_size += 1
            self.small_sum += val
            self.where[idx] = 'small'
        else:
            heappush(self.large, (val, idx))
            self.large_size += 1
            self.where[idx] = 'large'

        self._make_balance()

    def remove(self, val: int, idx: int) -> None:
        side = self.where.pop(idx, None)
        if side is None:
            return

        self.delayed[idx] += 1
        if side == 'small':
            self.small_size -= 1
            self.small_sum -= val
        else:
            self.large_size -= 1

        self._prune_small()
        self._prune_large()
        self._make_balance()


class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        n = len(nums)
        need = k - 1               # choose k-1 starts after index 0
        window_size = dist + 1     # indices must fit in a window of length dist+1

        dh = DualHeap(need)

        # Initial window covers indices [1 .. dist+1]
        for i in range(1, dist + 2):
            dh.add(nums[i], i)

        best = dh.small_sum

        # Slide window by right boundary r = dist+2 .. n-1
        for r in range(dist + 2, n):
            dh.add(nums[r], r)
            out = r - window_size   # outgoing index
            dh.remove(nums[out], out)
            best = min(best, dh.small_sum)

        return nums[0] + best

6) Walkthrough on Example 1

Example: nums = [1,3,2,6,4,2], k=3, dist=3 output is 5.

  • nums[0] = 1 is fixed.
  • Need k-1 = 2 more start indices within distance <= 3.
  • Window size = dist+1 = 4 on indices 1..
  • Check each window of length 4 and take the 2 smallest values.
  • The best pair is (2, 2) โ†’ sum = 4
  • Total = 1 + 4 = 5

7) Complexity

  • Each slide does a constant number of heap operations: O(log(dist))
  • Total slides: O(n)
  • Time: O(n log dist)
  • Space: O(dist)