Move src to src/lib, include to src/include, test to src/test.
[deb_shairplay.git] / src / lib / rsakey.c
1 #include <stdlib.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5
6 #include "rsakey.h"
7 #include "rsapem.h"
8 #include "base64.h"
9 #include "crypto/crypto.h"
10
11 #define RSA_MIN_PADLEN 8
12 #define MAX_KEYLEN 512
13
14 struct rsakey_s {
15 int keylen; /* length of modulus in bytes */
16 BI_CTX *bi_ctx; /* bigint context */
17
18 bigint *n; /* modulus */
19 bigint *e; /* public exponent */
20 bigint *d; /* private exponent */
21
22 int use_crt; /* use chinese remainder theorem */
23 bigint *p; /* p as in m = pq */
24 bigint *q; /* q as in m = pq */
25 bigint *dP; /* d mod (p-1) */
26 bigint *dQ; /* d mod (q-1) */
27 bigint *qInv; /* q^-1 mod p */
28
29 base64_t *base64;
30 };
31
32 rsakey_t *
33 rsakey_init(const unsigned char *modulus, int mod_len,
34 const unsigned char *pub_exp, int pub_len,
35 const unsigned char *priv_exp, int priv_len,
36 /* Optional, used for crt optimization */
37 const unsigned char *p, int p_len,
38 const unsigned char *q, int q_len,
39 const unsigned char *dP, int dP_len,
40 const unsigned char *dQ, int dQ_len,
41 const unsigned char *qInv, int qInv_len)
42 {
43 rsakey_t *rsakey;
44 int i;
45
46 if (mod_len > MAX_KEYLEN) {
47 return NULL;
48 }
49
50 rsakey = calloc(1, sizeof(rsakey_t));
51 if (!rsakey) {
52 return NULL;
53 }
54 rsakey->base64 = base64_init(NULL, 0, 0);
55 if (!rsakey->base64) {
56 free(rsakey);
57 return NULL;
58 }
59
60 /* Initialize structure */
61 for (i=0; !modulus[i] && i<mod_len; i++);
62 rsakey->keylen = mod_len-i;
63 rsakey->bi_ctx = bi_initialize();
64
65 /* Import public and private keys */
66 rsakey->n = bi_import(rsakey->bi_ctx, modulus, mod_len);
67 rsakey->e = bi_import(rsakey->bi_ctx, pub_exp, pub_len);
68 rsakey->d = bi_import(rsakey->bi_ctx, priv_exp, priv_len);
69
70 if (p && q && dP && dQ && qInv) {
71 /* Import crt optimization keys */
72 rsakey->p = bi_import(rsakey->bi_ctx, p, p_len);
73 rsakey->q = bi_import(rsakey->bi_ctx, q, q_len);
74 rsakey->dP = bi_import(rsakey->bi_ctx, dP, dP_len);
75 rsakey->dQ = bi_import(rsakey->bi_ctx, dQ, dQ_len);
76 rsakey->qInv = bi_import(rsakey->bi_ctx, qInv, qInv_len);
77
78 /* Set imported keys either permanent or modulo */
79 bi_permanent(rsakey->dP);
80 bi_permanent(rsakey->dQ);
81 bi_permanent(rsakey->qInv);
82 bi_set_mod(rsakey->bi_ctx, rsakey->p, BIGINT_P_OFFSET);
83 bi_set_mod(rsakey->bi_ctx, rsakey->q, BIGINT_Q_OFFSET);
84
85 rsakey->use_crt = 1;
86 }
87
88 /* Add keys to the bigint context */
89 bi_set_mod(rsakey->bi_ctx, rsakey->n, BIGINT_M_OFFSET);
90 bi_permanent(rsakey->e);
91 bi_permanent(rsakey->d);
92 return rsakey;
93 }
94
95 rsakey_t *
96 rsakey_init_pem(const char *pemstr)
97 {
98 rsapem_t *rsapem;
99 unsigned char *modulus=NULL; unsigned int mod_len=0;
100 unsigned char *pub_exp=NULL; unsigned int pub_len=0;
101 unsigned char *priv_exp=NULL; unsigned int priv_len=0;
102 unsigned char *p=NULL; unsigned int p_len=0;
103 unsigned char *q=NULL; unsigned int q_len=0;
104 unsigned char *dP=NULL; unsigned int dP_len=0;
105 unsigned char *dQ=NULL; unsigned int dQ_len=0;
106 unsigned char *qInv=NULL; unsigned int qInv_len=0;
107 rsakey_t *rsakey=NULL;
108
109 rsapem = rsapem_init(pemstr);
110 if (!rsapem) {
111 return NULL;
112 }
113
114 /* Read public and private keys */
115 mod_len = rsapem_read_vector(rsapem, &modulus);
116 pub_len = rsapem_read_vector(rsapem, &pub_exp);
117 priv_len = rsapem_read_vector(rsapem, &priv_exp);
118 /* Read private keys for crt optimization */
119 p_len = rsapem_read_vector(rsapem, &p);
120 q_len = rsapem_read_vector(rsapem, &q);
121 dP_len = rsapem_read_vector(rsapem, &dP);
122 dQ_len = rsapem_read_vector(rsapem, &dQ);
123 qInv_len = rsapem_read_vector(rsapem, &qInv);
124
125 if (modulus && pub_exp && priv_exp) {
126 /* Initialize rsakey value */
127 rsakey = rsakey_init(modulus, mod_len, pub_exp, pub_len, priv_exp, priv_len,
128 p, p_len, q, q_len, dP, dP_len, dQ, dQ_len, qInv, qInv_len);
129 }
130
131 free(modulus);
132 free(pub_exp);
133 free(priv_exp);
134 free(p);
135 free(q);
136 free(dP);
137 free(dQ);
138 free(qInv);
139 rsapem_destroy(rsapem);
140 return rsakey;
141 }
142
143 void
144 rsakey_destroy(rsakey_t *rsakey)
145 {
146 if (rsakey) {
147 bi_free_mod(rsakey->bi_ctx, BIGINT_M_OFFSET);
148 bi_depermanent(rsakey->e);
149 bi_depermanent(rsakey->d);
150 bi_free(rsakey->bi_ctx, rsakey->e);
151 bi_free(rsakey->bi_ctx, rsakey->d);
152
153 if (rsakey->use_crt) {
154 bi_free_mod(rsakey->bi_ctx, BIGINT_P_OFFSET);
155 bi_free_mod(rsakey->bi_ctx, BIGINT_Q_OFFSET);
156 bi_depermanent(rsakey->dP);
157 bi_depermanent(rsakey->dQ);
158 bi_depermanent(rsakey->qInv);
159 bi_free(rsakey->bi_ctx, rsakey->dP);
160 bi_free(rsakey->bi_ctx, rsakey->dQ);
161 bi_free(rsakey->bi_ctx, rsakey->qInv);
162 }
163 bi_terminate(rsakey->bi_ctx);
164
165 base64_destroy(rsakey->base64);
166 free(rsakey);
167 }
168 }
169
170 static bigint *
171 rsakey_modpow(rsakey_t *rsakey, bigint *msg)
172 {
173 if (rsakey->use_crt) {
174 return bi_crt(rsakey->bi_ctx, msg,
175 rsakey->dP, rsakey->dQ,
176 rsakey->p, rsakey->q, rsakey->qInv);
177 } else {
178 rsakey->bi_ctx->mod_offset = BIGINT_M_OFFSET;
179 return bi_mod_power(rsakey->bi_ctx, msg, rsakey->d);
180 }
181 }
182
183 int
184 rsakey_sign(rsakey_t *rsakey, char *dst, int dstlen, const char *b64digest,
185 unsigned char *ipaddr, int ipaddrlen,
186 unsigned char *hwaddr, int hwaddrlen)
187 {
188 unsigned char buffer[MAX_KEYLEN];
189 unsigned char *digest;
190 int digestlen;
191 int inputlen;
192 bigint *bi_in;
193 bigint *bi_out;
194 int idx;
195
196 assert(rsakey);
197
198 if (dstlen < base64_encoded_length(rsakey->base64, rsakey->keylen)) {
199 return -1;
200 }
201
202 /* Decode the base64 digest */
203 digestlen = base64_decode(rsakey->base64, &digest, b64digest, strlen(b64digest));
204 if (digestlen < 0) {
205 return -2;
206 }
207
208 /* Calculate the input data length */
209 inputlen = digestlen+ipaddrlen+hwaddrlen;
210 if (inputlen > rsakey->keylen-3-RSA_MIN_PADLEN) {
211 free(digest);
212 return -3;
213 }
214 if (inputlen < 32) {
215 /* Minimum size is 32 */
216 inputlen = 32;
217 }
218
219 /* Construct the input buffer with padding */
220 /* See RFC 3447 9.2 for more information */
221 idx = 0;
222 memset(buffer, 0, sizeof(buffer));
223 buffer[idx++] = 0x00;
224 buffer[idx++] = 0x01;
225 memset(buffer+idx, 0xff, rsakey->keylen-inputlen-3);
226 idx += rsakey->keylen-inputlen-3;
227 buffer[idx++] = 0x00;
228 memcpy(buffer+idx, digest, digestlen);
229 idx += digestlen;
230 memcpy(buffer+idx, ipaddr, ipaddrlen);
231 idx += ipaddrlen;
232 memcpy(buffer+idx, hwaddr, hwaddrlen);
233 idx += hwaddrlen;
234
235 /* Calculate the signature s = m^d (mod n) */
236 bi_in = bi_import(rsakey->bi_ctx, buffer, rsakey->keylen);
237 bi_out = rsakey_modpow(rsakey, bi_in);
238
239 /* Encode and save the signature into dst */
240 bi_export(rsakey->bi_ctx, bi_out, buffer, rsakey->keylen);
241 base64_encode(rsakey->base64, dst, buffer, rsakey->keylen);
242
243 free(digest);
244 return 0;
245 }
246
247 /* Mask generation function with SHA-1 hash */
248 /* See RFC 3447 B.2.1 for more information */
249 static int
250 rsakey_mfg1(unsigned char *dst, int dstlen, const unsigned char *seed, int seedlen, int masklen)
251 {
252 SHA1_CTX sha_ctx;
253 int iterations;
254 int dstpos;
255 int i;
256
257 iterations = (masklen+SHA1_SIZE-1)/SHA1_SIZE;
258 if (dstlen < iterations*SHA1_SIZE) {
259 return -1;
260 }
261
262 dstpos = 0;
263 for (i=0; i<iterations; i++) {
264 unsigned char counter[4];
265 counter[0] = (i>>24)&0xff;
266 counter[1] = (i>>16)&0xff;
267 counter[2] = (i>>8)&0xff;
268 counter[3] = i&0xff;
269
270 SHA1_Init(&sha_ctx);
271 SHA1_Update(&sha_ctx, seed, seedlen);
272 SHA1_Update(&sha_ctx, counter, sizeof(counter));
273 SHA1_Final(dst+dstpos, &sha_ctx);
274 dstpos += SHA1_SIZE;
275 }
276 return masklen;
277 }
278
279 /* OAEP decryption with SHA-1 hash */
280 /* See RFC 3447 7.1.2 for more information */
281 int
282 rsakey_decrypt(rsakey_t *rsakey, unsigned char *dst, int dstlen, const char *b64input)
283 {
284 unsigned char buffer[MAX_KEYLEN];
285 unsigned char maskbuf[MAX_KEYLEN];
286 unsigned char *input;
287 int inputlen;
288 bigint *bi_in;
289 bigint *bi_out;
290 int outlen;
291 int i, ret;
292
293 assert(rsakey);
294 if (!dst || !b64input) {
295 return -1;
296 }
297
298 memset(buffer, 0, sizeof(buffer));
299 inputlen = base64_decode(rsakey->base64, &input, b64input, strlen(b64input));
300 if (inputlen < 0 || inputlen > rsakey->keylen) {
301 return -2;
302 }
303 memcpy(buffer+rsakey->keylen-inputlen, input, inputlen);
304 free(input);
305 input = NULL;
306
307 /* Decrypt the input data m = c^d (mod n) */
308 bi_in = bi_import(rsakey->bi_ctx, buffer, rsakey->keylen);
309 bi_out = rsakey_modpow(rsakey, bi_in);
310
311 memset(buffer, 0, sizeof(buffer));
312 bi_export(rsakey->bi_ctx, bi_out, buffer, rsakey->keylen);
313
314 /* First unmask seed in the buffer */
315 ret = rsakey_mfg1(maskbuf, sizeof(maskbuf),
316 buffer+1+SHA1_SIZE,
317 rsakey->keylen-1-SHA1_SIZE,
318 SHA1_SIZE);
319 if (ret < 0) {
320 return -3;
321 }
322 for (i=0; i<ret; i++) {
323 buffer[1+i] ^= maskbuf[i];
324 }
325
326 /* Then unmask the actual message */
327 ret = rsakey_mfg1(maskbuf, sizeof(maskbuf),
328 buffer+1, SHA1_SIZE,
329 rsakey->keylen-1-SHA1_SIZE);
330 if (ret < 0) {
331 return -4;
332 }
333 for (i=0; i<ret; i++) {
334 buffer[1+SHA1_SIZE+i] ^= maskbuf[i];
335 }
336
337 /* Finally find the first data byte */
338 for (i=1+2*SHA1_SIZE; i<rsakey->keylen && !buffer[i++];);
339
340 /* Calculate real output length and return */
341 outlen = rsakey->keylen-i;
342 if (outlen > dstlen) {
343 return -5;
344 }
345 memcpy(dst, buffer+i, outlen);
346 return outlen;
347 }
348
349 int
350 rsakey_parseiv(rsakey_t *rsakey, unsigned char *dst, int dstlen, const char *b64input)
351 {
352 unsigned char *tmpptr;
353 int length;
354
355 assert(rsakey);
356 if (!dst || !b64input) {
357 return -1;
358 }
359
360 length = base64_decode(rsakey->base64, &tmpptr, b64input, strlen(b64input));
361 if (length < 0) {
362 return -1;
363 } else if (length > dstlen) {
364 free(tmpptr);
365 return -2;
366 }
367
368 memcpy(dst, tmpptr, length);
369 free(tmpptr);
370 return length;
371 }