diff --git a/tests/test_hash.c b/tests/test_hash.c index 891dd62aea..d83f08eb1d 100644 --- a/tests/test_hash.c +++ b/tests/test_hash.c @@ -12,35 +12,24 @@ #include "system_info.c" -#define BUFFER_SIZE 10000 +#define BUFFER_SIZE 30000 static int read_stdin(uint8_t **msg, size_t *msg_len) { *msg = malloc(BUFFER_SIZE); if (*msg == NULL) { return -1; } - uint8_t *msg_next_read = *msg; - ssize_t bytes_read; - *msg_len = 0; - while (1) { - bytes_read = read(0, msg_next_read, BUFFER_SIZE); - if (bytes_read == -1) { - return -1; - } - *msg_len += (size_t)bytes_read; - if (bytes_read < BUFFER_SIZE) { - break; - } else { - uint8_t *msgprime = malloc(*msg_len + BUFFER_SIZE); - if (msgprime == NULL) { - return -1; - } - memcpy(msgprime, *msg, *msg_len); - free(*msg); - *msg = msgprime; - msg_next_read = &((*msg)[*msg_len]); - } + size_t bytes_read; + bytes_read = fread(*msg, 1, BUFFER_SIZE, stdin); + if (ferror(stdin)) { + perror("Read from stdin failed"); + return -2; + } + if (bytes_read == BUFFER_SIZE && !feof(stdin)) { + fprintf(stderr, "Input too large for buffer (%d)\n", BUFFER_SIZE); + return -3; } + *msg_len = bytes_read; return 0; } @@ -56,7 +45,7 @@ static int do_sha256(void) { uint8_t *msg; size_t msg_len; if (read_stdin(&msg, &msg_len) != 0) { - fprintf(stderr, "ERROR: malloc failure\n"); + fprintf(stderr, "ERROR reading from stdin\n"); return -1; } // run main SHA-256 API @@ -103,7 +92,7 @@ static int do_sha384(void) { uint8_t *msg; size_t msg_len; if (read_stdin(&msg, &msg_len) != 0) { - fprintf(stderr, "ERROR: malloc failure\n"); + fprintf(stderr, "ERROR reading from stdin\n"); return -1; } // run main SHA-384 API @@ -150,7 +139,7 @@ static int do_sha512(void) { uint8_t *msg; size_t msg_len; if (read_stdin(&msg, &msg_len) != 0) { - fprintf(stderr, "ERROR: malloc failure\n"); + fprintf(stderr, "ERROR reading from stdin\n"); return -1; } // run main SHA-512 API @@ -197,7 +186,7 @@ static int do_arbitrary_hash(void (*hash)(uint8_t *, const uint8_t *, size_t), s uint8_t *msg; size_t msg_len; if (read_stdin(&msg, &msg_len) != 0) { - fprintf(stderr, "ERROR: malloc failure\n"); + fprintf(stderr, "ERROR reading from stdin\n"); return -1; } // run main SHA-256 API @@ -214,6 +203,7 @@ int main(int argc, char **argv) { if (argc != 2) { fprintf(stderr, "Usage: test_hash algname\n"); fprintf(stderr, " algname: sha256, sha384, sha512, sha256inc, sha384inc, sha512inc\n"); + fprintf(stderr, " sha3_256, sha3_384, sha3_512\n"); fprintf(stderr, " test_hash reads input from stdin and outputs hash value as hex string to stdout"); printf("\n"); print_system_info(); @@ -234,6 +224,12 @@ int main(int argc, char **argv) { return do_arbitrary_hash(&OQS_SHA2_sha384, 48); } else if (strcmp(hash_alg, "sha512") == 0) { return do_arbitrary_hash(&OQS_SHA2_sha512, 64); + } else if (strcmp(hash_alg, "sha3_256") == 0) { + return do_arbitrary_hash(&OQS_SHA3_sha3_256, 32); + } else if (strcmp(hash_alg, "sha3_384") == 0) { + return do_arbitrary_hash(&OQS_SHA3_sha3_384, 48); + } else if (strcmp(hash_alg, "sha3_512") == 0) { + return do_arbitrary_hash(&OQS_SHA3_sha3_512, 64); } else { fprintf(stderr, "ERROR: Test not implemented\n"); OQS_destroy(); diff --git a/tests/test_hash.py b/tests/test_hash.py index 609d9ea526..d19b8dee23 100644 --- a/tests/test_hash.py +++ b/tests/test_hash.py @@ -3,6 +3,7 @@ import hashlib import helpers import pytest +import random import sys @helpers.filtered_test @@ -20,49 +21,31 @@ def test_sha3(): ) @helpers.filtered_test -@pytest.mark.parametrize('msg', ['', 'a', 'abc', '1234567890123456789012345678901678901567890', '1234567890123456789012345678901678901567890andthensometohavemorethan64bytes', '1234567890123456789012345678901678901567890andlotsmoretexttosurelyandwithoutdoubtgobeyondthe128bytesrequiredtotriggerallincrementalblocklogiccases']) +@pytest.mark.parametrize('algname', ['sha256', 'sha384', 'sha512', 'sha3_256', 'sha3_384', 'sha3_512']) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Not supported on Windows") -def test_sha256(msg): - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha256'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha256(msg.encode()).hexdigest()) - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha256inc'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha256(msg.encode()).hexdigest()) - -@helpers.filtered_test -@pytest.mark.parametrize('msg', ['', 'a', 'abc', '1234567890123456789012345678901678901567890', '1234567890123456789012345678901678901567890andthensometohavemorethan64bytes', '1234567890123456789012345678901678901567890andlotsmoretexttosurelyandwithoutdoubtgobeyondthe128bytesrequiredtotriggerallincrementalblocklogiccases']) -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not supported on Windows") -def test_sha384(msg): - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha384'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha384(msg.encode()).hexdigest()) - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha384inc'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha384(msg.encode()).hexdigest()) - -@helpers.filtered_test -@pytest.mark.parametrize('msg', ['', 'a', 'abc', '1234567890123456789012345678901678901567890', '1234567890123456789012345678901678901567890andthensometohavemorethan64bytes', '1234567890123456789012345678901678901567890andlotsmoretexttosurelyandwithoutdoubtgobeyondthe128bytesrequiredtotriggerallincrementalblocklogiccases']) -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not supported on Windows") -def test_sha512(msg): - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha512'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha512(msg.encode()).hexdigest()) - output = helpers.run_subprocess( - [helpers.path_to_executable('test_hash'), 'sha512inc'], - input = msg.encode(), - ) - assert(output.rstrip() == hashlib.sha512(msg.encode()).hexdigest()) +def test_hash_sha2_random(algname): + # hash every size from 0 to 1024, then every 11th size after that + # (why 11? it's coprime with powers of 2, so we should land in a + # bunch of random-ish spots relative to block boundaries) + for i in list(range(0, 1024)) + list(range(1025, 20000, 11)): + msg = "".join("1" for j in range(i)).encode() + hasher = hashlib.new(algname) + hasher.update(msg) + output = helpers.run_subprocess( + [helpers.path_to_executable('test_hash'), algname], + input = msg, + ) + if output.rstrip() != hasher.hexdigest(): + print(msg.hex()) + assert False, algname + " hashes don't match for the above " + str(i) + "-byte hex string; liboqs output = " + output.rstrip() + "; Python output = " + hasher.hexdigest() + if algname[0:4] == "sha3": continue + output = helpers.run_subprocess( + [helpers.path_to_executable('test_hash'), algname + 'inc'], + input = msg, + ) + if output.rstrip() != hasher.hexdigest(): + print(msg.hex()) + assert False, algname + " hashes (using liboqs incremental API) don't match for the above " + str(i) + "-byte hex string; liboqs output = " + output.rstrip() + "; Python output = " + hasher.hexdigest() if __name__ == "__main__": import sys