Source code for kvikio.zarr._zarr_python_3

# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import functools
import os
from pathlib import Path

import packaging.version
import zarr

_zarr_version = packaging.version.parse(zarr.__version__)

if _zarr_version < packaging.version.parse("3.0.0"):
    # We include this runtime package checking to help users who relied on
    # installing kvikio to also include zarr, which is not an optional dependency.
    raise ImportError(
        f"'zarr>=3' is required, but 'zarr=={_zarr_version}' is installed."
    )

import zarr.storage  # noqa: E402
from zarr.abc.store import (  # noqa: E402
    ByteRequest,
    OffsetByteRequest,
    RangeByteRequest,
    SuffixByteRequest,
)
from zarr.core.buffer import Buffer, BufferPrototype  # noqa: E402
from zarr.core.buffer.core import default_buffer_prototype  # noqa: E402

import kvikio  # noqa: E402

# The GDSStore implementation follows the `LocalStore` implementation
# at https://github.com/zarr-developers/zarr-python/blob/main/src/zarr/storage/_local.py
# with differences coming swapping in `cuFile` for the stdlib open file object.


@functools.cache
def _is_ge_zarr_3_0_7():
    return _zarr_version >= packaging.version.parse("3.0.7")


def _get(
    path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None
) -> Buffer:
    file_size = os.path.getsize(path)
    file_offset: int

    match byte_range:
        case None:
            nbytes = file_size
            file_offset = 0
        case OffsetByteRequest():
            nbytes = max(0, file_size - byte_range.offset)
            file_offset = byte_range.offset
        case RangeByteRequest():
            nbytes = byte_range.end - byte_range.start
            file_offset = byte_range.start
        case SuffixByteRequest():
            nbytes = byte_range.suffix
            file_offset = max(0, file_size - byte_range.suffix)
        case _:
            # This isn't allowed by mypy, but the tests assert we raise
            # something here.
            raise TypeError(f"Unexpected byte_range, got {byte_range}")

    # kvikio doesn't support reading past the end of a file. Some zarr tests
    # rely on this behavior: to "read" 3 bytes out of a 0 byte file, or to
    # "seek" past the end of a file with file_offset. The semantics seem to
    # be roughly the same as slicing an empty bytestring.

    nbytes = min(nbytes, file_size)
    file_offset = min(file_offset, file_size)

    if _is_ge_zarr_3_0_7():
        dtype = "B"
    else:
        dtype = "b"

    raw = prototype.nd_buffer.create(shape=(nbytes,), dtype=dtype).as_ndarray_like()
    buf = prototype.buffer.from_array_like(raw)

    with kvikio.CuFile(path) as f:
        # Note: this currently creates an IOFuture and then blocks
        # on reading it. The blocking read means this is in a regular
        # sync function, and so this must be run in a threadpool.
        future = f.pread(raw, size=nbytes, file_offset=file_offset)
        future.get()  # blocks

    return buf


def _put(
    path: Path,
    value: Buffer,
    start: int | None = None,
    exclusive: bool = False,
) -> int | None:
    path.parent.mkdir(parents=True, exist_ok=True)
    if start is not None:
        with kvikio.CuFile(path, "r+b") as f:
            f.write(value.as_array_like(), file_offset=start)
        return None
    else:
        buf = value.as_array_like()
        if exclusive:
            if path.exists():
                raise FileExistsError(f"File exists: {path}")
        mode = "wb"
        with kvikio.CuFile(path, flags=mode) as f:
            return f.write(buf)


[docs] class GDSStore(zarr.storage.LocalStore): def __repr__(self) -> str: return f"{type(self).__name__}('{self}')"
[docs] async def get( self, key: str, prototype: BufferPrototype | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: if prototype is None: prototype = default_buffer_prototype() if not self._is_open: await self._open() assert isinstance(key, str) path = self.root / key try: return await asyncio.to_thread(_get, path, prototype, byte_range) except (FileNotFoundError, IsADirectoryError, NotADirectoryError): return None
[docs] async def set(self, key: str, value: Buffer) -> None: return await self._set(key, value)
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() self._check_writable() assert isinstance(key, str) if not isinstance(value, Buffer): raise TypeError( f"LocalStore.set(): `value` must be a Buffer instance. Got an " f"instance of {type(value)} instead." ) path = self.root / key await asyncio.to_thread(_put, path, value, start=None, exclusive=exclusive)