-
Notifications
You must be signed in to change notification settings - Fork 4
Added operations #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import torch | ||
| import torch.nn.functional as F | ||
| from dnnv.nn import Operation, OperationGraph, OperationVisitor, operations | ||
| from .utils import ONNX_TO_TORCH_DTYPE | ||
|
|
||
|
|
||
| def convert(op_graph: OperationGraph) -> PytorchModel: | ||
|
|
@@ -235,9 +236,9 @@ def visit_Gather(self, operation: operations.Gather): | |
|
|
||
| def gather(operation_graph): | ||
| x = torch.as_tensor(operation_graph[operation.x]) | ||
| axis = int(operation.axis) | ||
| indices = torch.as_tensor(operation_graph[operation.indices]) | ||
| result = torch.gather(x, axis, indices) | ||
| indices = [slice(None)] * x.ndim | ||
| indices[operation.axis] = operation.indices | ||
| result = x[indices] | ||
| return result | ||
|
|
||
| return gather | ||
|
|
@@ -407,12 +408,12 @@ def split(operation_graph): | |
| def visit_Sub(self, operation: operations.Sub): | ||
| self.generic_visit(operation) | ||
|
|
||
| def add(operation_graph): | ||
| def sub(operation_graph): | ||
| a = operation_graph[operation.a] | ||
| b = operation_graph[operation.b] | ||
| return a - b | ||
|
|
||
| return add | ||
| return sub | ||
|
|
||
| def visit_Tanh(self, operation: operations.Tanh): | ||
| self.generic_visit(operation) | ||
|
|
@@ -445,5 +446,87 @@ def unsqueeze(operation_graph): | |
|
|
||
| return unsqueeze | ||
|
|
||
| def visit_Upsample(self, operation: operations.Upsample): | ||
| self.generic_visit(operation) | ||
|
|
||
| def upsample(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| scales = operation.scales.tolist() | ||
| mode = operation.mode | ||
| result = torch.nn.Upsample(scale_factor=tuple(scales[2:]), mode=mode)(x) | ||
| return result | ||
|
|
||
| return upsample | ||
|
|
||
| def visit_Div(self, operation: operations.Div): | ||
| self.generic_visit(operation) | ||
|
|
||
| def div(operation_graph): | ||
| a = operation_graph[operation.a] | ||
| b = operation_graph[operation.b] | ||
| result = torch.div(a, b) | ||
| return result | ||
|
|
||
| return div | ||
|
|
||
| def visit_Squeeze(self, operation: operations.Squeeze): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you type all of the new operators using strings, e.g., |
||
| self.generic_visit(operation) | ||
|
|
||
| def squeeze(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| axes = operation.axes | ||
| if axes is None: | ||
| result = torch.squeeze(x) | ||
| else: | ||
| result = torch.squeeze(x, dim=axes) | ||
| return result | ||
|
|
||
| return squeeze | ||
|
|
||
| def visit_Expand(self, operation: operations.Expand): | ||
| self.generic_visit(operation) | ||
|
|
||
| def expand(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| shape = operation_graph[operation.shape] | ||
| result = x.expand(shape) | ||
| return result | ||
|
|
||
| return expand | ||
|
|
||
| def visit_Clip(self, operation: operations.Clip): | ||
| self.generic_visit(operation) | ||
|
|
||
| def clip(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| _min = operation.min | ||
| _max = operation.max | ||
| result = torch.clip(x, _min, _max) | ||
| return result | ||
|
|
||
| return clip | ||
|
|
||
| def visit_ReduceL2(self, operation: operations.ReduceL2): | ||
| self.generic_visit(operation) | ||
|
|
||
| def reducel2(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| axes = operation.axes | ||
| keepdims = operation.keepdims | ||
| result = torch.norm(x, p=2, dim=axes, keepdim=bool(keepdims)) | ||
| return result | ||
|
|
||
| return reducel2 | ||
|
|
||
| def visit_Cast(self, operation: operations.Cast): | ||
| self.generic_visit(operation) | ||
|
|
||
| def cast(operation_graph): | ||
| x = operation_graph[operation.x] | ||
| result = x.type(ONNX_TO_TORCH_DTYPE[operation.to]) | ||
| return result | ||
|
|
||
| return cast | ||
|
|
||
|
|
||
| __all__ = ["convert", "PytorchConverter"] | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -220,6 +220,7 @@ def suffixed_op_graph(self) -> OperationGraph: | |||
| import dnnv.nn.operations as operations | ||||
|
|
||||
| output_shape = self.op_graph.output_shape[0] | ||||
| # axis = 0 | ||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove commented code
Suggested change
|
||||
| axis = (0, 0, 1)[len(output_shape)] | ||||
| if len(self.op_graph.output_operations) == 1: | ||||
| new_output_op = self.op_graph.output_operations[0] | ||||
|
|
@@ -336,22 +337,23 @@ def add_constraint(self, variables, indices, coeffs, b, is_open): | |||
| self.interval_constraints[0][flat_index] = max(b / coeff, current_bound) | ||||
|
|
||||
| def build(self) -> HPolyProperty: | ||||
| Ab = np.vstack(self.hpoly_constraints) | ||||
| A = Ab[..., :-1] | ||||
| b = Ab[..., -1:] | ||||
| bounds = tuple(zip(*self.interval_constraints)) | ||||
| for i in range(self.num_vars): | ||||
| c = np.zeros(self.num_vars) | ||||
| c[i] = 1 | ||||
| result = linprog(c, A, b, bounds=bounds, method="highs") | ||||
| if result.success: | ||||
| current_bound = self.interval_constraints[0][i] | ||||
| self.interval_constraints[0][i] = max(result.x[i], current_bound) | ||||
| c[i] = -1 | ||||
| result = linprog(c, A, b, bounds=bounds, method="highs") | ||||
| if result.success: | ||||
| current_bound = self.interval_constraints[1][i] | ||||
| self.interval_constraints[1][i] = min(result.x[i], current_bound) | ||||
| if self.hpoly_constraints: | ||||
| Ab = np.vstack(self.hpoly_constraints) | ||||
| A = Ab[..., :-1] | ||||
| b = Ab[..., -1:] | ||||
| bounds = tuple(zip(*self.interval_constraints)) | ||||
| for i in range(self.num_vars): | ||||
| c = np.zeros(self.num_vars) | ||||
| c[i] = 1 | ||||
| result = linprog(c, A, b, bounds=bounds, method="highs") | ||||
| if result.success: | ||||
| current_bound = self.interval_constraints[0][i] | ||||
| self.interval_constraints[0][i] = max(result.x[i], current_bound) | ||||
| c[i] = -1 | ||||
| result = linprog(c, A, b, bounds=bounds, method="highs") | ||||
| if result.success: | ||||
| current_bound = self.interval_constraints[1][i] | ||||
| self.interval_constraints[1][i] = min(result.x[i], current_bound) | ||||
| return HPolyProperty.build( | ||||
| self.input_vars, | ||||
| self.output_vars, | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put these in alphabetical order?