1. 前言
网上有很多关于RSA的介绍,大神阮一峰 http://www.ruanyifeng.com/blog/2013/06/rsa_algorithm_part_one.html 都写了相关的博客。 为了引出 HTTPS, 编写一个小示例。
需要注意的点:
欧几里得算法:辗转相除求最大公约数。 如果最大公约数为1, 则两数互素:
/**
* 欧几里得定理
* 辗转相除求最大公约数 可用于判断互质
* @param a 数1
* @param b 数2
* @return 最大公约数
*/
public static int max(int a, int b) {
if (a > b) {
int tmp = a;
a = b;
b = tmp;
}
int r;
return ( (r = b%a) == 0) ? a : max(a, r);
}
扩展的欧几里得算法:可以求最大公约数以及
乘法逆元:
/**
* 使用 扩展的欧几里得算法 求乘法逆元
* ax + ny = 1
* y = -k
* @return x : a 对 n 的乘法逆元
*/
private static int extendGcd(int a, int n) {
int x2=0, x3=n, y2=1, y3=a, q, t2, t3;
while (true) {
if (y3 == 0)
return 0;
if (y3 == 1) {
return y2 < 0 ? y2+n : y2;
}
q = x3 / y3;
t2 = x2 - q * y2;
t3 = x3 - q * y3;
x2 = y2;
x3 = y3;
y2 = t2;
y3 = t3;
}
}
当然要说这个欧拉定理有什么用处~~~自然是用到的人才知道好,用不到的人只知道他可以简化幂运算。
质数 p, q
ψ(N) = ψ(p * q) = ψ(p) * ψ(q) = (p-1)*(q-1);
任何合数都可以分解为质数的乘积的形式。
2. 只适用于本次展示的代码片段
只适用于小于128且大于0的数的编码:import java.math.BigInteger;
public class RSA {
/**
* 欧几里得定理
* 辗转相除求最大公约数 可用于判断互质
* @param a 数1
* @param b 数2
* @return 最大公约数
*/
public static int max(int a, int b) {
if (a > b) {
int tmp = a;
a = b;
b = tmp;
}
int r;
return ( (r = b%a) == 0) ? a : max(a, r);
}
// 选取两个大素数 p, q
// 注意别取的太小 否则求余的时候会很不精确
// 毕竟只是演示, 只对byte起作用, 乘积大于128即可.
// 显然, 只支持正数
private static int p = 11;
private static int q = 13;
// 得到 N = p * q
private static BigInteger N = BigInteger.valueOf(p*q);
// 则 ψ(N) = ψ(p * q) = ψ(p) * ψ(q) = (p-1)*(q-1) (欧拉定理)
private static int r = (p-1)*(q-1);
// 取任意一与 ψ(N) 互质的小于 ψ(N)的数.
private static int e = 97;
// 得到这个数关于 ψ(N) 的乘法逆元
// 此时的公钥对为(e, N), 私钥对为(d, N). 其余所有数据销毁
private static int d = extendGcd(e, r);
/**
* 使用 扩展的欧几里得算法 求乘法逆元
* ax + ny = 1
* y = -k
* @return x : a 对 n 的乘法逆元
*/
private static int extendGcd(int a, int n) {
int x2=0, x3=n, y2=1, y3=a, q, t2, t3;
while (true) {
if (y3 == 0)
return 0;
if (y3 == 1) {
return y2 < 0 ? y2+n : y2;
}
q = x3 / y3;
t2 = x2 - q * y2;
t3 = x3 - q * y3;
x2 = y2;
x3 = y3;
y2 = t2;
y3 = t3;
}
}
/**
* 因为是N的取值是在整数范围内的, 此示例使用int型作为返回方便查看
*
* 比如 in=2; type=3;
* return in^type % N = 2^3 % (97*101) = 8
*
* @param in 被编码的值
* @param type 秘钥: 公钥/私钥
* @return 编码后的值
*
*/
private static int code(int in, int type) {
return BigInteger.valueOf(in).pow(type).mod(N).intValue();
}
/**
* 只支持正数
* @param res 被编码数组
* @param type 编码类型 公/私钥匙
* @return 编码后的数组
*/
private static byte[] code(byte[] res, int type) {
byte[] b = new byte[res.length];
for (int i = 0; i < res.length; i ++) {
b[i] = (byte) code(res[i], type);
}
return b;
}
public static byte[] rsaEncode(byte[] res) {
return code(res, e);
}
public static byte[] rsaDecode(byte[] res) {
return code(res, d);
}
public static void main(String[] args) {
int needs = 100;
int needsEncode = code(needs, e);
System.out.println("needs="+needs+", encode="+needsEncode);
int needsDecode = code(needsEncode, d);
System.out.println("needs="+needs+", decode="+needsDecode);
String code = "hello";
byte[] encode = rsaEncode(code.getBytes());
System.out.println(new String(encode)); // 不知所云
byte[] decode = rsaDecode(encode);
System.out.println(new String(decode)); // 还原
}
}
如果希望将其扩大到整个Integer范围内, 可以稍作修改实现。 注意 byte 的强转
3. JAVA中比较通用的RSA写法
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import javax.crypto.Cipher;
/**
* RSA:
* 罗纳德·李维斯特(Ron [R]ivest)、阿迪·萨莫尔(Adi [S]hamir)和伦纳德·阿德曼(Leonard [A]dleman)
* <p/>
* 字符串格式的密钥在未在特殊说明情况下都为BASE64编码格式<br/>
* 由于非对称加密速度极其缓慢,一般文件不使用它来加密而是使用对称加密,<br/>
* 非对称加密算法可以用来对对称加密的密钥加密,这样保证密钥的安全也就保证了数据的安全
* <p/>
* 部分摘录
*
* @see http://blog.youkuaiyun.com/keda8997110/article/details/16823361
*/
public class RSAUtils {
/**
* 加密算法RSA
*/
public static final String KEY_ALGORITHM = "RSA";
/**
* RSA最大加密明文大小
*/
private static final int MAX_ENCRYPT_BLOCK = 117;
/**
* RSA最大解密密文大小
*/
private static final int MAX_DECRYPT_BLOCK = 128;
private static RSAPublicKey publicKey;
private static RSAPrivateKey privateKey;
static {
KeyPairGenerator keyPairGen = null;
try {
keyPairGen = KeyPairGenerator.getInstance(KEY_ALGORITHM);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
keyPairGen.initialize(1024);
KeyPair keyPair = keyPairGen.generateKeyPair();
publicKey = (RSAPublicKey) keyPair.getPublic();
privateKey = (RSAPrivateKey) keyPair.getPrivate();
}
/**
* 私钥解密
*
* @param encryptedData 已加密数据
* @return
* @throws Exception
*/
public static byte[] decode(byte[] encryptedData) throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
Key privateK = privateKey;
Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
cipher.init(Cipher.DECRYPT_MODE, privateK);
int inputLen = encryptedData.length;
ByteArrayOutputStream out = new ByteArrayOutputStream();
int offSet = 0;
byte[] cache;
// 对数据分段解密
while (inputLen - offSet > 0) {
if (inputLen - offSet > MAX_DECRYPT_BLOCK) {
cache = cipher.doFinal(encryptedData, offSet, MAX_DECRYPT_BLOCK);
} else {
cache = cipher.doFinal(encryptedData, offSet, inputLen - offSet);
}
out.write(cache, 0, cache.length);
offSet += MAX_DECRYPT_BLOCK;
}
byte[] decryptedData = out.toByteArray();
out.close();
return decryptedData;
}
/**
* 公钥加密
*
* @param data 源数据
* @return
* @throws Exception
*/
public static byte[] encode(byte[] data)
throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
Key publicK = publicKey;
// 对数据加密
Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
cipher.init(Cipher.ENCRYPT_MODE, publicK);
int inputLen = data.length;
ByteArrayOutputStream out = new ByteArrayOutputStream();
int offSet = 0;
byte[] cache;
// 对数据分段加密
while (inputLen - offSet > 0) {
if (inputLen - offSet > MAX_ENCRYPT_BLOCK) {
cache = cipher.doFinal(data, offSet, MAX_ENCRYPT_BLOCK);
} else {
cache = cipher.doFinal(data, offSet, inputLen - offSet);
}
out.write(cache, 0, cache.length);
offSet += MAX_ENCRYPT_BLOCK;
}
byte[] encryptedData = out.toByteArray();
out.close();
return encryptedData;
}
public static void main(String[] args) throws Exception {
String source = "china中国";
byte[] encodedData = RSAUtils.encode(source.getBytes());
System.out.println("encode:\t" + new String(encodedData)); // 不知所云
byte[] decodedData = RSAUtils.decode(encodedData);
System.out.println("decode: \t" + new String(decodedData)); // 成功解码
}
}
这里面也可以微微看出java对字符编码的一些小问题。
有的时候前面的 "encode:\t"是显示不出来的
造成这样的原因是: jdk将byte转化为char[] 是委托给 Charsets.jar里面的各种charset来完成的。
UTF_8.java:
public int decode(byte[] sa, int sp, int len, char[] da) {
final int sl = sp + len;
int dp = 0;
int dlASCII = Math.min(len, da.length);
ByteBuffer bb = null; // only necessary if malformed
// ASCII only optimized loop
while (dp < dlASCII && sa[sp] >= 0)
da[dp++] = (char) sa[sp++];
while (sp < sl) {
int b1 = sa[sp++];
if (b1 >= 0) {
// 1 byte, 7 bits: 0xxxxxxx
da[dp++] = (char) b1;
} else if ((b1 >> 5) == -2) {
// 2 bytes, 11 bits: 110xxxxx 10xxxxxx
if (sp < sl) {
int b2 = sa[sp++];
if (isMalformed2(b1, b2)) {
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
sp--; // malformedN(bb, 2) always returns 1
} else {
da[dp++] = (char) (((b1 << 6) ^ b2)^
(((byte) 0xC0 << 6) ^
((byte) 0x80 << 0)));
}
continue;
}
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
return dp;
} else if ((b1 >> 4) == -2) {
// 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx
if (sp + 1 < sl) {
int b2 = sa[sp++];
int b3 = sa[sp++];
if (isMalformed3(b1, b2, b3)) {
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
sp -=3;
bb = getByteBuffer(bb, sa, sp);
sp += malformedN(bb, 3).length();
} else {
da[dp++] = (char)((b1 << 12) ^
(b2 << 6) ^
(b3 ^
(((byte) 0xE0 << 12) ^
((byte) 0x80 << 6) ^
((byte) 0x80 << 0))));
}
continue;
}
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
return dp;
} else if ((b1 >> 3) == -2) {
// 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
if (sp + 2 < sl) {
int b2 = sa[sp++];
int b3 = sa[sp++];
int b4 = sa[sp++];
int uc = ((b1 << 18) ^
(b2 << 12) ^
(b3 << 6) ^
(b4 ^
(((byte) 0xF0 << 18) ^
((byte) 0x80 << 12) ^
((byte) 0x80 << 6) ^
((byte) 0x80 << 0))));
if (isMalformed4(b2, b3, b4) ||
// shortest form check
!Character.isSupplementaryCodePoint(uc)) {
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
sp -= 4;
bb = getByteBuffer(bb, sa, sp);
sp += malformedN(bb, 4).length();
} else {
da[dp++] = Character.highSurrogate(uc);
da[dp++] = Character.lowSurrogate(uc);
}
continue;
}
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
return dp;
} else {
if (malformedInputAction() != CodingErrorAction.REPLACE)
return -1;
da[dp++] = replacement().charAt(0);
sp--;
bb = getByteBuffer(bb, sa, sp);
CoderResult cr = malformedN(bb, 1);
if (!cr.isError()) {
// leading byte for 5 or 6-byte, but don't have enough
// bytes in buffer to check. Consumed rest as malformed.
return dp;
}
sp += cr.length();
}
}
return dp;
}
}
其中期作用的地方就在将byte转为char那个while循环里面。
而char的编码,就必须遵从于Unicode的标准了。
因此看起来就会觉得怪怪的