diff --git a/src/NPZ.jl b/src/NPZ.jl index adbf1bb..dd609e8 100644 --- a/src/NPZ.jl +++ b/src/NPZ.jl @@ -6,7 +6,7 @@ module NPZ using ZipFile, FileIO import Base.CodeUnits -export npzread, npzwrite +export npzread, npzwrite, npzreadfrom, npzwriteto const NPYMagic = UInt8[0x93, 'N', 'U', 'M', 'P', 'Y'] const ZIPMagic = UInt8['P', 'K', 3, 4] @@ -220,7 +220,39 @@ function _npzreadarray(f, hdr::Header{T}) where {T} ndims(x) == 0 ? x[1] : x end -function npzreadarray(f::IO) +""" + npzreadfrom(file::IO, [vars]) + +Read a variable or a collection of variables from `file`. +The input needs to be either an `npy` or an `npz` file. +The optional argument `vars` is used only for `npz` files. +If it is specified, only the matching variables are read in from the file. + +!!! note "Zero-dimensional arrays" + Zero-dimensional arrays are stripped while being read in, and the values that they + contain are returned. This is a notable difference from numpy, where + numerical values are written out and read back in as zero-dimensional arrays. + +# Examples + +```julia +julia> npzwriteto("temp.npz", x = ones(3), y = 3) + +julia> open("temp.npz", "rb") do f + return npzreadfrom(f) # Reads all variables +end +Dict{String,Any} with 2 entries: + "x" => [1.0, 1.0, 1.0] + "y" => 3 + +julia> open("temp.npz", "rb") do f + return npzreadfrom(f, ["x"]) # Reads only "x" +end +Dict{String,Array{Float64,1}} with 1 entry: + "x" => [1.0, 1.0, 1.0] +``` +""" +function npzreadfrom(f::IO) hdr = readheader(f) _npzreadarray(f, hdr) end @@ -277,7 +309,7 @@ function npzread(filename::AbstractString, vars...) close(fz) elseif samestart(b, NPYMagic) seekstart(f) - data = npzreadarray(f) + data = npzreadfrom(f) else close(f) error("not a NPY or NPZ/Zip file: $filename") @@ -289,7 +321,7 @@ end function npzread(dir::ZipFile.Reader, vars = map(f -> _maybetrimext(f.name), dir.files)) - Dict(_maybetrimext(f.name) => npzreadarray(f) + Dict(_maybetrimext(f.name) => npzreadfrom(f) for f in dir.files if f.name in vars || _maybetrimext(f.name) in vars) end @@ -329,7 +361,7 @@ function readheader(dir::ZipFile.Reader, if f.name in vars || _maybetrimext(f.name) in vars) end -function npzwritearray( +function npzwriteto( f::IO, x::AbstractArray{UInt8}, T::DataType, shape) if !haskey(Julia2Numpy, T) @@ -356,12 +388,32 @@ function npzwritearray( end end -function npzwritearray(f::IO, x::AbstractArray) - npzwritearray(f, reinterpret(UInt8, vec(x)), eltype(x), size(x)) +""" + npzwriteto(file::IO, x) + +Write the variable `x` to the `npy` file `file`. +The file should be open in writable binary mode "wb". + +# Examples + +```julia +julia> open("abc.npy", "wb") do f + npzwriteto(f, zeros(3)) +end + +julia> npzread("abc.npy") +3-element Array{Float64,1}: + 0.0 + 0.0 + 0.0 +``` +""" +function npzwriteto(f::IO, x::AbstractArray) + npzwriteto(f, reinterpret(UInt8, vec(x)), eltype(x), size(x)) end -function npzwritearray(f::IO, x::Number) - npzwritearray(f, reinterpret(UInt8, [x]), typeof(x), ()) +function npzwriteto(f::IO, x::Number) + npzwriteto(f, reinterpret(UInt8, [x]), typeof(x), ()) end """ @@ -387,7 +439,7 @@ julia> npzread("abc.npy") """ function npzwrite(filename::AbstractString, x) open(filename, "w") do f - npzwritearray(f, x) + npzwriteto(f, x) end end @@ -434,7 +486,7 @@ function npzwrite(filename::AbstractString, vars::Dict{<:AbstractString}) for (k, v) in vars f = ZipFile.addfile(dir, k * ".npy") - npzwritearray(f, v) + npzwriteto(f, v) close(f) end