华为 OD 训练营 · 题解精讲
LeetCode698、划分K个相等的子集(回溯解法)
LeetCode698、划分K个相等的子集(回溯解法) 题目练习网址:https://leetcode.cn/problems/partition-to-k-equal-sum-subsets/ 视频直播讲解:2025/02/22 真题讲解(2024E卷回溯类型题目:划分k个相等子集问题)
题目描述
给定一个整数数组 nums 和一个正整数 k,找出是否有可能把这个数组分成 k 个非空子集,其总和都相等。
示例 1: 输入: nums = [4, 3, 2, 3, 5, 2, 1], k = 4 输出: True 说明: 有可能将其分成 4 个子集(5),(1,4),(2,3),(2,3)等于总和。 示例 2: 输入: nums = [1,2,3,4], k = 3 输出: False
提示: 1 <= k <= len(nums) <= 16 0 < nums[i] < 10000 每个元素的频率在 [1,4] 范围内
题目解析
这个问题有多种解法:回溯、DFS记忆化搜索、动态规划。 本篇题解主要介绍回溯解法。
往k个桶里分配n个数字 这个问题本身有点抽象,我们可以把它换成一个更加具象化的场景。 假设我们现在有k个桶和n个数字,现在我们希望把这n个数字分配在这k个桶里,使得这k个桶里各自的数字的和相等。 这很明显是一个分组问题。
其次,我们观察数据量,1 <= k <= len(nums) <= 16,并不是一个很大的数量级。 所以容易想到回溯穷举的做法。
特殊情况的排除 显然,要完成这个问题,这n个数字构成的数字nums必须满足以下条件 这些数字的和total = sum(nums)必须能够整除k,否则一定无法使得这k个桶各自的和相等 当total能够整除k时,最后每个桶各自的和per = total // k也是非常容易算出来的 这些数字中的最大值max(nums)必须不能大于per
以示例一为例,容易计算total = 20,per = 5。
所以本题可以进行一些特殊情况的排除 class Solution:
pass
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
计算所有数字nums的和total
total = sum(nums)
如果total不能整除k,直接返回False
if total % k != 0: return False
当total可以整除k时,计算各个桶最终到达的和per
per = total // k
如果nums中的最大值已经大于per
那么也一定无法完成分配
if max(nums) > per: return False
pass
贪心地从大到小考虑nums数组 回溯无非是穷举,我们需要思考当人来做这个问题的时候,是如何进行穷举的。
我们会贪心地先把大的数字挑出来,把这些更大的数字先分配到桶里。
这个贪心的原理很容易理解:大的数占整个桶的空间更多,相对而言比较“笨重”,一旦确定放在某个桶里,较少的概率会发生调整。而小的数可以更加灵活地去填补每个桶剩余的缝隙,更加方便在后面进行不同桶之间的调整。
以示例一为例,nums = [4, 3, 2, 3, 5, 2, 1],k = 4,total = 20,per = 5。 如果我们选择了数字4放在一个桶里,这个桶剩下的空间是1,我们就知道在后面只需要再找一个数字1来放到这个桶里即可。 但如果我们先选择了数字1放在一个桶里,这个桶剩下的空间是4,那么将出现较多种可能的组合,比如1+1+2,2+2,1+3,4等等,就增加了枚举的情况。
所以,基于上述的贪心思想,我们可以在回溯之前,将数字nums进行逆序排序,使得大的数字放在nums的前面而先被考虑。 nums.sort(reverse = True)
状态树的构建 既然考虑使用回溯来完成这个问题,那么自然还需要思考整个计算过程的状态树如何构建。
我们可以使用一个变量i来表示nums数组里面的元素索引。 由于nums数组本身已经进行逆序排序,所以越靠后的索引i对应的值更小。
对于每一个nums中的数字nums[i],我们都考虑它可能会被放到某一个桶中。
考虑其中的某个状态以及其下一层的状态。 假设我们已经将nums[0] = 5和nums[1] = 4分别填入第0和第1个桶里,接下来我们要将nums[2] = 3分别尝试填入每一个桶中。
考虑回溯过程中的两个核心方向:纵向和横向。 纵向遍历涉及数组索引的遍历,也就是越往下层,索引i的值会增加 横向遍历涉及桶的索引的遍历,如果把桶的索引设置为j,那么从左到右的横向遍历j会增大
注意横向遍历的过程中,在进行完group[j]的状态更新之后,只有当group[j]的值不超过per,才可以继续进行递归函数的调用。 (考虑上述回溯状态树中的前两个分枝是直接不进入递归的,这也是某种意义上的剪枝)
那么我们就可以构造出如下的dfs函数来完成基本的回溯过程,仍然遵循常规的回溯模板 class Solution: def dfs(self, group, i, nums, per, n, k):
递归终止条件:当数组索引i到达n的时候
说明整个回溯完成,退出函数
if i == n: return
横向遍历:考虑k个不同的桶
设置桶的索引为j,横向遍历每一个桶
for j in range(k):
状态更新:将nums[i]放入第j个桶中,
也就是将nums[i]加到group[j]中
group[j] += nums[i]
纵向遍历,考虑下一个数字,也就是将i+1传入dfs函数中并调用
并且group[j]的值不能超过per,否则回溯将没有意义
if group[j] <= per: self.dfs(group, i+1, nums, per, n, k)
状态回滚:将nums[i]从第j个桶中拿出
也就是group[j]减去nums[i]
group[j] -= nums[i]
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool: pass n = len(nums)
初始化长度为k的group数组,表示k个桶的求和情况
group = [0] * k
递归入口:传入i=0,表示nums[0]开始填入桶中
self.dfs(group, 0, nums, per, n, k) pass
除此之外,我们还需要考虑函数的返回值。 题目要求当我们能够找到正确分配时返回True,所以我们可以设置一个布尔类型成员变量self.ans(如果是ACM模式则设置一个全局变量)来表示是否找到了这样的分配。
容易更新上述代码为 class Solution: def dfs(self, group, i, nums, per, n, k):
递归终止条件:当数组索引i到达n的时候
说明整个回溯完成,将self.ans设置为True之后,退出函数
if i == n: self.ans = True return
横向遍历:考虑k个不同的桶
设置桶的索引为j,横向遍历每一个桶
for j in range(k):
状态更新:将nums[i]放入第j个桶中,
也就是将nums[i]加到group[j]中
group[j] += nums[i]
纵向遍历,考虑下一个数字,也就是将i+1传入dfs函数中并调用
并且group[j]的值不能超过per,否则回溯将没有意义
if group[j] <= per: self.dfs(group, i+1, nums, per, n, k)
状态回滚:将nums[i]从第j个桶中拿出
也就是group[j]减去nums[i]
group[j] -= nums[i]
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
计算所有数字nums的和total
total = sum(nums)
如果total不能整除k,直接返回False
if total % k != 0: return False
当total可以整除k时,计算各个桶最终到达的和per
per = total // k
贪心地对nums进行逆序排序,减少遍历次数
nums.sort(reverse = True)
如果nums中的最大值已经大于per
那么也一定无法完成分配
if nums[0] > per: return False
n = len(nums)
初始化成员变量self.ans为False
self.ans = False
初始化长度为k的group数组,表示k个桶的求和情况
group = [0] * k
递归入口:传入i=0,表示nums[0]开始填入桶中
self.dfs(group, 0, nums, per, n, k)
返回self.ans的值来表示是否找到了答案
return self.ans
需要注意的是,本题的横向遍历过程和以往的常规回溯题目(如子集问题、组合问题等等)做法稍有不同,甚至可能会有点反直觉。 在过往的题目中,横向遍历通常涉及到整个nums数组的遍历。 但本题的横向遍历则涉及了k个桶的遍历,nums数组中数字的选择,则放在了纵向上。
这是因为本题的状态树分岔较多,只有这样进行遍历才能够尽可能地剪枝降低时间复杂度。
剪枝优化减少状态树分岔 上述代码就已经能够通过一些简单用例了,但是提交时仍然因为超时无法通过全部用例。 所以很显然我们需要考虑回溯的剪枝,来排除一些重复状态和遍历。
当我们已经找到一组正确的的分配时,就无需再做回溯过程了,可以直接返回,退出函数。所以可以修改dfs递归函数的退出条件为 class Solution: def dfs(self, group, i, nums, per, n, k):
递归终止条件:当数组索引i到达n的时候,或者已经找到一组正确的的分配时
说明整个回溯完成,将self.ans设置为True之后,退出函数
if i == n or self.ans: self.ans = True return
pass
pass
在横向遍历的过程中,由于所有的桶都是全同的没有差别的,在更新group[j]之前,如果发现当前桶的和group[j]和上一个桶的和group[j-1]相等,那么更新group[j]将没有意义,因为更新它和之前已经考虑过的group[j-1]是一样的。 (考虑上述回溯状态树中的最后一个分岔实际上和倒数第二个分岔的分配情况是一样的) class Solution: def dfs(self, group, i, nums, per, n, k): pass
横向遍历:考虑k个不同的桶
设置桶的索引为j,横向遍历每一个桶
for j in range(k):
如果当前桶的和group[j]和前一个桶的和group[j-1]相等
则无需考虑当前桶的更新,进行剪枝
if j != 0 and group[j] == group[j-1]: continue
状态更新:将nums[i]放入第j个桶中,
也就是将nums[i]加到group[j]中
group[j] += nums[i]
纵向遍历,考虑下一个数字,也就是将i+1传入dfs函数中并调用
只有当
if group[j] <= per: self.dfs(group, i+1, nums, per, n, k)
状态回滚:将nums[i]从第j个桶中拿出
也就是group[j]减去nums[i]
group[j] -= nums[i]
加上了这两个剪枝之后,就可以顺利通过所有用例了。
与本题基本完全一致的题目有LeetCode473. 火柴拼正方形
参考代码
参考代码
class Solution:
def dfs(self, group, i, nums, per, n, k):
# 递归终止条件:当数组索引i到达n的时候,或者已经找到了一组正确的分配
# 说明整个回溯完成,将self.ans设置为True之后,退出函数
if i == n or self.ans:
self.ans = True
return
# 横向遍历:考虑k个不同的桶
# 设置桶的索引为j,横向遍历每一个桶
for j in range(k):
# 如果当前桶的和group[j]和前一个桶的和group[j-1]相等
# 则无需考虑当前桶的更新,进行剪枝
if j != 0 and group[j] == group[j-1]:
continue
# 状态更新:将nums[i]放入第j个桶中,
# 也就是将nums[i]加到group[j]中
group[j] += nums[i]
# 纵向遍历,考虑下一个数字,也就是将i+1传入dfs函数中并调用
if group[j] <= per:
self.dfs(group, i+1, nums, per, n, k)
# 状态回滚:将nums[i]从第j个桶中拿出
# 也就是group[j]减去nums[i]
group[j] -= nums[i]
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
# 计算所有数字nums的和total
total = sum(nums)
# 如果total不能整除k,直接返回False
if total % k != 0:
return False
# 当total可以整除k时,计算各个桶最终到达的和per
per = total // k
# 贪心地对nums进行逆序排序,减少遍历次数
nums.sort(reverse = True)
# 如果nums中的最大值已经大于per
# 那么也一定无法完成分配
if nums[0] > per:
return False
n = len(nums)
# 初始化成员变量self.ans为False
self.ans = False
# 初始化长度为k的group数组,表示k个桶的求和情况
group = [0] * k
# 递归入口:传入i=0,表示nums[0]开始填入桶中
self.dfs(group, 0, nums, per, n, k)
# 返回self.ans的值来表示是否找到了答案
return self.ans