diff --git a/secret_handshake/util.py b/secret_handshake/util.py index 895bc0d..429f798 100644 --- a/secret_handshake/util.py +++ b/secret_handshake/util.py @@ -44,14 +44,17 @@ def inc_nonce(nonce: bytes) -> bytes: return bnum -def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]: +def split_chunks(seq: Sequence[T], chunk_size: int) -> Generator[Sequence[T], None, None]: """Split sequence in equal-sized chunks. The last chunk is not padded.""" + if chunk_size <= 0: + raise ValueError("chunk_size must be greater than zero") + while seq: - yield seq[:n] - seq = seq[n:] + yield seq[:chunk_size] + seq = seq[chunk_size:] def long_to_bytes(n: int, blocksize: int = 0) -> bytes: diff --git a/tests/test_util.py b/tests/test_util.py index cf64c30..7928c14 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,6 +49,16 @@ def test_split_chunks_is_generator() -> None: assert isinstance(split_chunks([], 1), GeneratorType) +@pytest.mark.parametrize("size", (-123, -1, 0)) +def test_nonpositive_chunk_size(size: int) -> None: + """Test if split_chunks() with non-positive chunk sizes raise an error""" + + with pytest.raises(ValueError) as ctx: + list(split_chunks(b"", size)) + + assert str(ctx.value) == "chunk_size must be greater than zero" + + @pytest.mark.parametrize( "in_,chunksize,out", (