diff --git a/Sources/Crypto/RSA/RSA.swift b/Sources/Crypto/RSA/RSA.swift index ed3bdbf..920b67d 100644 --- a/Sources/Crypto/RSA/RSA.swift +++ b/Sources/Crypto/RSA/RSA.swift @@ -147,8 +147,15 @@ public final class RSA { /// - returns: Decrypted data. /// - throws: `CryptoError` if encrypting fails. public static func decrypt(_ input: LosslessDataConvertible, padding: RSAPadding = .pkcs1, key: RSAKey) throws -> Data { - return try cipher(input, padding: padding, key: key) { - RSA_private_decrypt($0, $1, $2, $3!.convert(), $4) + switch key.type { + case .public: + return try cipher(input, padding: padding, key: key) { + RSA_public_decrypt($0, $1, $2, $3!.convert(), $4) + } + case .private: + return try cipher(input, padding: padding, key: key) { + RSA_private_decrypt($0, $1, $2, $3!.convert(), $4) + } } } @@ -163,8 +170,15 @@ public final class RSA { /// - returns: Encrypted data. /// - throws: `CryptoError` if encrypting fails. public static func encrypt(_ input: LosslessDataConvertible, padding: RSAPadding = .pkcs1, key: RSAKey) throws -> Data { - return try cipher(input, padding: padding, key: key) { - RSA_public_encrypt($0, $1, $2, $3!.convert(), $4) + switch key.type { + case .public: + return try cipher(input, padding: padding, key: key) { + RSA_public_encrypt($0, $1, $2, $3!.convert(), $4) + } + case .private: + return try cipher(input, padding: padding, key: key) { + RSA_private_encrypt($0, $1, $2, $3!.convert(), $4) + } } } diff --git a/Sources/Crypto/Utilities/CryptoError.swift b/Sources/Crypto/Utilities/CryptoError.swift index ee2d646..d8e1954 100644 --- a/Sources/Crypto/Utilities/CryptoError.swift +++ b/Sources/Crypto/Utilities/CryptoError.swift @@ -11,8 +11,7 @@ public struct CryptoError: Debuggable { /// Internal error creation from OpenSSLL internal static func openssl(identifier: String, reason: String) -> CryptoError { - let errmsg: UnsafeMutablePointer? = nil - ERR_error_string(ERR_get_error(), errmsg) + let errmsg = ERR_error_string(ERR_get_error(), nil) let cReason: String if let e = errmsg { diff --git a/Tests/CryptoTests/RSATests.swift b/Tests/CryptoTests/RSATests.swift index f600c19..c52f349 100644 --- a/Tests/CryptoTests/RSATests.swift +++ b/Tests/CryptoTests/RSATests.swift @@ -86,6 +86,20 @@ class RSATests: XCTestCase { ) XCTAssertEqual(key.type, .public) } + + // https://github.com/vapor/crypto/issues/78 + func testGH78() throws { + let passphrase = "abcdef" + + // From https://www.googleapis.com/oauth2/v3/certs + let key: RSAKey = try .components( + n: "vvAaaSpfr934Qx0ioFiWsopq7UCfLNn0zjYVbq4bvUcGSXU9kowYmQArR7WlIkjk1moffla0UV75QRaQPATva1oD5xQnnW-20haeMWTSsMgUHoN0Np9AD8ffPz-DfMJBOHIo4REL1BFFS33HSZgPl0hxJ-5UScqr4lW1JMy5XGeRho30dnmKTpakU1Oc35hFYKSea_O2SXfmbqiAkWlWkilEzgHq4pzVWiDZe4ZgfMdD4vqkSNrO_PkBFBT1mnBJztQ1h4v1jvUW-zeYYwIcPTaOX-xOTiGH9uQkcNPpe5pBrIZJqR5VNrDl_bJOmvVlhhXZSn4fkxA8kyQcZXGaTw", + e: "AQAB" + ) + let encrypted = try RSA.encrypt(passphrase, padding: .pkcs1, key: key) + let decrypted = try RSA.decrypt(encrypted, padding: .pkcs1, key: key) + XCTAssertEqual(decrypted.convert(to: String.self), passphrase) + } static var allTests = [ ("testPrivateKey", testPrivateKey), @@ -98,6 +112,7 @@ class RSATests: XCTestCase { ("testRand", testRand), ("testComps", testComps), ("testEncrypt", testEncrypt), + ("testGH78", testGH78), ] }