Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ jobs:
choco install protoc -y
protoc --version

- name: Update Rust dependencies
run: cargo update

- name: Cache cargo registry
uses: actions/cache@v4
with:
Expand Down Expand Up @@ -110,6 +113,9 @@ jobs:
sudo apt-get install -y protobuf-compiler
protoc --version

- name: Update Rust dependencies
run: cargo update

- name: Cache cargo registry
uses: actions/cache@v4
with:
Expand Down Expand Up @@ -154,6 +160,9 @@ jobs:
sudo apt-get install -y protobuf-compiler
protoc --version

- name: Update Rust dependencies
run: cargo update

- name: Cache cargo registry
uses: actions/cache@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ half = "2.4"
regex = "1.11"
# Optional runtime dependencies
ndarray = { version = "0.16", optional = true }
ort = { version = "2.0.0-rc.11", optional = true, features = ["ndarray", "half", "load-dynamic"] }
ort = { version = "=2.0.0-rc.11", optional = true, features = ["ndarray", "half", "load-dynamic"] }
objc = { version = "0.2", optional = true }

# Make this crate independent from the parent workspace
Expand Down
4 changes: 2 additions & 2 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ Sigmoid activation: `1 / (1 + exp(-x))`

Hyperbolic tangent activation

#### `softmax(x)`
#### `softmax(x, axis)`

Softmax activation (normalizes to probability distribution)

Expand All @@ -1045,7 +1045,7 @@ x = builder.input("x", [1, 10], "float32")
relu_out = builder.relu(x)
sigmoid_out = builder.sigmoid(x)
tanh_out = builder.tanh(x)
softmax_out = builder.softmax(x)
softmax_out = builder.softmax(x, axis=1)
```

### Shape Operations
Expand Down
2 changes: 1 addition & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class SimpleClassifier:
logits = builder.add(logits, b2)

# Softmax
output = builder.softmax(logits)
output = builder.softmax(logits, axis=len(logits.shape) - 1)

# Build
self.graph = builder.build({"probabilities": output})
Expand Down
2 changes: 1 addition & 1 deletion examples/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def build_simple_classifier(builder, num_classes=1000):
logits = builder.gemm(flattened, fc_weights, b_transpose=True)

# Softmax for class probabilities
output = builder.softmax(logits)
output = builder.softmax(logits, axis=len(logits.shape) - 1)

return output

Expand Down
2 changes: 1 addition & 1 deletion examples/mobilenetv2_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def build_complete_mobilenetv2(builder, weights):
x = builder.add(x, fc_b)

# Softmax
output = builder.softmax(x)
output = builder.softmax(x, axis=len(x.shape) - 1)

print(" ✓ Complete MobileNetV2 graph built!")
return output
Expand Down
2 changes: 1 addition & 1 deletion examples/mobilenetv2_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def build_mobilenetv2_simple(builder, weights):
x = builder.add(x, fc_bias)

# Softmax
output = builder.softmax(x)
output = builder.softmax(x, axis=len(x.shape) - 1)

return output

Expand Down
2 changes: 1 addition & 1 deletion examples/text_generation_enhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def build_graph_with_kv(self, builder, x_input, cached_k, cached_v):
# Output projection
logits = builder.gemm(out, W_out)
logits = builder.add(logits, b_out)
probs = builder.softmax(logits)
probs = builder.softmax(logits, axis=len(logits.shape) - 1)

# Note: new_k and new_v are returned separately for caching
return probs # In full implementation, return (new_k, new_v, probs)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_generation_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def build_graph(self, builder, x_input):
logits = builder.add(logits, b_out)

# Apply softmax to get probabilities
probs = builder.softmax(logits)
probs = builder.softmax(logits, axis=len(logits.shape) - 1)

return probs

Expand Down
2 changes: 1 addition & 1 deletion python/webnn/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class MLGraphBuilder:
"""Tanh activation"""
...

def softmax(self, x: MLOperand) -> MLOperand:
def softmax(self, x: MLOperand, axis: int) -> MLOperand:
"""Softmax activation"""
...

Expand Down
29 changes: 29 additions & 0 deletions src/python/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,35 @@ impl PyMLContext {
inputs: &Bound<'_, PyDict>,
_outputs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyDict>> {
// Validate slice ops: starts/sizes length must match; empty only allowed for 0D input (no-op).
for (idx, op) in graph.graph_info.operations.iter().enumerate() {
if op.op_type.eq_ignore_ascii_case("slice") {
if let Some(o) = op.attributes.as_slice() {
if o.starts.len() != o.sizes.len() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Slice operation at index {} has mismatched starts/sizes (starts.len()={}, sizes.len()={}).",
idx, o.starts.len(), o.sizes.len()
)));
}
let both_empty = o.starts.is_empty() && o.sizes.is_empty();
if both_empty {
let input_rank = op
.input_operands
.first()
.and_then(|&id| graph.graph_info.operand(id))
.map(|operand| operand.descriptor.static_or_max_shape().len())
.unwrap_or(usize::MAX);
if input_rank != 0 {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Slice operation at index {} has empty starts/sizes but input rank is {} (only 0D no-op slice may have empty starts/sizes).",
idx, input_rank
)));
}
}
}
}
}

// Route to appropriate backend based on context's backend selection
match self.backend {
Backend::OnnxCpu | Backend::OnnxGpu => self.compute_onnx(py, graph, inputs),
Expand Down
2 changes: 1 addition & 1 deletion src/python/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl PyMLGraph {
continue;
}

let axes = op.attributes.get("axes").and_then(parse_i64_array);
let axes = op.attributes.get("axes").and_then(|v| parse_i64_array(&v));
let Some(axes) = axes else {
continue;
};
Expand Down
Loading
Loading