Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions src/NPZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module NPZ
using ZipFile, FileIO
import Base.CodeUnits

export npzread, npzwrite
export npzread, npzwrite, npzreadfrom, npzwriteto

const NPYMagic = UInt8[0x93, 'N', 'U', 'M', 'P', 'Y']
const ZIPMagic = UInt8['P', 'K', 3, 4]
Expand Down Expand Up @@ -220,7 +220,39 @@ function _npzreadarray(f, hdr::Header{T}) where {T}
ndims(x) == 0 ? x[1] : x
end

function npzreadarray(f::IO)
"""
npzreadfrom(file::IO, [vars])

Read a variable or a collection of variables from `file`.
The input needs to be either an `npy` or an `npz` file.
The optional argument `vars` is used only for `npz` files.
If it is specified, only the matching variables are read in from the file.

!!! note "Zero-dimensional arrays"
Zero-dimensional arrays are stripped while being read in, and the values that they
contain are returned. This is a notable difference from numpy, where
numerical values are written out and read back in as zero-dimensional arrays.

# Examples

```julia
julia> npzwriteto("temp.npz", x = ones(3), y = 3)

julia> open("temp.npz", "rb") do f
return npzreadfrom(f) # Reads all variables
end
Dict{String,Any} with 2 entries:
"x" => [1.0, 1.0, 1.0]
"y" => 3

julia> open("temp.npz", "rb") do f
return npzreadfrom(f, ["x"]) # Reads only "x"
end
Dict{String,Array{Float64,1}} with 1 entry:
"x" => [1.0, 1.0, 1.0]
```
"""
function npzreadfrom(f::IO)
hdr = readheader(f)
_npzreadarray(f, hdr)
end
Expand Down Expand Up @@ -277,7 +309,7 @@ function npzread(filename::AbstractString, vars...)
close(fz)
elseif samestart(b, NPYMagic)
seekstart(f)
data = npzreadarray(f)
data = npzreadfrom(f)
else
close(f)
error("not a NPY or NPZ/Zip file: $filename")
Expand All @@ -289,7 +321,7 @@ end
function npzread(dir::ZipFile.Reader,
vars = map(f -> _maybetrimext(f.name), dir.files))

Dict(_maybetrimext(f.name) => npzreadarray(f)
Dict(_maybetrimext(f.name) => npzreadfrom(f)
for f in dir.files
if f.name in vars || _maybetrimext(f.name) in vars)
end
Expand Down Expand Up @@ -329,7 +361,7 @@ function readheader(dir::ZipFile.Reader,
if f.name in vars || _maybetrimext(f.name) in vars)
end

function npzwritearray(
function npzwriteto(
f::IO, x::AbstractArray{UInt8}, T::DataType, shape)

if !haskey(Julia2Numpy, T)
Expand All @@ -356,12 +388,32 @@ function npzwritearray(
end
end

function npzwritearray(f::IO, x::AbstractArray)
npzwritearray(f, reinterpret(UInt8, vec(x)), eltype(x), size(x))
"""
npzwriteto(file::IO, x)

Write the variable `x` to the `npy` file `file`.
The file should be open in writable binary mode "wb".

# Examples

```julia
julia> open("abc.npy", "wb") do f
npzwriteto(f, zeros(3))
end

julia> npzread("abc.npy")
3-element Array{Float64,1}:
0.0
0.0
0.0
```
"""
function npzwriteto(f::IO, x::AbstractArray)
npzwriteto(f, reinterpret(UInt8, vec(x)), eltype(x), size(x))
end

function npzwritearray(f::IO, x::Number)
npzwritearray(f, reinterpret(UInt8, [x]), typeof(x), ())
function npzwriteto(f::IO, x::Number)
npzwriteto(f, reinterpret(UInt8, [x]), typeof(x), ())
end

"""
Expand All @@ -387,7 +439,7 @@ julia> npzread("abc.npy")
"""
function npzwrite(filename::AbstractString, x)
open(filename, "w") do f
npzwritearray(f, x)
npzwriteto(f, x)
end
end

Expand Down Expand Up @@ -434,7 +486,7 @@ function npzwrite(filename::AbstractString, vars::Dict{<:AbstractString})

for (k, v) in vars
f = ZipFile.addfile(dir, k * ".npy")
npzwritearray(f, v)
npzwriteto(f, v)
close(f)
end

Expand Down