main.cpp
#include <random>
#include <chrono>
#include <thread>
#include <atomic>
#include <iostream>
using namespace std;
// 排序数据大小
constexpr int n = 10000000;
// 最大并发单位实际数量
static atomic<int> P = 12;
static int a[50000100] = {0};
static int c[50000100] = {0};
static int b[50000100] = {0};
static void Merge(int l, int r);
static void MergeSort(int l, int r);
// 不给用户暴露无意义的参数
static void MultiMergeSort(int l, int r, const bool incFlag = false);
int main(int argc, char* argv[]) {
if (P < 1) {
return -1;
}
std::random_device randomDevice;
std::default_random_engine engine(randomDevice());
std::uniform_int_distribution<int> distribution(0, INT32_MAX);
for (int i = 1; i <= n; i++) {
a[i] = distribution(engine);
}
std::copy(begin(a), end(a), begin(c));
{
const auto startTime1 = std::chrono::high_resolution_clock::now();
{
MergeSort(1, n);
}
const auto endTime1 = std::chrono::high_resolution_clock::now();
auto spentTime1 =
std::chrono::duration_cast<std::chrono::microseconds>(
endTime1 - startTime1
).count();
cout << "[Single thread] mode spends: " << spentTime1 << "ms" << endl;
for (int i = 1; i <= std::min(n, 50); i++)///
cout << a[i] << " ";
cout << endl << endl;
}
std::copy(begin(c), end(c), begin(a));
{
const auto startTime2 = std::chrono::high_resolution_clock::now();
MultiMergeSort(1, n);
const auto endTime2 = std::chrono::high_resolution_clock::now();
auto spentTime2 =
std::chrono::duration_cast<std::chrono::microseconds>(
endTime2 - startTime2
).count();
cout << "[Multi(mix) thread] mode spends: " << spentTime2 << "ms" << endl;
for(int i = 1;i <= std::min(n, 50); i++)
cout << a[i] << " ";
}
}
// 多线程归并排序
void MultiMergeSort(int l, int r, const bool incFlag) {
if (l >= r) {
P++;
return;
}
int mid = (l + r) / 2;
std::thread LeftRegion;
std::thread RightRegion;
if (P > 0) {
P--;
LeftRegion = thread(MultiMergeSort, l, mid, true);
}
else {
MergeSort(l, mid);
}
if (P > 0) {
P--;
RightRegion = thread(MultiMergeSort, mid + 1, r, true);
}
else {
MergeSort(mid + 1, r);
}
if (LeftRegion.joinable())
LeftRegion.join();
if (RightRegion.joinable())
RightRegion.join();
Merge(l, r);
if (incFlag) {
P++;
}
}
// 单线程归并排序
void MergeSort(int l, int r) {
if (l >= r)
return;
int mid = (l + r) / 2;
MergeSort(l, mid);
MergeSort(mid + 1, r);
Merge(l, r);
}
// 合并LR区间
void Merge(int l, int r) {
int mid = (l + r) / 2;
int p = l, p1 = l, p2 = mid + 1;
while (p1 <= mid && p2 <= r)
b[p++] = a[p1] < a[p2] ? a[p1++] : a[p2++];
while (p1 <= mid)
b[p++] = a[p1++];
while (p2 <= r)
b[p++] = a[p2++];
for (int i = l; i <= r; i++)
a[i] = b[i];
}
result
[Single thread] mode spends: 2038998ms
80 654 765 831 977 1235 1795 1809 1932 2111 2515 2617 2861 2898 2954 3038 3392 3444 3447 3566 4945 5153 5708 5760 6332 7
078 7089 7122 7342 7679 7806 7856 8126 8440 8698 8933 8968 9041 9182 9183 9272 9284 9882 10088 10382 10419 10558 10939 1
1092 11133
[Multi(mix) thread] mode spends: 563816ms
80 654 765 831 977 1235 1795 1809 1932 2111 2515 2617 2861 2898 2954 3038 3392 3444 3447 3566 4945 5153 5708 5760 6332 7
078 7089 7122 7342 7679 7806 7856 8126 8440 8698 8933 8968 9041 9182 9183 9272 9284 9882 10088 10382 10419 10558 10939 1
1092 11133
notice
并发单元P的数量最接近真线程数量比较好,主要运行逻辑是根据当前是否还有剩余的并发单元来判断是否使用新线程执行归并排序的一个子方法