Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion include/flucoma/clients/common/Result.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Result

bool ok() const noexcept { return (mStatus == Status::kOk); }

Status status() { return mStatus; }
Status status() const noexcept { return mStatus; }

void set(Status r) noexcept { mStatus = r; }

Expand Down Expand Up @@ -92,6 +92,16 @@ class MessageResult : public Result
operator T() const { return mData; }
T& value() { return mData; }
const T& value() const { return mData; }

template <typename U>
MessageResult(MessageResult<U> const& x) : Result(x.status(), x.message())
{
if constexpr (std::is_convertible_v<T, U>)
{
if (x.ok()) mData = x.mData;
}
}

private:
T mData;
bool hasData;
Expand Down
142 changes: 74 additions & 68 deletions include/flucoma/clients/nrt/KDTreeClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,14 @@ class KDTreeClient : public FluidBaseClient,
MessageResult<StringVector> kNearest(InputBufferPtr data,
Optional<index> nNeighbours) const
{
// we can deprecate ancillary parameters in favour of optional args by
// falling back to using parameters when arg not present
index k = nNeighbours ? nNeighbours.value() : get<kNumNeighbors>();
// alternatively we could just be hardcore and ignore parameters and have
// message handlers fallback to a default when arg missing (which would be
// eventual behaviour, I guess) index k = nNeighbours.value_or(1);
if (k > mAlgorithm.size()) return Error<StringVector>(SmallDataSet);
// if (k <= 0 && get<kRadius>() <= 0) return Error<StringVector>(SmallK);
if (!mAlgorithm.initialized()) return Error<StringVector>(NoDataFitted);
InBufferCheck bufCheck(mAlgorithm.dims());
if (!bufCheck.checkInputs(data.get()))
return Error<StringVector>(bufCheck.error());
RealVector point(mAlgorithm.dims());
point <<=
BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0);
auto [dists, ids] = mAlgorithm.kNearest(point, k, get<kRadius>());

auto reply = computeKnearest(data, k);
if (!reply.ok()) return reply;

auto dists = reply.value().first;
auto ids = reply.value().second;

StringVector result(asSigned(ids.size()));
std::transform(ids.cbegin(), ids.cend(), result.begin(),
[](const std::string* x) {
Expand All @@ -105,19 +97,11 @@ class KDTreeClient : public FluidBaseClient,
MessageResult<RealVector> kNearestDist(InputBufferPtr data,
Optional<index> nNeighbours) const
{
// TODO: refactor with kNearest
index k = nNeighbours ? nNeighbours.value() : get<kNumNeighbors>();
if (k > mAlgorithm.size()) return Error<RealVector>(SmallDataSet);
// if (k <= 0 && get<kRadius>() <= 0) return Error<RealVector>(SmallK);
if (!mAlgorithm.initialized()) return Error<RealVector>(NoDataFitted);
InBufferCheck bufCheck(mAlgorithm.dims());
if (!bufCheck.checkInputs(data.get()))
return Error<RealVector>(bufCheck.error());
RealVector point(mAlgorithm.dims());
point <<=
BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0);
auto [dist, ids] = mAlgorithm.kNearest(point, k, get<kRadius>());
return {dist};
auto reply = computeKnearest(data, k);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to check that the result is ok here as well

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oupsy, sorry, added now running the tests

if (!reply.ok()) return reply;

return {reply.value().first};
}

static auto getMessageDescriptors()
Expand All @@ -141,6 +125,24 @@ class KDTreeClient : public FluidBaseClient,

private:
InputDataSetClientRef mDataSetClient;

MessageResult<algorithm::KDTree::KNNResult>
computeKnearest(InputBufferPtr data, index k) const
{
if (k > mAlgorithm.size())
return Error<algorithm::KDTree::KNNResult>(SmallDataSet);
if (k < 0 && get<kRadius>() < 0)
return Error<algorithm::KDTree::KNNResult>(SmallK);
if (!mAlgorithm.initialized())
return Error<algorithm::KDTree::KNNResult>(NoDataFitted);
InBufferCheck bufCheck(mAlgorithm.dims());
if (!bufCheck.checkInputs(data.get()))
return Error<algorithm::KDTree::KNNResult>(bufCheck.error());
RealVector point(mAlgorithm.dims());
point <<=
BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0);
return {mAlgorithm.kNearest(point, k, get<kRadius>())};
}
};

using KDTreeRef = SharedClientRef<const KDTreeClient>;
Expand All @@ -149,13 +151,20 @@ constexpr auto KDTreeQueryParams = defineParameters(
KDTreeRef::makeParam("tree", "KDTree"),
LongParam("numNeighbours", "Number of Nearest Neighbours", 1),
FloatParam("radius", "Maximum distance", 0, Min(0)),
InputDataSetClientRef::makeParam("dataSet", "DataSet Name"),
InputDataSetClientRef::makeParam("lookupDataSet", "Lookup DataSet Name"),
InputBufferParam("inputPointBuffer", "Input Point Buffer"),
BufferParam("predictionBuffer", "Prediction Buffer"));

class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut
{
enum { kTree, kNumNeighbors, kRadius, kDataSet, kInputBuffer, kOutputBuffer };
enum {
kTree,
kNumNeighbors,
kRadius,
kLookupDataSet,
kInputBuffer,
kOutputBuffer
};

public:
using ParamDescType = decltype(KDTreeQueryParams);
Expand All @@ -173,8 +182,7 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut

static constexpr auto& getParameterDescriptors() { return KDTreeQueryParams; }

KDTreeQuery(ParamSetViewType& p, FluidContext& c)
: mParams(p), mRTBuffer(c.allocator())
KDTreeQuery(ParamSetViewType& p, FluidContext& c) : mParams(p)
{
controlChannelsIn(1);
controlChannelsOut({1, 1});
Expand All @@ -188,74 +196,72 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut
{
if (input[0](0) > 0)
{
output[0](0) = mLastNumPoints = 0;

auto kdtreeptr = get<kTree>().get().lock();
if (!kdtreeptr)
{
// c.reportError("FluidKDTree RT Query: No FluidKDTree found");
return;
}
return; // c.reportError("FluidKDTree RT Query: No FluidKDTree found");

if (!kdtreeptr->initialized())
{
// c.reportError("FluidKDTree RT Query: tree not fitted");
return;
}
return; // c.reportError("FluidKDTree RT Query: tree not fitted");

index k = get<kNumNeighbors>();
if (k > kdtreeptr->size() || k < 0)
return; // c.reportError("FluidKDTree RT Query has wrong k size");
index dims = kdtreeptr->dims();

index dims = kdtreeptr->dims();

InOutBuffersCheck bufCheck(dims);
if (!bufCheck.checkInputs(get<kInputBuffer>().get(),
get<kOutputBuffer>().get()))
return; // c.reportError("FluidKDTree RT Query i/o buffers are
// unavailable");
auto datasetClientPtr = get<kDataSet>().get().lock();
if (!datasetClientPtr)
datasetClientPtr = kdtreeptr->getDataSet().get().lock();

if (!datasetClientPtr)
{
// c.reportError("Could not obtain reference FluidDataSet");
return;
}
auto lookupDSpointer = get<kLookupDataSet>().get().lock();

index pointSize = lookupDSpointer ? lookupDSpointer->dims().value() : 1;

auto outBuf = BufferAdaptor::Access(get<kOutputBuffer>().get());
auto outSamps = outBuf.samps(0);

auto dataset = datasetClientPtr->getDataSet();
index pointSize = dataset.pointSize();
auto outBuf = BufferAdaptor::Access(get<kOutputBuffer>().get());
index maxK = outBuf.samps(0).size() / pointSize;
if (maxK <= 0) return;
index outputSize = maxK * pointSize;
index numPoints = outSamps.size() / pointSize;
if (numPoints <= 0)
return; // c.reportError("FluidKDTree RT Query output buffer is too
// small for one point")

RealVector point(dims, c.allocator());
point <<= BufferAdaptor::ReadAccess(get<kInputBuffer>().get())
.samps(0, dims, 0);
if (mRTBuffer.size() != outputSize)
{
mRTBuffer = RealVector(outputSize, c.allocator());
mRTBuffer.fill(0);
}

auto [dists, ids] = kdtreeptr->algorithm().kNearest(
point, k, get<kRadius>(), c.allocator());

mNumValidKs = std::min(asSigned(ids.size()), maxK);

for (index i = 0; i < mNumValidKs; i++)
if (lookupDSpointer)
{
dataset.get(*ids[asUnsigned(i)],
mRTBuffer(Slice(i * pointSize, pointSize)));
auto lookupDS = lookupDSpointer->getDataSet();

auto lookupFn = [&lookupDS, outSamps, pointSize,
n = 0](auto id) mutable {
if (auto point = lookupDS.get(*id); point.data() != nullptr)
outSamps(Slice(n, pointSize)) <<= point;
n += pointSize;
};

std::for_each_n(ids.begin(), numPoints, lookupFn);
}
outBuf.samps(0, outputSize, 0) <<= mRTBuffer;
else
std::copy_n(dists.begin(), numPoints, outSamps.begin());

mLastNumPoints = std::min(asSigned(ids.size()), numPoints);
}

output[0](0) = mNumValidKs;
output[0](0) =
mLastNumPoints; // updates the output if successful or if not triggered
}


private:
RealVector mRTBuffer;
index mNumValidKs = 0;
index mLastNumPoints{0};
InputDataSetClientRef mDataSetClient;
};

Expand Down