Youโre given a 0โindexed array nums (length n) and two integers k and dist. You must split nums into k disjoint contiguous subarrays. The cost of any subarray is simply its first element. You also have a constraint on where the later subarrays can start:
If the subarrays start at indices 0 = s0 < s1 < s2 < ... < s(k-1), then you must have:
s(k-1) - s1 <= dist
Your goal: minimize the sum of subarray costs.
Constraints are big (n up to 1e5), so we need an O(n log n)-ish solution.
1) The key reduction (turning โpartitioningโ into โpick k-1 indicesโ)
Cost structure
The first subarray always starts at index 0, so nums[0] is always included in the total cost.
The remaining (k - 1) subarrays start at indices:
s1, s2, ..., s(k-1)
So the total cost is:
nums[0] + nums[s1] + nums[s2] + ... + nums[s(k-1)]
โContiguous subarraysโ doesnโt add extra restrictions here
Once you choose the start indices in increasing order, the partition is automatically valid:
- Subarray 1:
nums[0 .. s1-1] - Subarray 2:
nums[s1 .. s2-1] - …
- Subarray k:
nums[s(k-1) .. n-1]
So the problem becomes:
Choose
(k-1)indices from1..n-1that minimize the sum ofnums[i], subject tomaxIndex - minIndex <= dist.
2) From index constraint โ sliding window of length dist + 1
The condition:
s(k-1) - s1 <= dist
means all chosen indices lie inside some interval:
[L, L + dist]
That interval contains exactly dist + 1 positions.
So for every window of indices of length dist + 1 over nums[1..n-1], the best you can do in that window is:
pick the (k-1) smallest values in the window.
Therefore:
Answer = nums[0] + min_over_windows( sum of (k-1 smallest in window) )
This is exactly why the problem is tagged with sliding window + heap / ordered set.
3) Maintaining โk-1 smallest in a sliding windowโ
Classic trick: maintain two groups:
small: the (k-1) smallest elements in the current windowlarge: the rest
And keep:
sumSmall = sum(small)
Then each time the window slides:
- remove the outgoing element
- add the incoming element
- rebalance so
len(small) == k-1 - update answer with
sumSmall
Why two groups works
Because the window changes by one element at a time, we can maintain membership and only do log(window) work per step.
Time: O(n log dist)
Space: O(dist)
4) Implementation notes (Python-friendly)
In C++ youโd use multiset for easy โremove arbitrary elementโ + โget min/maxโ. In Python, we typically use two heaps + lazy deletion:
smallas a max-heap (store(-value, index))largeas a min-heap (store(value, index))- a
delayedmap to mark removed indices until they reach heap top - a
where[index]map to know whether an item currently belongs tosmallorlarge - track
small_sum
This avoids needing external libraries like sortedcontainers.
5) Python solution (Two Heaps + Lazy Deletion)
from heapq import heappush, heappop
from collections import defaultdict
from typing import List
class DualHeap:
"""
Maintain:
- 'small': k smallest elements (as max-heap via negative values)
- 'large': remaining elements (min-heap)
- small_sum: sum of values in 'small'
Supports sliding window add/remove in O(log W).
"""
def __init__(self, k: int):
self.k = k # desired size of small
self.small = [] # (-val, idx)
self.large = [] # (val, idx)
self.delayed = defaultdict(int) # idx -> pending deletions
self.where = {} # idx -> 'small' or 'large'
self.small_size = 0
self.large_size = 0
self.small_sum = 0
def _prune_small(self) -> None:
while self.small and self.delayed[self.small[0][1]]:
_, idx = heappop(self.small)
self.delayed[idx] -= 1
if self.delayed[idx] == 0:
del self.delayed[idx]
def _prune_large(self) -> None:
while self.large and self.delayed[self.large[0][1]]:
_, idx = heappop(self.large)
self.delayed[idx] -= 1
if self.delayed[idx] == 0:
del self.delayed[idx]
def _pop_small_valid(self):
self._prune_small()
neg, idx = heappop(self.small)
return -neg, idx
def _pop_large_valid(self):
self._prune_large()
val, idx = heappop(self.large)
return val, idx
def _make_balance(self) -> None:
# If total elements < k (early stage), small should just contain all of them.
target = min(self.k, self.small_size + self.large_size)
while self.small_size > target:
val, idx = self._pop_small_valid()
self.small_size -= 1
self.small_sum -= val
heappush(self.large, (val, idx))
self.large_size += 1
self.where[idx] = 'large'
while self.small_size < target:
val, idx = self._pop_large_valid()
self.large_size -= 1
heappush(self.small, (-val, idx))
self.small_size += 1
self.small_sum += val
self.where[idx] = 'small'
self._prune_small()
self._prune_large()
def add(self, val: int, idx: int) -> None:
if self.k == 0:
heappush(self.large, (val, idx))
self.large_size += 1
self.where[idx] = 'large'
return
self._prune_small()
if not self.small or val <= -self.small[0][0]:
heappush(self.small, (-val, idx))
self.small_size += 1
self.small_sum += val
self.where[idx] = 'small'
else:
heappush(self.large, (val, idx))
self.large_size += 1
self.where[idx] = 'large'
self._make_balance()
def remove(self, val: int, idx: int) -> None:
side = self.where.pop(idx, None)
if side is None:
return
self.delayed[idx] += 1
if side == 'small':
self.small_size -= 1
self.small_sum -= val
else:
self.large_size -= 1
self._prune_small()
self._prune_large()
self._make_balance()
class Solution:
def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
n = len(nums)
need = k - 1 # choose k-1 starts after index 0
window_size = dist + 1 # indices must fit in a window of length dist+1
dh = DualHeap(need)
# Initial window covers indices [1 .. dist+1]
for i in range(1, dist + 2):
dh.add(nums[i], i)
best = dh.small_sum
# Slide window by right boundary r = dist+2 .. n-1
for r in range(dist + 2, n):
dh.add(nums[r], r)
out = r - window_size # outgoing index
dh.remove(nums[out], out)
best = min(best, dh.small_sum)
return nums[0] + best
6) Walkthrough on Example 1
Example: nums = [1,3,2,6,4,2], k=3, dist=3 output is 5.
nums[0] = 1is fixed.- Need
k-1 = 2more start indices within distance<= 3. - Window size =
dist+1 = 4on indices1.. - Check each window of length 4 and take the 2 smallest values.
- The best pair is
(2, 2)โ sum =4 - Total =
1 + 4 = 5
7) Complexity
- Each slide does a constant number of heap operations: O(log(dist))
- Total slides: O(n)
- Time:
O(n log dist) - Space:
O(dist)