diff --git a/maraboupy/parsers/ONNXParser.py b/maraboupy/parsers/ONNXParser.py index b255a24478..ea5f0d8fff 100644 --- a/maraboupy/parsers/ONNXParser.py +++ b/maraboupy/parsers/ONNXParser.py @@ -23,7 +23,7 @@ from onnx import TensorProto import itertools from copy import copy -from onnx.reference.ops._op_list import Split_18, Unsqueeze_1 +from onnx.reference.ops._op_list import Split_18, Unsqueeze_1, Gather class ONNXParser: """ @@ -171,6 +171,10 @@ def makeMarabouEquations(self, nodeName, makeEquations): self.reshape(node) elif node.op_type == 'Flatten': self.flatten(node) + elif node.op_type == 'Gather': + self.gather(node) + elif node.op_type == 'Shape': + self.shape_(node) elif node.op_type == "Transpose": self.transpose(node) elif node.op_type == 'Unsqueeze': @@ -414,6 +418,20 @@ def flatten(self, node): elif inputName in self.constantMap: self.constantMap[nodeName] = self.constantMap[inputName].reshape(newShape) + def shape_(self, node): + """Function representing a shape tensor + Args: + node (node): ONNX node representing shape operation + + :meta private: + """ + nodeName = node.output[0] + inputName = node.input[0] + + shape = self.shapeMap[inputName] + self.constantMap[nodeName] = np.array(shape) + self.shapeMap[nodeName] = [len(shape)] + def transpose(self, node): """Function representing transpose @@ -440,6 +458,34 @@ def transpose(self, node): elif inputName in self.constantMap: self.constantMap[nodeName] = np.transpose(self.constantMap[inputName], perm) + def gather(self, node): + """Function representing Gather + + Args: + node (node): ONNX node representing gather operation + + :meta private: + """ + nodeName = node.output[0] + inputName = node.input[0] + if node.input[1] not in self.constantMap: + raise RuntimeError("Indices of Gather is not a constant.") + indices = self.constantMap[node.input[1]] + + axis=None + for attr in node.attribute: + if attr.name == "axis": + axis = get_attribute_value(attr) + + if inputName in self.varMap: + output_data = Gather.eval(self.varMap[inputName], indices, axis=axis) + self.shapeMap[nodeName] = output_data.shape + self.varMap[nodeName] = output_data + else: + output_data = Gather.eval(self.constantMap[inputName], indices, axis=axis) + self.shapeMap[nodeName] = output_data.shape + self.constantMap[nodeName] = output_data + def unsqueeze(self, node): """Function representing unsqueeze @@ -461,7 +507,6 @@ def unsqueeze(self, node): self.shapeMap[nodeName] = output_data.shape self.constantMap[nodeName] = output_data - def squeeze(self, node): """Function representing squeeze @@ -913,11 +958,23 @@ def concatEquations(self, node): if attr.name == "axis": axis = get_attribute_value(attr) - # Set maps of shape and var - inputVars = list([self.varMap[input] for input in node.input]) - outputVars = np.concatenate(inputVars, axis) - self.shapeMap[nodeName] = outputVars.shape - self.varMap[nodeName] = outputVars + allVars = all(input in self.varMap for input in node.input) + allConstants = all(input in self.constantMap for input in node.input) + if allVars: + # Set maps of shape and var + inputVars = list([self.varMap[input] for input in node.input]) + outputVars = np.concatenate(inputVars, axis) + self.shapeMap[nodeName] = outputVars.shape + self.varMap[nodeName] = outputVars + elif allConstants: + # Set maps of shape and constants + inputs = list([self.constantMap[input] for input in node.input]) + outputs = np.concatenate(inputs, axis) + self.shapeMap[nodeName] = outputs.shape + self.constantMap[nodeName] = outputs + else: + raise RuntimeError("Concat inputs need to be all variables or all constants.") + def splitEquations(self, node, nodeName, makeEquations): """Function to generate equations corresponding to split diff --git a/maraboupy/test/test_onnx.py b/maraboupy/test/test_onnx.py index 0087035aa8..55b51c4353 100644 --- a/maraboupy/test/test_onnx.py +++ b/maraboupy/test/test_onnx.py @@ -60,6 +60,13 @@ def test_split_onnx_error(): os.remove(presplit_filename) +def test_concat_const_and_gather_network(): + """ + Test an onnx file that actually contains two disjoint network + """ + filename = "test_gather.onnx" + evaluateFile(filename) + def test_concat_network(): """ Test an onnx file that actually contains two disjoint network diff --git a/resources/onnx/test_gather.onnx b/resources/onnx/test_gather.onnx new file mode 100644 index 0000000000..b31b45c367 Binary files /dev/null and b/resources/onnx/test_gather.onnx differ