Java—bouncycastle支持国密SM2的公钥加密算法
Java—bouncycastle支持国密SM2的公钥加密算法
java代码是依赖 BouncyCastle 类库,经修改此类库中的 SM2Engin 类的原码而来,用于支持 SM2 公钥加密算法,符合:《GM/T 0009-2012: SM2密码算法使用规范》。可以使用 gmssl 工具进行交互测试(http://gmssl.org)
引入jar:
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
<version>1.68</version>
</dependency>
代码:
import java.io.IOException; import java.math.BigInteger; import java.security.SecureRandom; import org.bouncycastle.asn1.ASN1EncodableVector; import org.bouncycastle.asn1.ASN1InputStream; import org.bouncycastle.asn1.ASN1Integer; import org.bouncycastle.asn1.ASN1OctetString; import org.bouncycastle.asn1.ASN1Sequence; import org.bouncycastle.asn1.DEROctetString; import org.bouncycastle.asn1.DERSequence; import org.bouncycastle.crypto.CipherParameters; import org.bouncycastle.crypto.Digest; import org.bouncycastle.crypto.InvalidCipherTextException; import org.bouncycastle.crypto.digests.SM3Digest; import org.bouncycastle.crypto.params.ECDomainParameters; import org.bouncycastle.crypto.params.ECKeyParameters; import org.bouncycastle.crypto.params.ECPrivateKeyParameters; import org.bouncycastle.crypto.params.ECPublicKeyParameters; import org.bouncycastle.crypto.params.ParametersWithRandom; import org.bouncycastle.math.ec.ECFieldElement; import org.bouncycastle.math.ec.ECMultiplier; import org.bouncycastle.math.ec.ECPoint; import org.bouncycastle.math.ec.FixedPointCombMultiplier; import org.bouncycastle.util.Arrays; import org.bouncycastle.util.BigIntegers; import org.bouncycastle.util.Memoable; import org.bouncycastle.util.Pack; /** * 自义的 SM2 公钥加密、私钥解密引擎, 用于替换 BouncyCastle 中的 SM2Engine 的实现, * 可用于非 java 开发的系统之间交换数据时的公钥加密、私钥解密,完全符合 GM/T 以下两个标准: * <br/> * <li>《GM/T 0003.4-2012: SM2椭圆曲线公钥密钥算法:第4部分:公钥加密算法》</li> * <li>《GM/T 0009-2012: SM2密码算法使用规范》</li> * <br/> * <p/> * <li>1.加密密文默认为 C1||C3||C2, 输出内容为 ASN.1 编码。符合 《GM/T 0009-2012: SM2密码算法使用规范》 标准 * </li> * <li>2.加密密文设置为 C1||C2||C3,则输出内容不是 ASN.1 编码。与 《GM/T 0009-2012 ...》 标准不兼容。</li> * <br/> * 建议可以使用 GmSSL 工具进行交互测试,请参考 {@code http://gmssl.org} * <p> * SM2 public key encryption engine - based on https://tools.ietf.org/html/draft-shen-sm2-ecdsa-02. * * * @author YangHongFeng * @since 2020/5/28 creation */ public class MySm2Engine { public enum Mode { C1C2C3, C1C3C2; } private final Digest digest; private final Mode mode; private boolean forEncryption; private ECKeyParameters ecKey; private ECDomainParameters ecParams; private int curveLength; private SecureRandom random; /*** * 默认采用国标:C1||C3||C2 */ public MySm2Engine() { this(new SM3Digest(), Mode.C1C3C2); } public MySm2Engine(Mode mode) { this(new SM3Digest(), mode); } public MySm2Engine(Digest digest) { this(digest, Mode.C1C2C3); } public MySm2Engine(Digest digest, Mode mode) { if (mode == null) { throw new IllegalArgumentException("mode cannot be NULL"); } this.digest = digest; this.mode = mode; } /** * 初始化 * @param forEncryption true-公钥加密, false-私钥解密 * @param param 密码参数,从中获取公或私钥、及椭圆曲线相关参数 */ public void init(boolean forEncryption, CipherParameters param) { this.forEncryption = forEncryption; if (forEncryption) { ParametersWithRandom rParam = (ParametersWithRandom) param; ecKey = (ECKeyParameters) rParam.getParameters(); ecParams = ecKey.getParameters(); ECPoint s = ((ECPublicKeyParameters) ecKey).getQ().multiply(ecParams.getH()); if (s.isInfinity()) { throw new IllegalArgumentException("invalid key: [h]Q at infinity"); } random = rParam.getRandom(); } else { ecKey = (ECKeyParameters) param; ecParams = ecKey.getParameters(); } curveLength = (ecParams.getCurve().getFieldSize() + 7) / 8; } /** * 进行加密、或解密 * @param in * @param inOff * @param inLen * @return * @throws InvalidCipherTextException */ public byte[] processBlock( byte[] in, int inOff, int inLen) throws IOException, InvalidCipherTextException { if (forEncryption) { return encrypt(in, inOff, inLen); } else { return decrypt(in, inOff, inLen); } } public int getOutputSize(int inputLen) { return (1 + 2 * curveLength) + inputLen + digest.getDigestSize(); } protected ECMultiplier createBasePointMultiplier() { return new FixedPointCombMultiplier(); } private byte[] encrypt(byte[] in, int inOff, int inLen) throws IOException { byte[] c2 = new byte[inLen]; System.arraycopy(in, inOff, c2, 0, c2.length); ECMultiplier multiplier = createBasePointMultiplier(); ECPoint c1P; ECPoint kPB; do { BigInteger k = nextK(); c1P = multiplier.multiply(ecParams.getG(), k).normalize(); // c1 = c1P.getEncoded(false); kPB = ((ECPublicKeyParameters) ecKey).getQ().multiply(k).normalize(); kdf(digest, kPB, c2); } while (notEncrypted(c2, in, inOff)); byte[] c3 = new byte[digest.getDigestSize()]; addFieldElement(digest, kPB.getAffineXCoord()); digest.update(in, inOff, inLen); addFieldElement(digest, kPB.getAffineYCoord()); digest.doFinal(c3, 0); switch (mode) { case C1C3C2: // 2020/06/01 按国标组装为 ANS.1 编码 final ASN1EncodableVector vector = new ASN1EncodableVector(); vector.add(new ASN1Integer(c1P.getXCoord().toBigInteger())); vector.add(new ASN1Integer(c1P.getYCoord().toBigInteger())); vector.add(new DEROctetString(c3)); vector.add(new DEROctetString(c2)); return new DERSequence(vector).getEncoded(); default: byte[] c1 = c1P.getEncoded(false); return Arrays.concatenate(c1, c2, c3); } } private byte[] decrypt(byte[] in, int inOff, int inLen) throws InvalidCipherTextException, IOException { ECPoint c1P; // = ecParams.getCurve().decodePoint(c1); byte[] inHash ; byte[] inCipherData; if (mode == Mode.C1C3C2) { // 2020/06/01 按国标 ANS.1 编码 进行解码 ASN1InputStream inputStream = new ASN1InputStream(in); ASN1Sequence seq = (ASN1Sequence) inputStream.readObject(); if (seq.size() != 4){ throw new InvalidCipherTextException("invalid cipher text"); } int index = 0; // C1 == XCoordinate 、YCoordinate BigInteger x = ((ASN1Integer) seq.getObjectAt(index ++)).getPositiveValue(); // YCoordinate BigInteger y = ((ASN1Integer) seq.getObjectAt(index ++)).getPositiveValue(); // XCoord 、YCoord ==> CEPoint (C1) c1P = ecParams.getCurve().createPoint(x, y); // HASH (C3) inHash = ((ASN1OctetString)seq.getObjectAt(index ++)).getOctets(); // CipherText (C2) inCipherData = ((ASN1OctetString)seq.getObjectAt(index)).getOctets(); } else { // C1 byte[] c1 = new byte[curveLength * 2 + 1]; System.arraycopy(in, inOff, c1, 0, c1.length); c1P = ecParams.getCurve().decodePoint(c1); // C2 == inCipherData int digestSize = this.digest.getDigestSize(); inCipherData = new byte[inLen - c1.length - digestSize]; System.arraycopy(in, inOff + c1.length, inCipherData, 0, inCipherData.length); // C3 == Hash inHash = new byte[digestSize]; System.arraycopy(in, inOff + c1.length + inCipherData.length, inHash, 0, inHash.length); } // 解密 ==> inCipherData; ECPoint s = c1P.multiply(ecParams.getH()); if (s.isInfinity()) { throw new InvalidCipherTextException("[h]C1 at infinity"); } c1P = c1P.multiply(((ECPrivateKeyParameters)ecKey).getD()).normalize(); kdf(digest, c1P, inCipherData); // 动态计算已解密的明文的摘要并比较 byte[] cipherDigest = new byte[digest.getDigestSize()]; addFieldElement(digest, c1P.getAffineXCoord()); digest.update(inCipherData, 0, inCipherData.length); addFieldElement(digest, c1P.getAffineYCoord()); digest.doFinal(cipherDigest, 0); int check = 0; if (mode == Mode.C1C3C2) { for (int i = 0; i != cipherDigest.length; i++) { check |= cipherDigest[i] ^ inHash[i]; } } else { for (int i = 0; i != cipherDigest.length; i++) { check |= cipherDigest[i] ^ inHash[i]; } } // Arrays.fill(c1, (byte)0); Arrays.fill(cipherDigest, (byte)0); if (check != 0) { Arrays.fill(inCipherData, (byte)0); throw new InvalidCipherTextException("invalid cipher text"); } // return c2; return inCipherData; } private boolean notEncrypted(byte[] encData, byte[] in, int inOff) { for (int i = 0; i != encData.length; i++) { if (encData[i] != in[inOff + i]) { return false; } } return true; } private void kdf(Digest digest, ECPoint c1, byte[] encData) { int digestSize = digest.getDigestSize(); byte[] buf = new byte[Math.max(4, digestSize)]; int off = 0; Memoable memo = null; Memoable copy = null; if (digest instanceof Memoable) { addFieldElement(digest, c1.getAffineXCoord()); addFieldElement(digest, c1.getAffineYCoord()); memo = (Memoable) digest; copy = memo.copy(); } int ct = 0; while (off < encData.length) { if (memo != null) { memo.reset(copy); } else { addFieldElement(digest, c1.getAffineXCoord()); addFieldElement(digest, c1.getAffineYCoord()); } Pack.intToBigEndian(++ct, buf, 0); digest.update(buf, 0, 4); digest.doFinal(buf, 0); int xorLen = Math.min(digestSize, encData.length - off); xor(encData, buf, off, xorLen); off += xorLen; } } private void xor(byte[] data, byte[] kdfOut, int dOff, int dRemaining) { for (int i = 0; i != dRemaining; i++) { data[dOff + i] ^= kdfOut[i]; } } private BigInteger nextK() { int qBitLength = ecParams.getN().bitLength(); BigInteger k; do { k = BigIntegers.createRandomBigInteger(qBitLength, random); } while (k.equals(BigIntegers.ZERO) || k.compareTo(ecParams.getN()) >= 0); return k; } private void addFieldElement(Digest digest, ECFieldElement v) { byte[] p = BigIntegers.asUnsignedByteArray(curveLength, v.toBigInteger()); digest.update(p, 0, p.length); } }