Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 96 additions & 39 deletions src/ASDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ const compression_keys = Dict{Compression,Vector{UInt8}}(
C_Zlib => Vector{UInt8}("zlib"),
C_Zstd => Vector{UInt8}("zstd"),
)
@assert all(length(val)==4 for val in values(compression_keys))
if !all(length(val)==4 for val in values(compression_keys))
error("Invalid entry in `compression_keys`, please ensure all values have length 4.")
end
const compression_enums = Dict{Vector{UInt8},Compression}(
value => key for (key, value) in compression_keys
)
Expand Down Expand Up @@ -115,8 +117,9 @@ function read_block_header(io::IO, position::Int64)
header = Array{UInt8}(undef, max_header_size)
seek(io, position)
nb = readbytes!(io, header)
# TODO: Better error message
@assert nb == length(header)
if nb != length(header)
error("Number of bytes read from stream does not match length of header")
end

# Decode block header
token = @view header[1:4]
Expand All @@ -128,15 +131,19 @@ function read_block_header(io::IO, position::Int64)
data_size = big2native_U64(@view header[31:38])
checksum = @view header[39:54]

# TODO: Better error message
@assert token == block_magic_token
if token != block_magic_token
error("Block does not start with magic number")
end

STREAMED = Bool(flags & 0x1)
# We don't handle streamed blocks yet
@assert !STREAMED
if STREAMED
error("ASDF.jl does not support streamed blocks")
end

# TODO: Better error message
@assert allocated_size >= used_size
if allocated_size < used_size
error("ASDF file header incorrectly specifies amount of space to use")
end

return BlockHeader(io, position, token, header_size, flags, compression, allocated_size, used_size, data_size, checksum)
end
Expand All @@ -158,14 +165,16 @@ function read_block(header::BlockHeader)
seek(header.io, block_data_start)
data = Array{UInt8}(undef, header.used_size)
nb = readbytes!(header.io, data)
# TODO: Better error message
@assert nb == length(data)
if nb != length(data)
error("Number of bytes read from `header` does not match length of `data`")
end

# Check checksum
if any(header.checksum != 0)
actual_checksum = md5(data)
# TODO: Better error message
@assert all(actual_checksum == header.checksum)
if any(actual_checksum != header.checksum)
error("Checksum mismatch in ASDF file header")
end
end

# Decompress data
Expand All @@ -187,15 +196,15 @@ function read_block(header::BlockHeader)
elseif compression == C_Zstd
codec = ZstdCodec()
else
# TODO: Better error message
@assert false
error("Invalid compression format found: $compression")
end
data = decode(codec, data)
end
data::AbstractVector{UInt8}

# TODO: Better error message
@assert length(data) == header.data_size
if length(data) != header.data_size
error("Actual data size different from declared data size in header.")
end

return data
end
Expand Down Expand Up @@ -318,14 +327,32 @@ struct NDArray
offset::Int64,
strides::Vector{Int64},
)
@assert (source === nothing) + (data === nothing) == 1
@assert source === nothing || source >= 0
@assert data === nothing || eltype(data) == Type(datatype)
@assert data === nothing || size(data) == Tuple(reverse(shape))
@assert offset >= 0
@assert length(shape) == length(strides)
@assert all(shape .>= 0)
@assert all(strides .> 0)
if (source === nothing) + (data === nothing) != 1
throw(ArgumentError("Exactly one of `source` or `data` must be `nothing`."))
end
if source !== nothing && source < 0
throw(ArgumentError("`source` must be >= 0 if provided"))
end
if data !== nothing
if eltype(data) != Type(datatype)
throw(ArgumentError("`data` must contain elements of type given by `datatype`."))
end
if size(data) != Tuple(reverse(shape))
throw(ArgumentError("`shape` does not correctly describe the shape of `data`."))
end
end
if offset < 0
throw(ArgumentError("`offset` must be >= 0"))
end
if length(shape) != length(strides)
throw(DimensionMismatch("`shape` and `strides` must have the same length."))
end
if any(shape .< 0)
throw(ArgumentError("`shape` cannot have negative elements."))
end
if !all(strides .> 0)
throw(ArgumentError("`strides` must have only positive elements."))
end
return new(lazy_block_headers, source, data, shape, datatype, byteorder, offset, strides)
end
end
Expand Down Expand Up @@ -389,7 +416,9 @@ end
function Base.getindex(ndarray::NDArray)
if ndarray.data !== nothing
data = ndarray.data
@assert ndarray.byteorder == host_byteorder
if ndarray.byteorder != host_byteorder
error("ndarray byteorder does not match system byteorder")
end
elseif ndarray.source !== nothing
data = read_block(ndarray.lazy_block_headers.block_headers[ndarray.source + 1])
# Handle strides and offset.
Expand All @@ -409,13 +438,19 @@ function Base.getindex(ndarray::NDArray)
map!(bswap, data, data)
end
else
@assert false
error("`ndarray` is in invalid state, both `data` and `source` are `nothing`.")
end

# Check array layout
@assert size(data) == Tuple(reverse(ndarray.shape))
@assert eltype(data) == Type(ndarray.datatype)
@assert sizeof(eltype(data)) .* Base.strides(data) == Tuple(reverse(ndarray.strides))
if size(data) != Tuple(reverse(ndarray.shape))
error("`data` does not conform to specified `ndarray.shape`")
end
if eltype(data) != Type(ndarray.datatype)
error("`data` does not match type specified by `ndarray.datatype`")
end
if sizeof(eltype(data)) .* Base.strides(data) != Tuple(reverse(ndarray.strides))
error("`data` has different match stride specified by `ndarray.strides`")
end

return data::AbstractArray
end
Expand All @@ -427,8 +462,12 @@ struct NDArrayChunk
ndarray::NDArray

function NDArrayChunk(start::Vector{Int64}, ndarray::NDArray)
@assert length(start) == length(ndarray.strides)
@assert all(start .>= 0)
if length(start) != length(ndarray.strides)
error("`start` of chunk does not match the number of `strides` in ndarray.")
end
if any(start .< 0)
error("`start` cannot contain negative values")
end
return new(start, ndarray)
end
end
Expand All @@ -453,13 +492,23 @@ struct ChunkedNDArray
chunks::AbstractVector{NDArrayChunk}

function ChunkedNDArray(shape::Vector{Int64}, datatype::Datatype, chunks::Vector{NDArrayChunk})
@assert all(shape .>= 0)
if any(shape .< 0)
error("`shape` cannot contain negative values")
end
for chunk in chunks
@assert length(chunk.start) == length(shape)
if length(chunk.start) != length(shape)
error("Incorrect number of dimensions specified by `chunks` and `shape`")
end
# We allow overlaps and gaps in the chunks
@assert all(chunk.start .<= shape)
@assert all(chunk.start + chunk.ndarray.shape .<= shape)
@assert chunk.ndarray.datatype == datatype
if !all(chunk.start .<= shape)
error("`chunk.start` exceeds number of elements in dimension")
end
if !all(chunk.start + chunk.ndarray.shape .<= shape)
error("`chunk` exceeds number of elements as specified by `shape`")
end
if chunk.ndarray.datatype != datatype
error("`datatype` and type of `chunk` cannot be different")
end
end
return new(shape, datatype, chunks)
end
Expand Down Expand Up @@ -589,6 +638,7 @@ function YAML._print(io::IO, val::NDArrayWrapper, level::Int=0, ignore_level::Bo
else
global blocks
source = length(blocks.arrays)
# `write_file()` has a corresponding `push!()` to `blocks.positions`
push!(blocks.arrays, val)
ndarray = Dict(
:source => source::Integer,
Expand Down Expand Up @@ -684,7 +734,7 @@ function write_file(filename::AbstractString, document::Dict{Any,Any})
elseif array.compression == C_Zstd
encode_options = ZstdEncodeOptions(; compressionLevel=22)
else
@assert false
error("`array` has invalid state: `compression` field has value not specified in `Compression` enum.")
end
data = encode(encode_options, input)
end
Expand Down Expand Up @@ -724,9 +774,16 @@ function write_file(filename::AbstractString, document::Dict{Any,Any})

# Check consistency
endpos = position(io)
@assert endpos == pos + 6 + header_size + allocated_size
if endpos != pos + 6 + header_size + allocated_size
error("Ending position does not match number of bytes written")
end
end
if length(blocks.positions) != length(blocks.arrays)
error(
"Global `blocks` has invalid state: number of arrays does not match number of `positions`. ",
"Check for mismatches between `write_file()` and `YAML._print()`."
)
end
@assert length(blocks.positions) == length(blocks.arrays)

# Write block list
println(io, "#ASDF BLOCK INDEX")
Expand Down
Loading