-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathoqs_kem.c
234 lines (202 loc) · 8.44 KB
/
oqs_kem.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
// SPDX-License-Identifier: Apache-2.0 AND MIT
/*
* OQS OpenSSL 3 provider
*
* Code strongly inspired by OpenSSL rsa kem.
*
* ToDo: Adding hybrid alg support; More testing with more key types.
*/
#include <openssl/core_dispatch.h>
#include <openssl/core_names.h>
#include <openssl/crypto.h>
#include <openssl/ec.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/params.h>
#include <string.h>
#include "oqs_prov.h"
#ifdef NDEBUG
#define OQS_KEM_PRINTF(a)
#define OQS_KEM_PRINTF2(a, b)
#define OQS_KEM_PRINTF3(a, b, c)
#else
#define OQS_KEM_PRINTF(a) \
if (getenv("OQSKEM")) \
printf(a)
#define OQS_KEM_PRINTF2(a, b) \
if (getenv("OQSKEM")) \
printf(a, b)
#define OQS_KEM_PRINTF3(a, b, c) \
if (getenv("OQSKEM")) \
printf(a, b, c)
#endif // NDEBUG
static OSSL_FUNC_kem_newctx_fn oqs_kem_newctx;
static OSSL_FUNC_kem_encapsulate_init_fn oqs_kem_encaps_init;
static OSSL_FUNC_kem_encapsulate_fn oqs_qs_kem_encaps;
static OSSL_FUNC_kem_decapsulate_fn oqs_qs_kem_decaps;
static OSSL_FUNC_kem_freectx_fn oqs_kem_freectx;
/*
* What's passed as an actual key is defined by the KEYMGMT interface.
*/
typedef struct {
OSSL_LIB_CTX *libctx;
OQSX_KEY *kem;
} PROV_OQSKEM_CTX;
/// Common KEM functions
static void *oqs_kem_newctx(void *provctx) {
PROV_OQSKEM_CTX *pkemctx = OPENSSL_zalloc(sizeof(PROV_OQSKEM_CTX));
OQS_KEM_PRINTF("OQS KEM provider called: newctx\n");
if (pkemctx == NULL)
return NULL;
pkemctx->libctx = PROV_OQS_LIBCTX_OF(provctx);
// kem will only be set in init
return pkemctx;
}
static void oqs_kem_freectx(void *vpkemctx) {
PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
OQS_KEM_PRINTF("OQS KEM provider called: freectx\n");
oqsx_key_free(pkemctx->kem);
OPENSSL_free(pkemctx);
}
static int oqs_kem_decapsencaps_init(void *vpkemctx, void *vkem,
int operation) {
PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
OQS_KEM_PRINTF3("OQS KEM provider called: _init : New: %p; old: %p \n",
vkem, pkemctx->kem);
if (pkemctx == NULL || vkem == NULL || !oqsx_key_up_ref(vkem))
return 0;
oqsx_key_free(pkemctx->kem);
pkemctx->kem = vkem;
return 1;
}
static int oqs_kem_encaps_init(void *vpkemctx, void *vkem,
const OSSL_PARAM params[]) {
OQS_KEM_PRINTF("OQS KEM provider called: encaps_init\n");
return oqs_kem_decapsencaps_init(vpkemctx, vkem, EVP_PKEY_OP_ENCAPSULATE);
}
static int oqs_kem_decaps_init(void *vpkemctx, void *vkem,
const OSSL_PARAM params[]) {
OQS_KEM_PRINTF("OQS KEM provider called: decaps_init\n");
return oqs_kem_decapsencaps_init(vpkemctx, vkem, EVP_PKEY_OP_DECAPSULATE);
}
/// Quantum-Safe KEM functions (OQS)
static int oqs_qs_kem_encaps_keyslot(void *vpkemctx, unsigned char *out,
size_t *outlen, unsigned char *secret,
size_t *secretlen, int keyslot) {
const PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
const OQS_KEM *kem_ctx = NULL;
OQS_KEM_PRINTF("OQS KEM provider called: encaps\n");
if (pkemctx->kem == NULL) {
OQS_KEM_PRINTF("OQS Warning: OQS_KEM not initialized\n");
return -1;
}
kem_ctx = pkemctx->kem->oqsx_provider_ctx.oqsx_qs_ctx.kem;
if (pkemctx->kem->comp_pubkey == NULL ||
pkemctx->kem->comp_pubkey[keyslot] == NULL) {
OQS_KEM_PRINTF("OQS Warning: public key is NULL\n");
return -1;
}
if (outlen == NULL) {
OQS_KEM_PRINTF("OQS Warning: outlen is NULL\n");
return -1;
}
if (secretlen == NULL) {
OQS_KEM_PRINTF("OQS Warning: secretlen is NULL\n");
return -1;
}
if (out == NULL || secret == NULL) {
*outlen = kem_ctx->length_ciphertext;
*secretlen = kem_ctx->length_shared_secret;
OQS_KEM_PRINTF3("KEM returning lengths %ld and %ld\n",
kem_ctx->length_ciphertext,
kem_ctx->length_shared_secret);
return 1;
}
if (*outlen < kem_ctx->length_ciphertext) {
OQS_KEM_PRINTF("OQS Warning: out buffer too small\n");
return -1;
}
if (*secretlen < kem_ctx->length_shared_secret) {
OQS_KEM_PRINTF("OQS Warning: secret buffer too small\n");
return -1;
}
*outlen = kem_ctx->length_ciphertext;
*secretlen = kem_ctx->length_shared_secret;
return OQS_SUCCESS == OQS_KEM_encaps(kem_ctx, out, secret,
pkemctx->kem->comp_pubkey[keyslot]);
}
static int oqs_qs_kem_decaps_keyslot(void *vpkemctx, unsigned char *out,
size_t *outlen, const unsigned char *in,
size_t inlen, int keyslot) {
const PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
const OQS_KEM *kem_ctx = NULL;
OQS_KEM_PRINTF("OQS KEM provider called: decaps\n");
if (pkemctx->kem == NULL) {
OQS_KEM_PRINTF("OQS Warning: OQS_KEM not initialized\n");
return -1;
}
kem_ctx = pkemctx->kem->oqsx_provider_ctx.oqsx_qs_ctx.kem;
if (pkemctx->kem->comp_privkey == NULL ||
pkemctx->kem->comp_privkey[keyslot] == NULL) {
OQS_KEM_PRINTF("OQS Warning: private key is NULL\n");
return -1;
}
if (out == NULL) {
if (outlen != NULL) {
*outlen = kem_ctx->length_shared_secret;
}
OQS_KEM_PRINTF2("KEM returning length %ld\n",
kem_ctx->length_shared_secret);
return 1;
}
if (inlen != kem_ctx->length_ciphertext) {
OQS_KEM_PRINTF("OQS Warning: wrong input length\n");
return 0;
}
if (in == NULL) {
OQS_KEM_PRINTF("OQS Warning: in is NULL\n");
return -1;
}
if (outlen == NULL) {
OQS_KEM_PRINTF("OQS Warning: outlen is NULL\n");
return -1;
}
if (*outlen < kem_ctx->length_shared_secret) {
OQS_KEM_PRINTF("OQS Warning: out buffer too small\n");
return -1;
}
*outlen = kem_ctx->length_shared_secret;
return OQS_SUCCESS == OQS_KEM_decaps(kem_ctx, out, in,
pkemctx->kem->comp_privkey[keyslot]);
}
static int oqs_qs_kem_encaps(void *vpkemctx, unsigned char *out, size_t *outlen,
unsigned char *secret, size_t *secretlen) {
return oqs_qs_kem_encaps_keyslot(vpkemctx, out, outlen, secret, secretlen,
0);
}
static int oqs_qs_kem_decaps(void *vpkemctx, unsigned char *out, size_t *outlen,
const unsigned char *in, size_t inlen) {
return oqs_qs_kem_decaps_keyslot(vpkemctx, out, outlen, in, inlen, 0);
}
#include "oqs_hyb_kem.c"
#define MAKE_KEM_FUNCTIONS(alg) \
const OSSL_DISPATCH oqs_##alg##_kem_functions[] = { \
{OSSL_FUNC_KEM_NEWCTX, (void (*)(void))oqs_kem_newctx}, \
{OSSL_FUNC_KEM_ENCAPSULATE_INIT, (void (*)(void))oqs_kem_encaps_init}, \
{OSSL_FUNC_KEM_ENCAPSULATE, (void (*)(void))oqs_qs_kem_encaps}, \
{OSSL_FUNC_KEM_DECAPSULATE_INIT, (void (*)(void))oqs_kem_decaps_init}, \
{OSSL_FUNC_KEM_DECAPSULATE, (void (*)(void))oqs_qs_kem_decaps}, \
{OSSL_FUNC_KEM_FREECTX, (void (*)(void))oqs_kem_freectx}, \
{0, NULL}};
#define MAKE_HYB_KEM_FUNCTIONS(alg) \
const OSSL_DISPATCH oqs_##alg##_kem_functions[] = { \
{OSSL_FUNC_KEM_NEWCTX, (void (*)(void))oqs_kem_newctx}, \
{OSSL_FUNC_KEM_ENCAPSULATE_INIT, (void (*)(void))oqs_kem_encaps_init}, \
{OSSL_FUNC_KEM_ENCAPSULATE, (void (*)(void))oqs_hyb_kem_encaps}, \
{OSSL_FUNC_KEM_DECAPSULATE_INIT, (void (*)(void))oqs_kem_decaps_init}, \
{OSSL_FUNC_KEM_DECAPSULATE, (void (*)(void))oqs_hyb_kem_decaps}, \
{OSSL_FUNC_KEM_FREECTX, (void (*)(void))oqs_kem_freectx}, \
{0, NULL}};
// keep this just in case we need to become ALG-specific at some point in time
MAKE_KEM_FUNCTIONS(generic)
MAKE_HYB_KEM_FUNCTIONS(hybrid)