#include <iostream>
#include <array>
#include <iomanip>
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <cstring>
#include <openssl/rsa.h>
#include <openssl/bn.h>
// -------------------- 基本类型定义与工具函数 --------------------
const size_t MAX_WORDS = 64; // 64 * 32 = 2048 bits
using BigInt = std::array<uint32_t, MAX_WORDS>;
struct BigIntWrapper {
BigInt words{};
size_t size = 0;
};
const uint64_t BASE = 0x100000000ULL; // 2^32
// 打印大整数为十六进制字符串
void print_bigint(const std::string& label, const BigIntWrapper& a) {
if (a.size == 0 || (a.size == 1 && a.words[0] == 0)) {
std::cout << label << ": 0x0" << std::endl;
return;
}
std::cout << label << ": 0x";
for (size_t i = a.size; i-- > 0;)
std::cout << std::setfill('0') << std::setw(8) << std::hex << a.words[i];
std::cout << std::dec << std::endl;
}
// 比较两个大整数的大小
int cmp(const BigIntWrapper& a, const BigIntWrapper& b) {
if (a.size != b.size) return a.size < b.size ? -1 : 1;
for (size_t i = a.size; i-- > 0;) {
if (a.words[i] != b.words[i]) return a.words[i] < b.words[i] ? -1 : 1;
}
return 0;
}
// 大整数减法:a -= b,假设a >= b
void sub(BigIntWrapper& a, const BigIntWrapper& b) {
uint64_t borrow = 0;
for (size_t i = 0; i < a.size; ++i) {
uint64_t ai = a.words[i];
uint64_t bi = i < b.size ? b.words[i] : 0;
uint64_t val = ai - bi - borrow;
a.words[i] = (uint32_t)val;
borrow = (val >> 63) & 1;
}
while (a.size > 1 && a.words[a.size - 1] == 0) a.size--;
}
// 大整数加法:res = a + b
BigIntWrapper add(const BigIntWrapper& a, const BigIntWrapper& b) {
BigIntWrapper res;
res.size = std::max(a.size, b.size);
uint64_t carry = 0;
for (size_t i = 0; i < res.size; ++i) {
uint64_t ai = i < a.size ? a.words[i] : 0;
uint64_t bi = i < b.size ? b.words[i] : 0;
uint64_t sum = ai + bi + carry;
res.words[i] = (uint32_t)sum;
carry = sum >> 32;
}
if (carry) {
if (res.size < MAX_WORDS) {
res.words[res.size++] = (uint32_t)carry;
}
}
while (res.size > 1 && res.words[res.size - 1] == 0) res.size--;
return res;
}
// 大整数乘法:res = a * b
BigIntWrapper mul(const BigIntWrapper& a, const BigIntWrapper& b) {
BigIntWrapper res;
if ((a.size == 1 && a.words[0] == 0) || (b.size == 1 && b.words[0] == 0)) {
res.size = 1;
return res;
}
res.size = a.size + b.size;
if (res.size > MAX_WORDS) res.size = MAX_WORDS;
for (size_t i = 0; i < a.size; ++i) {
uint64_t carry = 0;
for (size_t j = 0; j < b.size; ++j) {
if (i + j >= res.size) continue;
uint64_t sum = (uint64_t)a.words[i] * b.words[j] + res.words[i + j] + carry;
res.words[i + j] = (uint32_t)sum;
carry = sum >> 32;
}
if (carry && (i + b.size < res.size)) {
res.words[i + b.size] += (uint32_t)carry;
}
}
while (res.size > 1 && res.words[res.size - 1] == 0) res.size--;
return res;
}
// 高效的二进制长除法
std::pair<BigIntWrapper, BigIntWrapper> div_mod(const BigIntWrapper& dividend, const BigIntWrapper& divisor) {
if (divisor.size == 0 || (divisor.size == 1 && divisor.words[0] == 0)) {
throw std::runtime_error("Division by zero");
}
if (cmp(dividend, divisor) < 0) {
BigIntWrapper zero; zero.size = 1;
return { zero, dividend };
}
BigIntWrapper quotient; quotient.size = 1;
BigIntWrapper remainder = dividend;
BigIntWrapper temp_divisor = divisor;
// 找到最高位对齐位置
while (cmp(remainder, temp_divisor) >= 0) {
// 左移temp_divisor直到它大于remainder
BigIntWrapper temp = temp_divisor;
size_t shift_count = 0;
while (cmp(temp, remainder) <= 0 && temp.size < MAX_WORDS && (temp.words[temp.size - 1] & 0x80000000) == 0) {
temp_divisor = temp;
temp = mul(temp, BigIntWrapper{ {2}, 1 });
shift_count++;
}
}
// 执行长除法
while (cmp(temp_divisor, divisor) >= 0) {
quotient = mul(quotient, BigIntWrapper{ {2}, 1 });
if (cmp(remainder, temp_divisor) >= 0) {
sub(remainder, temp_divisor);
quotient.words[0] |= 1;
}
temp_divisor = div_mod(temp_divisor, BigIntWrapper{ {2}, 1 }).first;
}
return { quotient, remainder };
}
// 大整数取模
BigIntWrapper mod(const BigIntWrapper& a, const BigIntWrapper& m) {
return div_mod(a, m).second;
}
// 计算n关于2^32的逆元
uint32_t inv32(uint32_t n) {
uint32_t t = 1;
t *= 2 - n * t; t *= 2 - n * t; t *= 2 - n * t; t *= 2 - n * t; t *= 2 - n * t;
return t;
}
// -------------------- 蒙哥马利乘法实现 --------------------
BigIntWrapper montgomery_reduce(BigIntWrapper t, const BigIntWrapper& m, uint32_t m_inv) {
size_t n = m.size;
if (t.size < 2 * n + 1) t.size = std::max(t.size, 2 * n + 1);
if (t.size > MAX_WORDS) t.size = MAX_WORDS;
for (size_t i = 0; i < n; ++i) {
uint32_t u = t.words[i] * m_inv;
uint64_t carry = 0;
for (size_t j = 0; j < n; ++j) {
if (i + j >= t.size) continue;
uint64_t sum = (uint64_t)u * m.words[j] + t.words[i + j] + carry;
t.words[i + j] = (uint32_t)sum;
carry = sum >> 32;
}
size_t k = i + n;
while (carry > 0 && k < t.size) {
uint64_t sum = t.words[k] + carry;
t.words[k] = (uint32_t)sum;
carry = sum >> 32;
k++;
}
}
BigIntWrapper res;
if (t.size > n) {
res.size = t.size - n;
for (size_t i = 0; i < res.size; ++i) res.words[i] = t.words[i + n];
}
else {
res.size = 1;
}
while (res.size > 1 && res.words[res.size - 1] == 0) res.size--;
if (cmp(res, m) >= 0) {
sub(res, m);
}
return res;
}
// -------------------- Barrett归约法实现 --------------------
BigIntWrapper barrett_reduce(const BigIntWrapper& x, const BigIntWrapper& m, const BigIntWrapper& mu) {
size_t k = m.size;
if (x.size < k) return x;
if (x.size > 2 * k) return mod(x, m);
// q1 = floor(x / BASE^{k-1})
BigIntWrapper q1;
if (x.size >= k) {
q1.size = x.size - (k - 1);
for (size_t i = 0; i < q1.size; ++i) q1.words[i] = x.words[i + k - 1];
}
else {
q1.size = 1;
}
// q2 = q1 * mu
BigIntWrapper q2 = mul(q1, mu);
// q3 = floor(q2 / BASE^{k+1})
BigIntWrapper q3;
if (q2.size > k + 1) {
q3.size = q2.size - (k + 1);
for (size_t i = 0; i < q3.size; ++i) q3.words[i] = q2.words[i + k + 1];
}
else {
q3.size = 1;
}
// r1 = x mod BASE^{k+1}
BigIntWrapper r1 = x;
if (r1.size > k + 1) r1.size = k + 1;
// r2 = (q3 * m) mod BASE^{k+1}
BigIntWrapper r2 = mul(q3, m);
if (r2.size > k + 1) r2.size = k + 1;
// r = r1 - r2
BigIntWrapper r;
if (cmp(r1, r2) < 0) {
BigIntWrapper base_power;
base_power.size = k + 2;
if (base_power.size > MAX_WORDS) base_power.size = MAX_WORDS;
base_power.words[base_power.size - 1] = 1;
r1 = add(r1, base_power);
}
sub(r1, r2);
r = r1;
while (cmp(r, m) >= 0) {
sub(r, m);
}
return r;
}
// -------------------- RSA实现 --------------------
// 将BIGNUM转换为BigIntWrapper
BigIntWrapper bn_to_bigintwrapper(const BIGNUM* bn) {
BigIntWrapper result;
int num_bytes = BN_num_bytes(bn);
unsigned char* buffer = new unsigned char[num_bytes];
BN_bn2bin(bn, buffer);
// 将字节数组转换为words数组
result.size = (num_bytes + 3) / 4;
if (result.size > MAX_WORDS) result.size = MAX_WORDS;
for (int i = 0; i < num_bytes && i / 4 < MAX_WORDS; i += 4) {
uint32_t word = 0;
for (int j = 0; j < 4 && i + j < num_bytes; j++) {
word |= static_cast<uint32_t>(buffer[i + j]) << (8 * j);
}
result.words[i / 4] = word;
}
delete[] buffer;
return result;
}
// 将BigIntWrapper转换为BIGNUM
BIGNUM* bigintwrapper_to_bn(const BigIntWrapper& num) {
int num_bytes = num.size * sizeof(uint32_t);
unsigned char* buffer = new unsigned char[num_bytes];
// 将words数组转换为字节数组
for (size_t i = 0; i < num.size; i++) {
uint32_t word = num.words[i];
buffer[i * 4] = word & 0xFF;
buffer[i * 4 + 1] = (word >> 8) & 0xFF;
buffer[i * 4 + 2] = (word >> 16) & 0xFF;
buffer[i * 4 + 3] = (word >> 24) & 0xFF;
}
BIGNUM* result = BN_bin2bn(buffer, num_bytes, nullptr);
delete[] buffer;
return result;
}
// RSA密钥生成
void rsa_key_gen(BigIntWrapper& n, BigIntWrapper& e, BigIntWrapper& d) {
RSA* rsa = RSA_new();
BIGNUM* bn_e = BN_new();
BN_set_word(bn_e, 65537); // 常用的公钥指数
// 生成RSA密钥
if (RSA_generate_key_ex(rsa, MAX_WORDS * 32, bn_e, nullptr) != 1) {
throw std::runtime_error("RSA密钥生成失败");
}
// 使用正确的getter函数获取RSA密钥组件
const BIGNUM* rsa_n = RSA_get0_n(rsa);
const BIGNUM* rsa_e = RSA_get0_e(rsa);
const BIGNUM* rsa_d = RSA_get0_d(rsa);
// 转换为自定义BigIntWrapper格式
n = bn_to_bigintwrapper(rsa_n);
e = bn_to_bigintwrapper(rsa_e);
d = bn_to_bigintwrapper(rsa_d);
// 清理资源
RSA_free(rsa);
BN_free(bn_e);
}
// 使用蒙哥马利乘法进行模幂运算
BigIntWrapper montgomery_exp(const BigIntWrapper& base, const BigIntWrapper& exp, const BigIntWrapper& mod) {
// 计算蒙哥马利参数
uint32_t m_inv = 0 - inv32(mod.words[0]);
// 计算R^2 mod m
BigIntWrapper R2;
R2.size = 2 * mod.size + 1;
if (R2.size > MAX_WORDS) R2.size = MAX_WORDS;
R2.words[R2.size - 1] = 1;
R2 = mod(R2, mod);
// 转换为蒙哥马利形式
BigIntWrapper base_mont = montgomery_reduce(mul(base, R2), mod, m_inv);
BigIntWrapper result_mont = montgomery_reduce(R2, mod, m_inv); // 初始化为1的蒙哥马利形式
// 执行模幂运算
for (size_t i = exp.size; i-- > 0;) {
uint32_t word = exp.words[i];
for (int j = 31; j >= 0; j--) {
// 平方
result_mont = montgomery_reduce(mul(result_mont, result_mont), mod, m_inv);
// 如果当前位为1,则乘以base
if ((word >> j) & 1) {
result_mont = montgomery_reduce(mul(result_mont, base_mont), mod, m_inv);
}
}
}
// 转换回普通形式
return montgomery_reduce(result_mont, mod, m_inv);
}
// RSA加密
BigIntWrapper rsa_encrypt(const BigIntWrapper& msg, const BigIntWrapper& e, const BigIntWrapper& n) {
return montgomery_exp(msg, e, n);
}
// RSA解密
BigIntWrapper rsa_decrypt(const BigIntWrapper& ctxt, const BigIntWrapper& d, const BigIntWrapper& n) {
return montgomery_exp(ctxt, d, n);
}
// -------------------- 主函数 --------------------
int main() {
try {
std::cout << "生成RSA密钥对..." << std::endl;
BigIntWrapper n, e, d;
rsa_key_gen(n, e, d);
print_bigint("公钥模数 (n)", n);
print_bigint("公钥指数 (e)", e);
print_bigint("私钥指数 (d)", d);
// 测试消息
BigIntWrapper msg;
msg.size = 1;
msg.words[0] = 123456789;
print_bigint("原始消息", msg);
std::cout << "加密中..." << std::endl;
BigIntWrapper ctxt = rsa_encrypt(msg, e, n);
print_bigint("加密结果", ctxt);
std::cout << "解密中..." << std::endl;
BigIntWrapper decrypted = rsa_decrypt(ctxt, d, n);
print_bigint("解密结果", decrypted);
// 验证解密结果
if (cmp(msg, decrypted) == 0) {
std::cout << "成功!RSA加密解密验证通过。" << std::endl;
}
else {
std::cout << "错误!RSA加密解密验证失败。" << std::endl;
}
// 测试蒙哥马利乘法和Barrett归约法
std::cout << "\n测试蒙哥马利乘法和Barrett归约法..." << std::endl;
BigIntWrapper a, b, m;
a.size = 2; a.words[0] = 0x12345678; a.words[1] = 0x9ABCDEF0;
b.size = 2; b.words[0] = 0xFEDCBA98; b.words[1] = 0x76543210;
m.size = 2; m.words[0] = 0xFFFFFFFF; m.words[1] = 0xFFFFFFFF;
print_bigint("a", a);
print_bigint("b", b);
print_bigint("m", m);
// 蒙哥马利乘法
uint32_t m_inv = 0 - inv32(m.words[0]);
BigIntWrapper R2;
R2.size = 2 * m.size + 1;
if (R2.size > MAX_WORDS) R2.size = MAX_WORDS;
R2.words[R2.size - 1] = 1;
R2 = mod(R2, m);
BigIntWrapper a_mont = montgomery_reduce(mul(a, R2), m, m_inv);
BigIntWrapper b_mont = montgomery_reduce(mul(b, R2), m, m_inv);
BigIntWrapper c_mont = montgomery_reduce(mul(a_mont, b_mont), m, m_inv);
BigIntWrapper mont_result = montgomery_reduce(c_mont, m, m_inv);
print_bigint("蒙哥马利乘法结果", mont_result);
// Barrett归约法
size_t k = m.size;
BigIntWrapper mu;
mu.size = k * 2 + 1;
if (mu.size > MAX_WORDS) mu.size = MAX_WORDS;
mu.words[mu.size - 1] = 1;
mu = div_mod(mu, m).first;
BigIntWrapper ab = mul(a, b);
BigIntWrapper barrett_result = barrett_reduce(ab, m, mu);
print_bigint("Barrett归约结果", barrett_result);
// 直接计算 a*b mod m
BigIntWrapper direct_result = mod(ab, m);
print_bigint("直接计算结果", direct_result);
// 比较结果
if (cmp(mont_result, barrett_result) == 0 && cmp(barrett_result, direct_result) == 0) {
std::cout << "成功!蒙哥马利乘法和Barrett归约法验证通过。" << std::endl;
}
else {
std::cout << "错误!蒙哥马利乘法和Barrett归约法验证失败。" << std::endl;
}
}
catch (const std::exception& e) {
std::cerr << "错误: " << e.what() << std::endl;
return 1;
}
return 0;
}
代码中R2 = mod(R2, mod);一句报错:C++ 在没有适当 operator() 的情况下调用类类型的对象或将函数转换到指向函数的类型,请修改
最新发布