Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions maraboupy/parsers/ONNXParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from onnx import TensorProto
import itertools
from copy import copy
from onnx.reference.ops._op_list import Split_18, Unsqueeze_1
from onnx.reference.ops._op_list import Split_18, Unsqueeze_1, Gather

class ONNXParser:
"""
Expand Down Expand Up @@ -171,6 +171,10 @@ def makeMarabouEquations(self, nodeName, makeEquations):
self.reshape(node)
elif node.op_type == 'Flatten':
self.flatten(node)
elif node.op_type == 'Gather':
self.gather(node)
elif node.op_type == 'Shape':
self.shape_(node)
elif node.op_type == "Transpose":
self.transpose(node)
elif node.op_type == 'Unsqueeze':
Expand Down Expand Up @@ -414,6 +418,20 @@ def flatten(self, node):
elif inputName in self.constantMap:
self.constantMap[nodeName] = self.constantMap[inputName].reshape(newShape)

def shape_(self, node):
"""Function representing a shape tensor
Args:
node (node): ONNX node representing shape operation

:meta private:
"""
nodeName = node.output[0]
inputName = node.input[0]

shape = self.shapeMap[inputName]
self.constantMap[nodeName] = np.array(shape)
self.shapeMap[nodeName] = [len(shape)]

def transpose(self, node):
"""Function representing transpose

Expand All @@ -440,6 +458,34 @@ def transpose(self, node):
elif inputName in self.constantMap:
self.constantMap[nodeName] = np.transpose(self.constantMap[inputName], perm)

def gather(self, node):
"""Function representing Gather

Args:
node (node): ONNX node representing gather operation

:meta private:
"""
nodeName = node.output[0]
inputName = node.input[0]
if node.input[1] not in self.constantMap:
raise RuntimeError("Indices of Gather is not a constant.")
indices = self.constantMap[node.input[1]]

axis=None
for attr in node.attribute:
if attr.name == "axis":
axis = get_attribute_value(attr)

if inputName in self.varMap:
output_data = Gather.eval(self.varMap[inputName], indices, axis=axis)
self.shapeMap[nodeName] = output_data.shape
self.varMap[nodeName] = output_data
else:
output_data = Gather.eval(self.constantMap[inputName], indices, axis=axis)
self.shapeMap[nodeName] = output_data.shape
self.constantMap[nodeName] = output_data

def unsqueeze(self, node):
"""Function representing unsqueeze

Expand All @@ -461,7 +507,6 @@ def unsqueeze(self, node):
self.shapeMap[nodeName] = output_data.shape
self.constantMap[nodeName] = output_data


def squeeze(self, node):
"""Function representing squeeze

Expand Down Expand Up @@ -913,11 +958,23 @@ def concatEquations(self, node):
if attr.name == "axis":
axis = get_attribute_value(attr)

# Set maps of shape and var
inputVars = list([self.varMap[input] for input in node.input])
outputVars = np.concatenate(inputVars, axis)
self.shapeMap[nodeName] = outputVars.shape
self.varMap[nodeName] = outputVars
allVars = all(input in self.varMap for input in node.input)
allConstants = all(input in self.constantMap for input in node.input)
if allVars:
# Set maps of shape and var
inputVars = list([self.varMap[input] for input in node.input])
outputVars = np.concatenate(inputVars, axis)
self.shapeMap[nodeName] = outputVars.shape
self.varMap[nodeName] = outputVars
elif allConstants:
# Set maps of shape and constants
inputs = list([self.constantMap[input] for input in node.input])
outputs = np.concatenate(inputs, axis)
self.shapeMap[nodeName] = outputs.shape
self.constantMap[nodeName] = outputs
else:
raise RuntimeError("Concat inputs need to be all variables or all constants.")


def splitEquations(self, node, nodeName, makeEquations):
"""Function to generate equations corresponding to split
Expand Down
7 changes: 7 additions & 0 deletions maraboupy/test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_split_onnx_error():

os.remove(presplit_filename)

def test_concat_const_and_gather_network():
"""
Test an onnx file that actually contains two disjoint network
"""
filename = "test_gather.onnx"
evaluateFile(filename)

def test_concat_network():
"""
Test an onnx file that actually contains two disjoint network
Expand Down
Binary file added resources/onnx/test_gather.onnx
Binary file not shown.