正确解法:
class Solution {
private:
int n;
int target;
bool dfs(int k, int tmpsum, int used, vector<int>& nums)
{
if(k == 1) return true;
for (int i = 0; i < n; ++i)
{
if(used & (1 << i)) continue;
if(tmpsum + nums[i] > target) return false; //重要的剪枝,否则会超时,例如用例 5,5,5,5,16,4,4,4,4,4,3,3,3,3,4, k = 4
else if(tmpsum + nums[i] == target && dfs(k-1, 0, (used | (1 << i)), nums)) return true;
else if(dfs(k, tmpsum + nums[i], used | (1 << i), nums)) return true;
}
return false;
}
public:
bool canPartitionKSubsets(vector<int>& nums, int k) {
n = nums.size();
int sum = 0;
for(int i : nums) sum += i;
if(sum % k > 0) return false;
target = sum / k;
sort(nums.rbegin(), nums.rend());
if(nums.front() > target) return false;
return dfs(k, 0, 0, nums);
}
};
另一种正确解法:
class Solution {
private:
vector<bool> used;
int target, n;
bool dfs(int start, int cursum, int k, vector<int>& nums)
{
if(k == 1) return true;
for (int i = start; i < n; ++i)
{
if(used[i] || cursum + nums[i] > target) continue;
used[i] = true;
if(cursum + nums[i] == target && dfs(0, 0, k-1, nums)) return true; //注意这里从0重新开始搜索
else if(dfs(i+1, cursum + nums[i], k, nums)) return true; // 从i+1往后搜索
used[i] = false;
}
return false;
}
public:
bool canPartitionKSubsets(vector<int>& nums, int k) {
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum % k > 0) return false;
n = nums.size();
target = sum / k;
sort(nums.begin(), nums.end(), greater<int>());
if(nums.front() > target) return false;
used = vector<bool>(n, false);
return dfs(0, 0, k, nums);
}
};
错误解法:
class Solution {
//错误解法,会漏掉可行解,例如 10 10 10 7 7 7 7 7 7 6 6 6, k = 3;
private:
int n;
vector<bool> used;
bool dfs(int target, vector<int>& nums)
{
if(target == 0) return true;
for (int i = 0; i < n; ++i)
{
if(used[i] || target < nums[i]) continue;
used[i] = true;
if(dfs(target - nums[i], nums)) return true;
used[i] = false;
}
return false;
}
public:
bool canPartitionKSubsets(vector<int>& nums, int k) {
n = nums.size();
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum % k > 0) return false;
sort(nums.rbegin(), nums.rend());
int target = sum / k;
if(nums.front() > target) return false;
used = vector<bool>(n, false);
int cnt = 0;
for (int i = 0; i < n; ++i)
{
if(used[i]) continue;
used[i] = true;
if(dfs(target - nums[i], nums)) cnt++;
else return false;
if(cnt == k - 1) return true;
}
return false;
}
};
上面问题的简化版,用错误解法也能过,是测试数据不全的原因。