Skip to content
This repository was archived by the owner on Apr 19, 2026. It is now read-only.
This repository was archived by the owner on Apr 19, 2026. It is now read-only.

Understanding soft-sorting #15

@vamp-ire-tap

Description

@vamp-ire-tap

Hello,

What would be the difference between the provided implementation (soft-sort) and the torch.sort version. Sorry for the stupid question, but I am not able to see how the torch sort non-differentiable is really different to the soft sort.

import torch.nn.functional as F
import torch
import torch.nn as nn
import pytorch_ops
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(20, 20)
        self.out = nn.Linear(20,20)

    def forward(self, x):
        out = F.relu(self.layer(x))
        out = pytorch_ops.soft_sort(self.out(out)).float()
        return out

class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(20, 20)
        self.out = nn.Linear(20,20)

    def forward(self, x):
        out = F.relu(self.layer(x))
        out = torch.sort(self.out(out))[0]
        return out

targets = torch.rand(32,20)
inputs = torch.rand(32,20)
net = Net()
net2 = Net2()

#try with soft-sort
loss = criterion(targets, net(inputs))
loss.backward()

#try with torch.sort
loss = criterion(targets, net2(inputs))
loss.backward()

#both work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions