Hello,
When I'm trying to implement ML-KEM (Kyber), I realized that the current
API for SHAKE (sha3_256_shake) is a bit too limited: while ML-KEM uses
SHAKE128 as a source of pseudorandom samples[1], the the current API
requires the total number of bytes are determined prior to the call, and
after the call the hash context is reset.
Here I propose adding a couple of helper functions to support such
streaming use-case: sha3_256_shake_pad to process the final block, and
sha3_256_shake_read to read the digest from the current state without
resetting the context. With those functions, one could write a
streaming read interface as attached.
What do you think? The actual code changes can be found at:
https://git.lysator.liu.se/nettle/nettle/-/merge_requests/61
I also have a SHAKE128 implementation with analogous API, which I will
post later.
Regards,
Footnotes:
[1] https://bwesterb.github.io/draft-schwabe-cfrg-kyber/draft-cfrg-schwabe-kyber...
--
Daiki Ueno
#include "sha3.h"
#include <stdio.h>
struct shake_reader
{
struct sha3_256_ctx ctx;
uint8_t buf[SHA3_256_BLOCK_SIZE];
size_t offset;
};
static void
shake_reader_init (struct shake_reader *reader)
{
sha3_256_init (&reader->ctx);
reader->offset = sizeof(reader->buf);
}
static void
shake_reader_read (struct shake_reader *reader,
size_t length,
uint8_t *digest)
{
while (length > 0)
{
while (reader->offset < sizeof(reader->buf))
{
*digest++ = reader->buf[reader->offset++];
length--;
if (!length)
return;
}
if (reader->offset == sizeof(reader->buf))
{
sha3_256_shake_read (&reader->ctx, sizeof(reader->buf), reader->buf);
sha3_permute (&reader->ctx.state);
reader->offset = 0;
}
}
}
int main (void)
{
struct shake_reader reader;
uint8_t buf[3];
size_t n = 0;
shake_reader_init (&reader);
sha3_256_update (&reader.ctx, 0, NULL);
sha3_256_shake_pad (&reader.ctx);
while (n < 256)
{
shake_reader_read (&reader, sizeof(buf), buf);
int d1 = buf[0] + 256 * (buf[1] % 16);
int d2 = (buf[1] >> 4) + 16 * buf[2];
if (d1 < 3329)
{
printf ("%d\n", d1);
n++;
}
if (n == 256)
break;
if (d2 < 3329)
{
printf ("%d\n", d2);
n++;
}
}
return 0;
}