diff --git a/cmds/keystore/keystore.c b/cmds/keystore/keystore.c index ba74c7814..37155e4c4 100644 --- a/cmds/keystore/keystore.c +++ b/cmds/keystore/keystore.c @@ -163,19 +163,23 @@ static struct __attribute__((packed)) { static int8_t encrypt_blob(char *name, AES_KEY *aes_key) { uint8_t vector[AES_BLOCK_SIZE]; - int length = blob.length; + int length; int fd; if (read(the_entropy, vector, AES_BLOCK_SIZE) != AES_BLOCK_SIZE) { return SYSTEM_ERROR; } - length += blob.value - blob.digested; + length = (blob.length + blob.value - blob.encrypted) % AES_BLOCK_SIZE; + if (length) { + length = AES_BLOCK_SIZE - length; + } + + length += blob.length + blob.value - blob.digested; blob.length = htonl(blob.length); MD5(blob.digested, length, blob.digest); length += blob.digested - blob.encrypted; - length = (length + AES_BLOCK_SIZE - 1) / AES_BLOCK_SIZE * AES_BLOCK_SIZE; memcpy(vector, blob.vector, AES_BLOCK_SIZE); AES_cbc_encrypt(blob.encrypted, blob.encrypted, length, aes_key, vector, AES_ENCRYPT); @@ -184,11 +188,9 @@ static int8_t encrypt_blob(char *name, AES_KEY *aes_key) length += blob.encrypted - (uint8_t *)&blob; fd = open(".tmp", O_WRONLY | O_TRUNC | O_CREAT, S_IRUSR | S_IWUSR); - if (fd == -1 || write(fd, &blob, length) != length) { - return SYSTEM_ERROR; - } + length -= write(fd, &blob, length); close(fd); - return rename(".tmp", name) ? SYSTEM_ERROR : NO_ERROR; + return (length || rename(".tmp", name)) ? SYSTEM_ERROR : NO_ERROR; } static int8_t decrypt_blob(char *name, AES_KEY *aes_key) @@ -210,14 +212,15 @@ static int8_t decrypt_blob(char *name, AES_KEY *aes_key) AES_cbc_encrypt(blob.encrypted, blob.encrypted, length, aes_key, blob.vector, AES_DECRYPT); length -= blob.digested - blob.encrypted; - if (!memcmp(blob.digest, MD5(blob.digested, length, NULL), - MD5_DIGEST_LENGTH)) { + if (memcmp(blob.digest, MD5(blob.digested, length, NULL), + MD5_DIGEST_LENGTH)) { return VALUE_CORRUPTED; } length -= blob.value - blob.digested; blob.length = ntohl(blob.length); - return (length < blob.length) ? VALUE_CORRUPTED : NO_ERROR; + return (blob.length < 0 || blob.length > length) ? VALUE_CORRUPTED : + NO_ERROR; } /* Here are the actions. Each of them is a function without arguments. All