You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Apr 19, 2026. It is now read-only.
What would be the difference between the provided implementation (soft-sort) and the torch.sort version. Sorry for the stupid question, but I am not able to see how the torch sort non-differentiable is really different to the soft sort.
import torch.nn.functional as F
import torch
import torch.nn as nn
import pytorch_ops
import numpy as np
class Net(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(20, 20)
self.out = nn.Linear(20,20)
def forward(self, x):
out = F.relu(self.layer(x))
out = pytorch_ops.soft_sort(self.out(out)).float()
return out
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(20, 20)
self.out = nn.Linear(20,20)
def forward(self, x):
out = F.relu(self.layer(x))
out = torch.sort(self.out(out))[0]
return out
targets = torch.rand(32,20)
inputs = torch.rand(32,20)
net = Net()
net2 = Net2()
#try with soft-sort
loss = criterion(targets, net(inputs))
loss.backward()
#try with torch.sort
loss = criterion(targets, net2(inputs))
loss.backward()
#both work!
Hello,
What would be the difference between the provided implementation (soft-sort) and the torch.sort version. Sorry for the stupid question, but I am not able to see how the torch sort non-differentiable is really different to the soft sort.