Fix a memory leak in the ALAC decoder.
[deb_shairplay.git] / src / lib / rsakey.c
1 /**
2 * Copyright (C) 2011-2012 Juho Vähä-Herttua
3 *
4 * This library is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU Lesser General Public
6 * License as published by the Free Software Foundation; either
7 * version 2.1 of the License, or (at your option) any later version.
8 *
9 * This library is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * Lesser General Public License for more details.
13 */
14
15 #include <stdlib.h>
16 #include <string.h>
17 #include <stdint.h>
18 #include <assert.h>
19
20 #include "rsakey.h"
21 #include "rsapem.h"
22 #include "base64.h"
23 #include "crypto/crypto.h"
24
25 #define RSA_MIN_PADLEN 8
26 #define MAX_KEYLEN 512
27
28 struct rsakey_s {
29 int keylen; /* length of modulus in bytes */
30 BI_CTX *bi_ctx; /* bigint context */
31
32 bigint *n; /* modulus */
33 bigint *e; /* public exponent */
34 bigint *d; /* private exponent */
35
36 int use_crt; /* use chinese remainder theorem */
37 bigint *p; /* p as in m = pq */
38 bigint *q; /* q as in m = pq */
39 bigint *dP; /* d mod (p-1) */
40 bigint *dQ; /* d mod (q-1) */
41 bigint *qInv; /* q^-1 mod p */
42
43 base64_t *base64;
44 };
45
46 rsakey_t *
47 rsakey_init(const unsigned char *modulus, int mod_len,
48 const unsigned char *pub_exp, int pub_len,
49 const unsigned char *priv_exp, int priv_len,
50 /* Optional, used for crt optimization */
51 const unsigned char *p, int p_len,
52 const unsigned char *q, int q_len,
53 const unsigned char *dP, int dP_len,
54 const unsigned char *dQ, int dQ_len,
55 const unsigned char *qInv, int qInv_len)
56 {
57 rsakey_t *rsakey;
58 int i;
59
60 if (mod_len > MAX_KEYLEN) {
61 return NULL;
62 }
63
64 rsakey = calloc(1, sizeof(rsakey_t));
65 if (!rsakey) {
66 return NULL;
67 }
68 rsakey->base64 = base64_init(NULL, 0, 0);
69 if (!rsakey->base64) {
70 free(rsakey);
71 return NULL;
72 }
73
74 /* Initialize structure */
75 for (i=0; !modulus[i] && i<mod_len; i++);
76 rsakey->keylen = mod_len-i;
77 rsakey->bi_ctx = bi_initialize();
78
79 /* Import public and private keys */
80 rsakey->n = bi_import(rsakey->bi_ctx, modulus, mod_len);
81 rsakey->e = bi_import(rsakey->bi_ctx, pub_exp, pub_len);
82 rsakey->d = bi_import(rsakey->bi_ctx, priv_exp, priv_len);
83
84 if (p && q && dP && dQ && qInv) {
85 /* Import crt optimization keys */
86 rsakey->p = bi_import(rsakey->bi_ctx, p, p_len);
87 rsakey->q = bi_import(rsakey->bi_ctx, q, q_len);
88 rsakey->dP = bi_import(rsakey->bi_ctx, dP, dP_len);
89 rsakey->dQ = bi_import(rsakey->bi_ctx, dQ, dQ_len);
90 rsakey->qInv = bi_import(rsakey->bi_ctx, qInv, qInv_len);
91
92 /* Set imported keys either permanent or modulo */
93 bi_permanent(rsakey->dP);
94 bi_permanent(rsakey->dQ);
95 bi_permanent(rsakey->qInv);
96 bi_set_mod(rsakey->bi_ctx, rsakey->p, BIGINT_P_OFFSET);
97 bi_set_mod(rsakey->bi_ctx, rsakey->q, BIGINT_Q_OFFSET);
98
99 rsakey->use_crt = 1;
100 }
101
102 /* Add keys to the bigint context */
103 bi_set_mod(rsakey->bi_ctx, rsakey->n, BIGINT_M_OFFSET);
104 bi_permanent(rsakey->e);
105 bi_permanent(rsakey->d);
106 return rsakey;
107 }
108
109 rsakey_t *
110 rsakey_init_pem(const char *pemstr)
111 {
112 rsapem_t *rsapem;
113 unsigned char *modulus=NULL; unsigned int mod_len=0;
114 unsigned char *pub_exp=NULL; unsigned int pub_len=0;
115 unsigned char *priv_exp=NULL; unsigned int priv_len=0;
116 unsigned char *p=NULL; unsigned int p_len=0;
117 unsigned char *q=NULL; unsigned int q_len=0;
118 unsigned char *dP=NULL; unsigned int dP_len=0;
119 unsigned char *dQ=NULL; unsigned int dQ_len=0;
120 unsigned char *qInv=NULL; unsigned int qInv_len=0;
121 rsakey_t *rsakey=NULL;
122
123 rsapem = rsapem_init(pemstr);
124 if (!rsapem) {
125 return NULL;
126 }
127
128 /* Read public and private keys */
129 mod_len = rsapem_read_vector(rsapem, &modulus);
130 pub_len = rsapem_read_vector(rsapem, &pub_exp);
131 priv_len = rsapem_read_vector(rsapem, &priv_exp);
132 /* Read private keys for crt optimization */
133 p_len = rsapem_read_vector(rsapem, &p);
134 q_len = rsapem_read_vector(rsapem, &q);
135 dP_len = rsapem_read_vector(rsapem, &dP);
136 dQ_len = rsapem_read_vector(rsapem, &dQ);
137 qInv_len = rsapem_read_vector(rsapem, &qInv);
138
139 if (modulus && pub_exp && priv_exp) {
140 /* Initialize rsakey value */
141 rsakey = rsakey_init(modulus, mod_len, pub_exp, pub_len, priv_exp, priv_len,
142 p, p_len, q, q_len, dP, dP_len, dQ, dQ_len, qInv, qInv_len);
143 }
144
145 free(modulus);
146 free(pub_exp);
147 free(priv_exp);
148 free(p);
149 free(q);
150 free(dP);
151 free(dQ);
152 free(qInv);
153 rsapem_destroy(rsapem);
154 return rsakey;
155 }
156
157 void
158 rsakey_destroy(rsakey_t *rsakey)
159 {
160 if (rsakey) {
161 bi_free_mod(rsakey->bi_ctx, BIGINT_M_OFFSET);
162 bi_depermanent(rsakey->e);
163 bi_depermanent(rsakey->d);
164 bi_free(rsakey->bi_ctx, rsakey->e);
165 bi_free(rsakey->bi_ctx, rsakey->d);
166
167 if (rsakey->use_crt) {
168 bi_free_mod(rsakey->bi_ctx, BIGINT_P_OFFSET);
169 bi_free_mod(rsakey->bi_ctx, BIGINT_Q_OFFSET);
170 bi_depermanent(rsakey->dP);
171 bi_depermanent(rsakey->dQ);
172 bi_depermanent(rsakey->qInv);
173 bi_free(rsakey->bi_ctx, rsakey->dP);
174 bi_free(rsakey->bi_ctx, rsakey->dQ);
175 bi_free(rsakey->bi_ctx, rsakey->qInv);
176 }
177 bi_terminate(rsakey->bi_ctx);
178
179 base64_destroy(rsakey->base64);
180 free(rsakey);
181 }
182 }
183
184 static bigint *
185 rsakey_modpow(rsakey_t *rsakey, bigint *msg)
186 {
187 if (rsakey->use_crt) {
188 return bi_crt(rsakey->bi_ctx, msg,
189 rsakey->dP, rsakey->dQ,
190 rsakey->p, rsakey->q, rsakey->qInv);
191 } else {
192 rsakey->bi_ctx->mod_offset = BIGINT_M_OFFSET;
193 return bi_mod_power(rsakey->bi_ctx, msg, rsakey->d);
194 }
195 }
196
197 int
198 rsakey_sign(rsakey_t *rsakey, char *dst, int dstlen, const char *b64digest,
199 unsigned char *ipaddr, int ipaddrlen,
200 unsigned char *hwaddr, int hwaddrlen)
201 {
202 unsigned char buffer[MAX_KEYLEN];
203 unsigned char *digest;
204 int digestlen;
205 int inputlen;
206 bigint *bi_in;
207 bigint *bi_out;
208 int idx;
209
210 assert(rsakey);
211
212 if (dstlen < base64_encoded_length(rsakey->base64, rsakey->keylen)) {
213 return -1;
214 }
215
216 /* Decode the base64 digest */
217 digestlen = base64_decode(rsakey->base64, &digest, b64digest, strlen(b64digest));
218 if (digestlen < 0) {
219 return -2;
220 }
221
222 /* Calculate the input data length */
223 inputlen = digestlen+ipaddrlen+hwaddrlen;
224 if (inputlen > rsakey->keylen-3-RSA_MIN_PADLEN) {
225 free(digest);
226 return -3;
227 }
228 if (inputlen < 32) {
229 /* Minimum size is 32 */
230 inputlen = 32;
231 }
232
233 /* Construct the input buffer with padding */
234 /* See RFC 3447 9.2 for more information */
235 idx = 0;
236 memset(buffer, 0, sizeof(buffer));
237 buffer[idx++] = 0x00;
238 buffer[idx++] = 0x01;
239 memset(buffer+idx, 0xff, rsakey->keylen-inputlen-3);
240 idx += rsakey->keylen-inputlen-3;
241 buffer[idx++] = 0x00;
242 memcpy(buffer+idx, digest, digestlen);
243 idx += digestlen;
244 memcpy(buffer+idx, ipaddr, ipaddrlen);
245 idx += ipaddrlen;
246 memcpy(buffer+idx, hwaddr, hwaddrlen);
247 idx += hwaddrlen;
248
249 /* Calculate the signature s = m^d (mod n) */
250 bi_in = bi_import(rsakey->bi_ctx, buffer, rsakey->keylen);
251 bi_out = rsakey_modpow(rsakey, bi_in);
252
253 /* Encode and save the signature into dst */
254 bi_export(rsakey->bi_ctx, bi_out, buffer, rsakey->keylen);
255 base64_encode(rsakey->base64, dst, buffer, rsakey->keylen);
256
257 free(digest);
258 return 0;
259 }
260
261 /* Mask generation function with SHA-1 hash */
262 /* See RFC 3447 B.2.1 for more information */
263 static int
264 rsakey_mfg1(unsigned char *dst, int dstlen, const unsigned char *seed, int seedlen, int masklen)
265 {
266 SHA1_CTX sha_ctx;
267 int iterations;
268 int dstpos;
269 int i;
270
271 iterations = (masklen+SHA1_SIZE-1)/SHA1_SIZE;
272 if (dstlen < iterations*SHA1_SIZE) {
273 return -1;
274 }
275
276 dstpos = 0;
277 for (i=0; i<iterations; i++) {
278 unsigned char counter[4];
279 counter[0] = (i>>24)&0xff;
280 counter[1] = (i>>16)&0xff;
281 counter[2] = (i>>8)&0xff;
282 counter[3] = i&0xff;
283
284 SHA1_Init(&sha_ctx);
285 SHA1_Update(&sha_ctx, seed, seedlen);
286 SHA1_Update(&sha_ctx, counter, sizeof(counter));
287 SHA1_Final(dst+dstpos, &sha_ctx);
288 dstpos += SHA1_SIZE;
289 }
290 return masklen;
291 }
292
293 /* OAEP decryption with SHA-1 hash */
294 /* See RFC 3447 7.1.2 for more information */
295 int
296 rsakey_decrypt(rsakey_t *rsakey, unsigned char *dst, int dstlen, const char *b64input)
297 {
298 unsigned char buffer[MAX_KEYLEN];
299 unsigned char maskbuf[MAX_KEYLEN];
300 unsigned char *input;
301 int inputlen;
302 bigint *bi_in;
303 bigint *bi_out;
304 int outlen;
305 int i, ret;
306
307 assert(rsakey);
308 if (!dst || !b64input) {
309 return -1;
310 }
311
312 memset(buffer, 0, sizeof(buffer));
313 inputlen = base64_decode(rsakey->base64, &input, b64input, strlen(b64input));
314 if (inputlen < 0 || inputlen > rsakey->keylen) {
315 return -2;
316 }
317 memcpy(buffer+rsakey->keylen-inputlen, input, inputlen);
318 free(input);
319 input = NULL;
320
321 /* Decrypt the input data m = c^d (mod n) */
322 bi_in = bi_import(rsakey->bi_ctx, buffer, rsakey->keylen);
323 bi_out = rsakey_modpow(rsakey, bi_in);
324
325 memset(buffer, 0, sizeof(buffer));
326 bi_export(rsakey->bi_ctx, bi_out, buffer, rsakey->keylen);
327
328 /* First unmask seed in the buffer */
329 ret = rsakey_mfg1(maskbuf, sizeof(maskbuf),
330 buffer+1+SHA1_SIZE,
331 rsakey->keylen-1-SHA1_SIZE,
332 SHA1_SIZE);
333 if (ret < 0) {
334 return -3;
335 }
336 for (i=0; i<ret; i++) {
337 buffer[1+i] ^= maskbuf[i];
338 }
339
340 /* Then unmask the actual message */
341 ret = rsakey_mfg1(maskbuf, sizeof(maskbuf),
342 buffer+1, SHA1_SIZE,
343 rsakey->keylen-1-SHA1_SIZE);
344 if (ret < 0) {
345 return -4;
346 }
347 for (i=0; i<ret; i++) {
348 buffer[1+SHA1_SIZE+i] ^= maskbuf[i];
349 }
350
351 /* Finally find the first data byte */
352 for (i=1+2*SHA1_SIZE; i<rsakey->keylen && !buffer[i++];);
353
354 /* Calculate real output length and return */
355 outlen = rsakey->keylen-i;
356 if (outlen > dstlen) {
357 return -5;
358 }
359 memcpy(dst, buffer+i, outlen);
360 return outlen;
361 }
362
363 int
364 rsakey_parseiv(rsakey_t *rsakey, unsigned char *dst, int dstlen, const char *b64input)
365 {
366 unsigned char *tmpptr;
367 int length;
368
369 assert(rsakey);
370 if (!dst || !b64input) {
371 return -1;
372 }
373
374 length = base64_decode(rsakey->base64, &tmpptr, b64input, strlen(b64input));
375 if (length < 0) {
376 return -1;
377 } else if (length > dstlen) {
378 free(tmpptr);
379 return -2;
380 }
381
382 memcpy(dst, tmpptr, length);
383 free(tmpptr);
384 return length;
385 }