Skip to content
This repository was archived by the owner on Oct 15, 2019. It is now read-only.
This repository was archived by the owner on Oct 15, 2019. It is now read-only.

MNIST Solver cannot check accuracy when y is onehot encoded #180

@kkweon

Description

@kkweon

Setup

  • Ubuntu 16.04
  • MXNet (ver 0.10.0)
  • Minpy (ver 0.3.4)

What I was trying to do

  • Run MNIST but with y_onehot
    • X.shape = (n, 784)
    • y.shape = (n, 10)
  • Then I get the following message
    • It seems accuracy step cannot be performed due to wrong shape because it expects y.shape to be (n,) instead
(skipped)
(Iteration 91 / 1074) loss: 0.710763
(Iteration 101 / 1074) loss: 0.691108
[15:44:48] /home/travis/build/dmlc/mxnet-distro/mxnet-build/dmlc-core/include/dmlc/logging.h:304: [15:44:48] src/operator/tensor/./elemwise_binary_broadcast_op.h:48: Check failed: l == 1 || r == 1 operands could not be broadcast together with shapes (512,) (512,10)

Stack trace returned 10 entries:
[bt] (0) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x18b0dc) [0x7f1d2c2f30dc]
[bt] (1) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4c2ce2) [0x7f1d2c62ace2]
[bt] (2) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0xb3dadd) [0x7f1d2cca5add]
[bt] (3) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(MXImperativeInvoke+0x4b8) [0x7f1d2cca8c78]
[bt] (4) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7f1d50390550]
[bt] (5) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call+0x1f5) [0x7f1d5038fcf5]
[bt] (6) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x3dc) [0x7f1d5038783c]
[bt] (7) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x9da3) [0x7f1d5037fda3]
[bt] (8) /home/kkweon/anaconda3/bin/../lib/libpython3.6m.so.1.0(_PyObject_FastCallDict+0x9e) [0x7f1d51adfade]
[bt] (9) /home/kkweon/anaconda3/bin/../lib/libpython3.6m.so.1.0(+0x1482bb) [0x7f1d51bbc2bb]

Traceback (most recent call last):
  File "lab-10-1-mnist-softmax.py", line 64, in <module>
    solver.train()
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/minpy-0.3.4-py3.6.egg/minpy/nn/solver.py", line 275, in train
    self.train_dataiter, num_samples=self.train_acc_num_samples)
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/minpy-0.3.4-py3.6.egg/minpy/nn/solver.py", line 231, in check_accuracy
    acc_count += np.sum(np.argmax(predict, axis=1) == each_batch.label[0])
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/minpy-0.3.4-py3.6.egg/minpy/array.py", line 74, in __eq__
    return Value._ns.equal(self, other)
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/minpy-0.3.4-py3.6.egg/minpy/primitive.py", line 141, in __call__
    return self.call(args, kwargs)
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/minpy-0.3.4-py3.6.egg/minpy/primitive.py", line 238, in call
    result_value = self._func(*arg_values, **kwarg_values)
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/ndarray.py", line 1723, in equal
    None)
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/ndarray.py", line 1247, in _ufunc_helper
    return fn_array(lhs, rhs)
  File "<string>", line 14, in broadcast_equal
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py", line 72, in _imperative_invoke
    c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals])))
  File "/home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/base.py", line 84, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [15:44:48] src/operator/tensor/./elemwise_binary_broadcast_op.h:48: Check failed: l == 1 || r == 1 operands could not be broadcast together with shapes (512,) (512,10)

Stack trace returned 10 entries:
[bt] (0) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x18b0dc) [0x7f1d2c2f30dc]
[bt] (1) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4c2ce2) [0x7f1d2c62ace2]
[bt] (2) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0xb3dadd) [0x7f1d2cca5add]
[bt] (3) /home/kkweon/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(MXImperativeInvoke+0x4b8) [0x7f1d2cca8c78]
[bt] (4) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7f1d50390550]
[bt] (5) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call+0x1f5) [0x7f1d5038fcf5]
[bt] (6) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x3dc) [0x7f1d5038783c]
[bt] (7) /home/kkweon/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x9da3) [0x7f1d5037fda3]
[bt] (8) /home/kkweon/anaconda3/bin/../lib/libpython3.6m.so.1.0(_PyObject_FastCallDict+0x9e) [0x7f1d51adfade]
[bt] (9) /home/kkweon/anaconda3/bin/../lib/libpython3.6m.so.1.0(+0x1482bb) [0x7f1d51bbc2bb]

If I run without onehot_encoded

  • It works fine but I get a bunch of deprecation warning messages regarding onehot_encode
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
[15:56:36] src/ndarray/./ndarray_function-inl.h:68: The operator onehot_encode is deprecated; use one_hot instead.
(Epoch 10 / 10) train acc: 0.9033203125, val acc: 0.921875, time: 0.3320729732513428.

Full code

from minpy.nn import layers
from minpy.nn import model
from minpy.nn import solver
from minpy.nn import io
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

X_train = mnist.train.images
y_train = mnist.train.labels
print(X_train.shape, y_train.shape)  # (55000, 784) (55000, 10)
X_val = mnist.validation.images
y_val = mnist.validation.labels
print(X_val.shape, y_val.shape)  # (5000, 784) (5000, 10)
X_test = mnist.test.images
y_test = mnist.test.labels
print(X_test.shape, y_test.shape)  # (10000, 784) (10000, 10)

class SingleLayerNetwork(model.ModelBase):

    def __init__(self):
        super(SingleLayerNetwork, self).__init__()
        self.add_param("W1", shape=(784, 10))
        self.add_param("b1", shape=(10,))

    def forward(self, X, mode):
        net = X
        net = layers.affine(net, self.params['W1'], self.params['b1'])
        return net

    def loss(self, predict, y):
        return layers.softmax_loss(predict, y)


BATCH_SIZE = 512

train_iter = io.NDArrayIter(data=X_train,
                            label=y_train,
                            batch_size=BATCH_SIZE,
                            shuffle=True)
val_iter = io.NDArrayIter(data=X_val,
                          label=y_val,
                          batch_size=BATCH_SIZE)
test_iter = io.NDArrayIter(data=X_test,
                           label=y_test,
                           batch_size=BATCH_SIZE)


model = SingleLayerNetwork()

solver_config = {
    'model': model,
    'train_dataiter': train_iter,
    'test_dataiter': val_iter,
    'update_rule': 'adam',
    'num_epochs': 10,
    'init_rule': 'gaussian',
}
solver = solver.Solver(**solver_config)
solver.init()
solver.train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions