diff --git a/src/ASDF.jl b/src/ASDF.jl index 1fe903e..e85c90c 100644 --- a/src/ASDF.jl +++ b/src/ASDF.jl @@ -32,7 +32,9 @@ const compression_keys = Dict{Compression,Vector{UInt8}}( C_Zlib => Vector{UInt8}("zlib"), C_Zstd => Vector{UInt8}("zstd"), ) -@assert all(length(val)==4 for val in values(compression_keys)) +if !all(length(val)==4 for val in values(compression_keys)) + error("Invalid entry in `compression_keys`, please ensure all values have length 4.") +end const compression_enums = Dict{Vector{UInt8},Compression}( value => key for (key, value) in compression_keys ) @@ -115,8 +117,9 @@ function read_block_header(io::IO, position::Int64) header = Array{UInt8}(undef, max_header_size) seek(io, position) nb = readbytes!(io, header) - # TODO: Better error message - @assert nb == length(header) + if nb != length(header) + error("Number of bytes read from stream does not match length of header") + end # Decode block header token = @view header[1:4] @@ -128,15 +131,19 @@ function read_block_header(io::IO, position::Int64) data_size = big2native_U64(@view header[31:38]) checksum = @view header[39:54] - # TODO: Better error message - @assert token == block_magic_token + if token != block_magic_token + error("Block does not start with magic number") + end STREAMED = Bool(flags & 0x1) # We don't handle streamed blocks yet - @assert !STREAMED + if STREAMED + error("ASDF.jl does not yet support streamed blocks") + end - # TODO: Better error message - @assert allocated_size >= used_size + if allocated_size < used_size + error("ASDF file header incorrectly specifies amount of space to use") + end return BlockHeader(io, position, token, header_size, flags, compression, allocated_size, used_size, data_size, checksum) end @@ -158,14 +165,16 @@ function read_block(header::BlockHeader) seek(header.io, block_data_start) data = Array{UInt8}(undef, header.used_size) nb = readbytes!(header.io, data) - # TODO: Better error message - @assert nb == length(data) + if nb != length(data) + error("Number of bytes read from `header` does not match length of `data`") + end # Check checksum if any(header.checksum != 0) actual_checksum = md5(data) - # TODO: Better error message - @assert all(actual_checksum == header.checksum) + if any(actual_checksum != header.checksum) + error("Checksum mismatch in ASDF file header") + end end # Decompress data @@ -187,15 +196,15 @@ function read_block(header::BlockHeader) elseif compression == C_Zstd codec = ZstdCodec() else - # TODO: Better error message - @assert false + error("Invalid compression format found: $compression") end data = decode(codec, data) end data::AbstractVector{UInt8} - # TODO: Better error message - @assert length(data) == header.data_size + if length(data) != header.data_size + error("Actual data size different from declared data size in header.") + end return data end @@ -318,14 +327,32 @@ struct NDArray offset::Int64, strides::Vector{Int64}, ) - @assert (source === nothing) + (data === nothing) == 1 - @assert source === nothing || source >= 0 - @assert data === nothing || eltype(data) == Type(datatype) - @assert data === nothing || size(data) == Tuple(reverse(shape)) - @assert offset >= 0 - @assert length(shape) == length(strides) - @assert all(shape .>= 0) - @assert all(strides .> 0) + if (source === nothing) + (data === nothing) != 1 + throw(ArgumentError("Exactly one of `source` or `data` must be provided.")) + end + if source !== nothing && source < 0 + throw(ArgumentError("`source` must be >= 0 if provided.")) + end + if data !== nothing + if eltype(data) != Type(datatype) + throw(ArgumentError("`data` must contain elements of type given by `datatype`.")) + end + if size(data) != Tuple(reverse(shape)) + throw(ArgumentError("`shape` does not correctly describe the shape of `data`.")) + end + end + if offset < 0 + throw(ArgumentError("`offset` must be >= 0.")) + end + if length(shape) != length(strides) + throw(DimensionMismatch("`shape` and `strides` must have the same length.")) + end + if any(shape .< 0) + throw(ArgumentError("`shape` cannot have negative elements.")) + end + if !all(strides .> 0) + throw(ArgumentError("`strides` must have only positive elements.")) + end return new(lazy_block_headers, source, data, shape, datatype, byteorder, offset, strides) end end @@ -389,7 +416,12 @@ end function Base.getindex(ndarray::NDArray) if ndarray.data !== nothing data = ndarray.data - @assert ndarray.byteorder == host_byteorder + if ndarray.byteorder != host_byteorder + error( + "ndarray byteorder does not match system byteorder; ", + "byteorder swapping not yet implemented." + ) + end elseif ndarray.source !== nothing data = read_block(ndarray.lazy_block_headers.block_headers[ndarray.source + 1]) # Handle strides and offset. @@ -409,13 +441,19 @@ function Base.getindex(ndarray::NDArray) map!(bswap, data, data) end else - @assert false + error("`ndarray` is in invalid state; neither `source` nor `data` is given.") end # Check array layout - @assert size(data) == Tuple(reverse(ndarray.shape)) - @assert eltype(data) == Type(ndarray.datatype) - @assert sizeof(eltype(data)) .* Base.strides(data) == Tuple(reverse(ndarray.strides)) + if size(data) != Tuple(reverse(ndarray.shape)) + error("`data` does not conform to specified `ndarray.shape`") + end + if eltype(data) != Type(ndarray.datatype) + error("`data` does not match type specified by `ndarray.datatype`") + end + if sizeof(eltype(data)) .* Base.strides(data) != Tuple(reverse(ndarray.strides)) + error("`data` has different stride from `ndarray.strides`") + end return data::AbstractArray end @@ -427,8 +465,12 @@ struct NDArrayChunk ndarray::NDArray function NDArrayChunk(start::Vector{Int64}, ndarray::NDArray) - @assert length(start) == length(ndarray.strides) - @assert all(start .>= 0) + if length(start) != length(ndarray.strides) + error("`start` and `strides` have a different number of elements") + end + if any(start .< 0) + error("`start` cannot contain negative values") + end return new(start, ndarray) end end @@ -453,13 +495,23 @@ struct ChunkedNDArray chunks::AbstractVector{NDArrayChunk} function ChunkedNDArray(shape::Vector{Int64}, datatype::Datatype, chunks::Vector{NDArrayChunk}) - @assert all(shape .>= 0) + if any(shape .< 0) + error("`shape` cannot contain negative values") + end for chunk in chunks - @assert length(chunk.start) == length(shape) + if length(chunk.start) != length(shape) + error("Different number of dimensions specified by `chunks` and `shape`") + end # We allow overlaps and gaps in the chunks - @assert all(chunk.start .<= shape) - @assert all(chunk.start + chunk.ndarray.shape .<= shape) - @assert chunk.ndarray.datatype == datatype + if !all(chunk.start .<= shape) + error("`chunk.start` exceeds number of elements in dimension") + end + if !all(chunk.start + chunk.ndarray.shape .<= shape) + error("`chunk` exceeds number of elements as specified by `shape`") + end + if chunk.ndarray.datatype != datatype + error("`datatype` and type of `chunk` cannot be different") + end end return new(shape, datatype, chunks) end @@ -589,6 +641,7 @@ function YAML._print(io::IO, val::NDArrayWrapper, level::Int=0, ignore_level::Bo else global blocks source = length(blocks.arrays) + # `write_file()` has a corresponding `push!()` to `blocks.positions` push!(blocks.arrays, val) ndarray = Dict( :source => source::Integer, @@ -684,7 +737,7 @@ function write_file(filename::AbstractString, document::Dict{Any,Any}) elseif array.compression == C_Zstd encode_options = ZstdEncodeOptions(; compressionLevel=22) else - @assert false + error("`array` has invalid state: `compression` field has value not specified in `Compression` enum.") end data = encode(encode_options, input) end @@ -724,9 +777,16 @@ function write_file(filename::AbstractString, document::Dict{Any,Any}) # Check consistency endpos = position(io) - @assert endpos == pos + 6 + header_size + allocated_size + if endpos != pos + 6 + header_size + allocated_size + error("Ending position does not match number of bytes written") + end + end + if length(blocks.positions) != length(blocks.arrays) + error( + "Global `blocks` has invalid state: number of arrays does not match number of `positions`. ", + "Check for mismatches between `write_file()` and `YAML._print()`." + ) end - @assert length(blocks.positions) == length(blocks.arrays) # Write block list println(io, "#ASDF BLOCK INDEX")