4Sum
Total Accepted: 45122 Total Submissions: 207795
Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.
Note:
Elements in a quadruplet (a,b,c,d) must be in non-descending order. (ie, a ≤ b ≤ c ≤ d)
The solution set must not contain duplicate quadruplets.
For example, given array S = {1 0 -1 0 -2 2}, and target = 0.
A solution set is:
(-1, 0, 0, 1)
(-2, -1, 1, 2)
(-2, 0, 0, 2)
这道题我采用的是将之前的3Sum算法进行扩展,只是此前的3Sum采用的是哈希表是直接开了一个最大值与最小值之差那么大的数组,但在这道题之中,有一组测试用例的范围在千万这个数量级,所以开这样一个哈希表,就会变得空间开销极大,而且极有可能内存越界,因此改用177大小的哈希表,然后采用拉链法来解决冲突。因为改用了拉链法,所以维护这个hash表就要增添许多与之相关的函数,但这道题的本质仍然只是3Sum的扩展,下面是完整代码:
#include <iostream>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
using namespace std;
struct HashEntry
{
int value;
int count; //how many numbers having this value in input
struct HashEntry *next;
};
void copyHash(struct HashEntry **hashTable, struct HashEntry **newHashTable)
{
struct HashEntry *p, *q;
for(int i = 0; i < 177; i++)
{
p = hashTable[i];
q = newHashTable[i];
while(p != NULL)
{
q->value = p->value;
q->count = p->count;
p = p->next;
q = q->next;
}
}
}
int GetHash(struct HashEntry **hashTable, int num)
{
int index = num % 177;
if(index < 0)
index *= -1;
if(hashTable[index] == NULL)
return 0;
else
{
struct HashEntry *p = hashTable[index];
do
{
if(p->value == num)
{
return p->count;
}
p = p->next;
}while(p != NULL);
return 0;
}
}
void SetHash(struct HashEntry **hashTable, int num, int v)
{
int index = num % 177;
if(index < 0)
index *= -1;
if(hashTable[index] == NULL)
{}
else
{
struct HashEntry *p = hashTable[index];
do
{
if(p->value == num)
{
p->count = v;
return;
}
p = p->next;
}while(p != NULL);
}
}
void CreateHash(struct HashEntry **hashTable, int num)
{
int index = num % 177;
if(index < 0)
index *= -1;
if(hashTable[index] == NULL)
{
hashTable[index] = (struct HashEntry*)malloc(sizeof(struct HashEntry));
hashTable[index]->value = num;
hashTable[index]->count = 1;
hashTable[index]->next = NULL;
}
else
{
struct HashEntry *p = hashTable[index];
struct HashEntry *tmp = hashTable[index];
do
{
if(p->value == num)
{
p->count++;
break;
}
p = p->next;
}while(p != NULL);
if(p == NULL)
{
hashTable[index] = (struct HashEntry*)malloc(sizeof(struct HashEntry));
hashTable[index]->value = num;
hashTable[index]->count = 1;
hashTable[index]->next = tmp;
}
}
}
void twoSum(int* nums, int numsSize, int* returnSize, struct HashEntry **hashTable, int target, int start, int**result, int previous, int current)
{
int NeedValue;
int count = *returnSize;
for(int i = start; i < numsSize; i++)
{
if(!GetHash(hashTable, nums[i]))
continue;
SetHash(hashTable, nums[i], GetHash(hashTable, nums[i])-1);
NeedValue = target - nums[i];
if(GetHash(hashTable, NeedValue))
{
result[count] = (int*)malloc(sizeof(int) * 4);
result[count][0] = nums[i] < NeedValue ? nums[i] : NeedValue;
result[count][1] = nums[i] > NeedValue ? nums[i] : NeedValue;
if(current >= result[count][1])
result[count][2] = current;
else if(current < result[count][1] && current >= result[count][0])
{
result[count][2] = result[count][1];
result[count][1] = current;
}
else
{
result[count][2] = result[count][1];
result[count][1] = result[count][0];
result[count][0] = current;
}
result[count][3] = previous;
for(int i = 2; i >= -1; i--)
{
if(i == -1)
{
result[count][0] = previous;
break;
}
if(previous < result[count][i])
{
result[count][i+1] = result[count][i];
}
else
{
result[count][i+1] = previous;
break;
}
}
count++;
}
SetHash(hashTable, NeedValue, 0);
SetHash(hashTable, nums[i], 0);
}
*returnSize = count;
}
void threeSum(int* nums, int numsSize, int* returnSize, struct HashEntry **hashTable, int t, int start, int **result, int current, struct HashEntry ** newHashTable)
{
for(int i = start; i < numsSize; i++)
{
int target = t - nums[i];
copyHash(hashTable, newHashTable);
if(GetHash(newHashTable, nums[i]) > 0)
{
SetHash(newHashTable, nums[i], GetHash(newHashTable, nums[i])-1);
}
else if(GetHash(newHashTable, nums[i]) == 0)
continue;
twoSum(nums, numsSize, returnSize, newHashTable, target, i+1, result, current, nums[i]);
SetHash(hashTable, nums[i], 0);
}
}
int** fourSum(int* nums, int numsSize, int target, int* returnSize)
{
if(numsSize == 0)
return NULL;
struct HashEntry **hashTable = (struct HashEntry**)malloc(sizeof(struct HashEntry*) * 177);
struct HashEntry **newHashTable = (struct HashEntry**)malloc(sizeof(struct HashEntry*) * 177);
struct HashEntry **newHashTable2 = (struct HashEntry**)malloc(sizeof(struct HashEntry*) * 177);
for(int i = 0; i < 177; i++)
{
hashTable[i] = newHashTable[i] = newHashTable2[i] = NULL;
}
for(int i = 0; i < numsSize; i++)
{
CreateHash(hashTable, nums[i]);
CreateHash(newHashTable, nums[i]);
CreateHash(newHashTable2, nums[i]);
}
int **result = (int**)malloc(sizeof(int*) * 300);
*returnSize = 0;
for(int i = 0; i < numsSize; i++)
{
int t = target - nums[i];
copyHash(hashTable, newHashTable);
if(GetHash(newHashTable, nums[i]) > 0)
{
SetHash(newHashTable, nums[i], GetHash(newHashTable, nums[i])-1);
}
else if(GetHash(newHashTable, nums[i]) == 0)
continue;
threeSum(nums, numsSize, returnSize, newHashTable, t, i+1, result, nums[i], newHashTable2);
SetHash(hashTable, nums[i], 0);
}
return result;
}
int main()
{
int nums[300], numsSize = 0, returnSize = 0, target = 0;
cout<<"Input the numSize:"<<endl;
cin>>numsSize;
for(int i = 0; i < numsSize; i++)
{
scanf("%d,", &nums[i]);
}
cout<<"Input the target:"<<endl;
cin>>target;
int **result = fourSum(nums, numsSize, target, &returnSize);
for(int i = 0; i < returnSize; i++)
{
cout<<result[i][0]<<" "<<result[i][1]<<" "<<result[i][2]<<" "<<result[i][3]<<endl;;
}
return 0;
}
以上这种方法效率并不高,主要是由于每次的hash表的复制过程会很费时间,网上看到别人的一种解法 http://blog.youkuaiyun.com/doc_sgl/article/details/12427833,比我的方法更简洁,同时效率更高,此处记录一下。
int** fourSum(int* nums, int numsSize, int target, int* returnSize)
{
if(numsSize < 4)
return NULL;
qsort(nums, numsSize, sizeof(int), cmp);
int **result = (int**)malloc(sizeof(int*) * 300);
*returnSize = 0;
如果少于4个输入,那么必然不可能找出4个数之和等于target,所以返回NULL。然后用qsort对输入的数据进行排序,我采用的cmp函数如下:
int cmp(const void *a, const void *b)
{
return *(int*)a - *(int*)b;
}
qsort采用的比较函数的特点就是,如果返回值为负,那么保持元素位置不变,如果返回值为正,那么交换两个元素的位置。
然后分配result数组,大小为300,也就是说最多存储300个输出结果,这存在隐患,因为如果输入数据量很大,那么输出结果超过300的情况,但是对于这一题来说,300已经够用了。
int begin, end;
for(int i = 0; i < numsSize - 3; i++)
{
if(i >= 1 && nums[i] == nums[i-1])
continue;
for(int j = i + 1; j < numsSize - 2; j++)
{
if(j >= i + 2 && nums[j] == nums[j-1])
continue;
begin = j + 1;
end = numsSize - 1;
while(begin < end)
{
if(nums[i] + nums[j] + nums[begin] + nums[end] == target)
{
result[(*returnSize)++] = (int*)malloc(sizeof(int) * 4);
result[*returnSize - 1][0] = nums[i];
result[*returnSize - 1][1] = nums[j];
result[*returnSize - 1][2] = nums[begin];
result[*returnSize - 1][3] = nums[end];
do
{
begin++;
}while(nums[begin-1] == nums[begin]);
do
{
end--;
}while(nums[end+1] == nums[end]);
}
else if(nums[i] + nums[j] + nums[begin] + nums[end] > target)
{
do
{
end--;
}while(nums[end+1] == nums[end]);
}
else if(nums[i] + nums[j] + nums[begin] + nums[end] < target)
{
do
{
begin++;
}while(nums[begin-1] == nums[begin]);
}
}
}
}
return result;
这么长一段就是算法的核心了,最外层循环的i对应的是四个数中的第一个的索引,后面的if语句是为了防止重复,例如:输入是:0 0 0 -1 1 2 3,target是0,当i=0时,我们就找出了0 0 -1 1(对应索引0,1,3,4)这个组合,而当i=1是,我们仍然会找出0 0 -1 1(对应索引1,2,3,4)这个组合,虽然对应的索引不同,但是产生的组合都是0 0 -1 1, 因为题目要求结果中不能有重复,因此添加判断,如果当前的数之前已经有检测过了,那么就跳过,进入下一次循环,注意如果当前索引值是当前循环的第一个时,不用进行检测,因为即使它与前面一个相等也没关系,因为前面一个是前一个数字选的,后一个数字与前一个数字选相同的数是可以的,例如:0 0 -1 1,我们要避免的是前一轮已经选了某个数(假设为a)之后,下一轮循环的时候又选一次a(因为可能有多个a连在一起),这样就会产生重复。
然后进入第二层循环,同样要先判断重复。然后定义两个值begin和end,begin表示的是第三个数的索引值,它从第二个数的后一个开始取,end表示第四个数的索引值,它从最后一个开始取。然后试着把这四个数加起来,如果刚好和target相等,那么我们就找到了一个组合,然后把begin往后移一格,把end往前移一格,如果只移动其中一个,那么即使取到新取到的和和target相等,那么这个组合也不然与我们上一个找到组合相同。在begin和end移动的过程中,都要注意避免重复,如果新的值与当前值相等,那么跳过它,寻找下一个新的值。
如果四个数加起来之和大于target,我们就把end往前移一格,因为越前面的值越小。如果四个数加起来之和小于target,我们就把begin往后移一格,因为越后面的值越大,同样在移动过程中要避免重复。
下面是完整程序代码:
#include <iostream>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
using namespace std;
int cmp(const void *a, const void *b)
{
return *(int*)a - *(int*)b;
}
int** fourSum(int* nums, int numsSize, int target, int* returnSize)
{
if(numsSize < 4)
return NULL;
qsort(nums, numsSize, sizeof(int), cmp);
int **result = (int**)malloc(sizeof(int*) * 300);
*returnSize = 0;
int begin, end;
for(int i = 0; i < numsSize - 3; i++)
{
if(i >= 1 && nums[i] == nums[i-1])
continue;
for(int j = i + 1; j < numsSize - 2; j++)
{
if(j >= i + 2 && nums[j] == nums[j-1])
continue;
begin = j + 1;
end = numsSize - 1;
while(begin < end)
{
if(nums[i] + nums[j] + nums[begin] + nums[end] == target)
{
result[(*returnSize)++] = (int*)malloc(sizeof(int) * 4);
result[*returnSize - 1][0] = nums[i];
result[*returnSize - 1][1] = nums[j];
result[*returnSize - 1][2] = nums[begin];
result[*returnSize - 1][3] = nums[end];
do
{
begin++;
}while(nums[begin-1] == nums[begin]);
do
{
end--;
}while(nums[end+1] == nums[end]);
}
else if(nums[i] + nums[j] + nums[begin] + nums[end] > target)
{
do
{
end--;
}while(nums[end+1] == nums[end]);
}
else if(nums[i] + nums[j] + nums[begin] + nums[end] < target)
{
do
{
begin++;
}while(nums[begin-1] == nums[begin]);
}
}
}
}
return result;
}
int main()
{
int nums[300], numsSize = 0, returnSize = 0, target = 0;
cout<<"Input the numSize:"<<endl;
cin>>numsSize;
for(int i = 0; i < numsSize; i++)
{
scanf("%d,", &nums[i]);
}
cout<<"Input the target:"<<endl;
cin>>target;
int **result = fourSum(nums, numsSize, target, &returnSize);
for(int i = 0; i < returnSize; i++)
{
cout<<result[i][0]<<" "<<result[i][1]<<" "<<result[i][2]<<" "<<result[i][3]<<endl;;
}
return 0;
}