diff --git a/include/flucoma/clients/common/Result.hpp b/include/flucoma/clients/common/Result.hpp index 33d66016..9db70abd 100644 --- a/include/flucoma/clients/common/Result.hpp +++ b/include/flucoma/clients/common/Result.hpp @@ -57,7 +57,7 @@ class Result bool ok() const noexcept { return (mStatus == Status::kOk); } - Status status() { return mStatus; } + Status status() const noexcept { return mStatus; } void set(Status r) noexcept { mStatus = r; } @@ -92,6 +92,16 @@ class MessageResult : public Result operator T() const { return mData; } T& value() { return mData; } const T& value() const { return mData; } + + template + MessageResult(MessageResult const& x) : Result(x.status(), x.message()) + { + if constexpr (std::is_convertible_v) + { + if (x.ok()) mData = x.mData; + } + } + private: T mData; bool hasData; diff --git a/include/flucoma/clients/nrt/KDTreeClient.hpp b/include/flucoma/clients/nrt/KDTreeClient.hpp index 49d3f538..1a0dff0d 100644 --- a/include/flucoma/clients/nrt/KDTreeClient.hpp +++ b/include/flucoma/clients/nrt/KDTreeClient.hpp @@ -78,22 +78,14 @@ class KDTreeClient : public FluidBaseClient, MessageResult kNearest(InputBufferPtr data, Optional nNeighbours) const { - // we can deprecate ancillary parameters in favour of optional args by - // falling back to using parameters when arg not present index k = nNeighbours ? nNeighbours.value() : get(); - // alternatively we could just be hardcore and ignore parameters and have - // message handlers fallback to a default when arg missing (which would be - // eventual behaviour, I guess) index k = nNeighbours.value_or(1); - if (k > mAlgorithm.size()) return Error(SmallDataSet); - // if (k <= 0 && get() <= 0) return Error(SmallK); - if (!mAlgorithm.initialized()) return Error(NoDataFitted); - InBufferCheck bufCheck(mAlgorithm.dims()); - if (!bufCheck.checkInputs(data.get())) - return Error(bufCheck.error()); - RealVector point(mAlgorithm.dims()); - point <<= - BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0); - auto [dists, ids] = mAlgorithm.kNearest(point, k, get()); + + auto reply = computeKnearest(data, k); + if (!reply.ok()) return reply; + + auto dists = reply.value().first; + auto ids = reply.value().second; + StringVector result(asSigned(ids.size())); std::transform(ids.cbegin(), ids.cend(), result.begin(), [](const std::string* x) { @@ -105,19 +97,11 @@ class KDTreeClient : public FluidBaseClient, MessageResult kNearestDist(InputBufferPtr data, Optional nNeighbours) const { - // TODO: refactor with kNearest index k = nNeighbours ? nNeighbours.value() : get(); - if (k > mAlgorithm.size()) return Error(SmallDataSet); - // if (k <= 0 && get() <= 0) return Error(SmallK); - if (!mAlgorithm.initialized()) return Error(NoDataFitted); - InBufferCheck bufCheck(mAlgorithm.dims()); - if (!bufCheck.checkInputs(data.get())) - return Error(bufCheck.error()); - RealVector point(mAlgorithm.dims()); - point <<= - BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0); - auto [dist, ids] = mAlgorithm.kNearest(point, k, get()); - return {dist}; + auto reply = computeKnearest(data, k); + if (!reply.ok()) return reply; + + return {reply.value().first}; } static auto getMessageDescriptors() @@ -141,6 +125,24 @@ class KDTreeClient : public FluidBaseClient, private: InputDataSetClientRef mDataSetClient; + + MessageResult + computeKnearest(InputBufferPtr data, index k) const + { + if (k > mAlgorithm.size()) + return Error(SmallDataSet); + if (k < 0 && get() < 0) + return Error(SmallK); + if (!mAlgorithm.initialized()) + return Error(NoDataFitted); + InBufferCheck bufCheck(mAlgorithm.dims()); + if (!bufCheck.checkInputs(data.get())) + return Error(bufCheck.error()); + RealVector point(mAlgorithm.dims()); + point <<= + BufferAdaptor::ReadAccess(data.get()).samps(0, mAlgorithm.dims(), 0); + return {mAlgorithm.kNearest(point, k, get())}; + } }; using KDTreeRef = SharedClientRef; @@ -149,13 +151,20 @@ constexpr auto KDTreeQueryParams = defineParameters( KDTreeRef::makeParam("tree", "KDTree"), LongParam("numNeighbours", "Number of Nearest Neighbours", 1), FloatParam("radius", "Maximum distance", 0, Min(0)), - InputDataSetClientRef::makeParam("dataSet", "DataSet Name"), + InputDataSetClientRef::makeParam("lookupDataSet", "Lookup DataSet Name"), InputBufferParam("inputPointBuffer", "Input Point Buffer"), BufferParam("predictionBuffer", "Prediction Buffer")); class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut { - enum { kTree, kNumNeighbors, kRadius, kDataSet, kInputBuffer, kOutputBuffer }; + enum { + kTree, + kNumNeighbors, + kRadius, + kLookupDataSet, + kInputBuffer, + kOutputBuffer + }; public: using ParamDescType = decltype(KDTreeQueryParams); @@ -173,8 +182,7 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut static constexpr auto& getParameterDescriptors() { return KDTreeQueryParams; } - KDTreeQuery(ParamSetViewType& p, FluidContext& c) - : mParams(p), mRTBuffer(c.allocator()) + KDTreeQuery(ParamSetViewType& p, FluidContext& c) : mParams(p) { controlChannelsIn(1); controlChannelsOut({1, 1}); @@ -188,74 +196,72 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut { if (input[0](0) > 0) { + output[0](0) = mLastNumPoints = 0; + auto kdtreeptr = get().get().lock(); if (!kdtreeptr) - { - // c.reportError("FluidKDTree RT Query: No FluidKDTree found"); - return; - } + return; // c.reportError("FluidKDTree RT Query: No FluidKDTree found"); if (!kdtreeptr->initialized()) - { - // c.reportError("FluidKDTree RT Query: tree not fitted"); - return; - } + return; // c.reportError("FluidKDTree RT Query: tree not fitted"); index k = get(); if (k > kdtreeptr->size() || k < 0) return; // c.reportError("FluidKDTree RT Query has wrong k size"); - index dims = kdtreeptr->dims(); + + index dims = kdtreeptr->dims(); + InOutBuffersCheck bufCheck(dims); if (!bufCheck.checkInputs(get().get(), get().get())) return; // c.reportError("FluidKDTree RT Query i/o buffers are // unavailable"); - auto datasetClientPtr = get().get().lock(); - if (!datasetClientPtr) - datasetClientPtr = kdtreeptr->getDataSet().get().lock(); - if (!datasetClientPtr) - { - // c.reportError("Could not obtain reference FluidDataSet"); - return; - } + auto lookupDSpointer = get().get().lock(); + + index pointSize = lookupDSpointer ? lookupDSpointer->dims().value() : 1; + + auto outBuf = BufferAdaptor::Access(get().get()); + auto outSamps = outBuf.samps(0); - auto dataset = datasetClientPtr->getDataSet(); - index pointSize = dataset.pointSize(); - auto outBuf = BufferAdaptor::Access(get().get()); - index maxK = outBuf.samps(0).size() / pointSize; - if (maxK <= 0) return; - index outputSize = maxK * pointSize; + index numPoints = outSamps.size() / pointSize; + if (numPoints <= 0) + return; // c.reportError("FluidKDTree RT Query output buffer is too + // small for one point") RealVector point(dims, c.allocator()); point <<= BufferAdaptor::ReadAccess(get().get()) .samps(0, dims, 0); - if (mRTBuffer.size() != outputSize) - { - mRTBuffer = RealVector(outputSize, c.allocator()); - mRTBuffer.fill(0); - } auto [dists, ids] = kdtreeptr->algorithm().kNearest( point, k, get(), c.allocator()); - mNumValidKs = std::min(asSigned(ids.size()), maxK); - - for (index i = 0; i < mNumValidKs; i++) + if (lookupDSpointer) { - dataset.get(*ids[asUnsigned(i)], - mRTBuffer(Slice(i * pointSize, pointSize))); + auto lookupDS = lookupDSpointer->getDataSet(); + + auto lookupFn = [&lookupDS, outSamps, pointSize, + n = 0](auto id) mutable { + if (auto point = lookupDS.get(*id); point.data() != nullptr) + outSamps(Slice(n, pointSize)) <<= point; + n += pointSize; + }; + + std::for_each_n(ids.begin(), numPoints, lookupFn); } - outBuf.samps(0, outputSize, 0) <<= mRTBuffer; + else + std::copy_n(dists.begin(), numPoints, outSamps.begin()); + + mLastNumPoints = std::min(asSigned(ids.size()), numPoints); } - output[0](0) = mNumValidKs; + output[0](0) = + mLastNumPoints; // updates the output if successful or if not triggered } private: - RealVector mRTBuffer; - index mNumValidKs = 0; + index mLastNumPoints{0}; InputDataSetClientRef mDataSetClient; };