Skip to content

Commit 6f8bf2b

Browse files
committed
WIP: color by selection mapping
1 parent e7cffe7 commit 6f8bf2b

File tree

1 file changed

+95
-6
lines changed

1 file changed

+95
-6
lines changed

src/ScatterplotPlugin.cpp

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232
#include <algorithm>
3333
#include <cassert>
34+
#include <map>
35+
#include <optional>
36+
#include <ranges>
3437
#include <vector>
3538

3639
#define VIEW_SAMPLING_HTML
@@ -41,6 +44,75 @@ Q_PLUGIN_METADATA(IID "studio.manivault.ScatterplotPlugin")
4144
using namespace mv;
4245
using namespace mv::util;
4346

47+
static std::optional<const mv::LinkedData*> getSelectionMapping(const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
48+
const std::vector<mv::LinkedData>& linkedDatas = source->getLinkedData();
49+
50+
qDebug() << "Source: " << source->getGuiName();
51+
52+
for (const auto& linkedData : linkedDatas) {
53+
qDebug() << linkedData.getSourceDataSet()->getGuiName();
54+
qDebug() << linkedData.getTargetDataset()->getGuiName();
55+
}
56+
57+
const auto it = std::ranges::find_if(linkedDatas, [&target](const mv::LinkedData& obj) {
58+
59+
// TODO: This should be recursive
60+
auto isParentOf = [&target](const mv::Dataset<Points>& linkedTarget) -> bool {
61+
if (target->isDerivedData()) {
62+
const auto parent = target->getParent();
63+
if (parent->getDataType() == PointType) {
64+
const auto parentPoints = mv::Dataset<Points>(parent);
65+
66+
qDebug() << parentPoints->getGuiName();
67+
qDebug() << target->getGuiName();
68+
69+
return parentPoints->getNumPoints() == target->getNumPoints();
70+
}
71+
}
72+
return false;
73+
};
74+
75+
return obj.getTargetDataset() == target || isParentOf(obj.getTargetDataset());
76+
});
77+
78+
if (it != linkedDatas.end()) {
79+
return &(*it); // return pointer to the found object
80+
}
81+
82+
return std::nullopt; // nothing found
83+
}
84+
85+
static bool checkSelectionMapping(const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
86+
const std::vector<mv::LinkedData>& linkedDatas = source->getLinkedData();
87+
88+
// First, check if there is a mapping
89+
const auto it = getSelectionMapping(source, target);
90+
91+
if (!it.has_value())
92+
return false;
93+
94+
// Second, check if the mapping is surjective, i.e. hits all elements in the target
95+
const std::map<std::uint32_t, std::vector<std::uint32_t>>& linkedMap = it.value()->getMapping().getMap();
96+
const std::uint32_t numPointsInTarget = target->getNumPoints();
97+
98+
std::vector<bool> found(numPointsInTarget, false);
99+
std::uint32_t count = 0;
100+
101+
for (const auto& [key, vec] : linkedMap) {
102+
for (std::uint32_t val : vec) {
103+
if (val >= numPointsInTarget) continue; // Skip values that are too large
104+
105+
if (!found[val]) {
106+
found[val] = true;
107+
if (++count == numPointsInTarget)
108+
return true;
109+
}
110+
}
111+
}
112+
113+
return false; // The previous loop would have returned early if the entire taget set was covered
114+
}
115+
44116
ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
45117
ViewPlugin(factory),
46118
_dropWidget(nullptr),
@@ -181,8 +253,9 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
181253
/*if*/ _positionDataset->isDerivedData() ?
182254
/*then*/ _positionDataset->getSourceDataset<Points>()->getFullDataset<Points>()->getNumPoints() == numPointsCandidate :
183255
/*else*/ false;
256+
const bool hasSelectionMapping = checkSelectionMapping(candidateDataset, _positionDataset);
184257

185-
if (sameNumPoints || sameNumPointsAsFull) {
258+
if (sameNumPoints || sameNumPointsAsFull || hasSelectionMapping) {
186259
// Offer the option to use the points dataset as source for points colors
187260
dropRegions << new DropWidget::DropRegion(this, "Point color", QString("Colorize %1 points with %2").arg(_positionDataset->text(), candidateDataset->text()), "palette", true, [this, candidateDataset]() {
188261
_settingsAction.getColoringAction().setCurrentColorDataset(candidateDataset); // calls addColorDataset internally
@@ -647,19 +720,18 @@ void ScatterplotPlugin::positionDatasetChanged()
647720
updateData();
648721
}
649722

650-
void ScatterplotPlugin::loadColors(const Dataset<Points>& points, const std::uint32_t& dimensionIndex)
723+
void ScatterplotPlugin::loadColors(const Dataset<Points>& pointsColor, const std::uint32_t& dimensionIndex)
651724
{
652725
// Only proceed with valid points dataset
653-
if (!points.isValid())
726+
if (!pointsColor.isValid())
654727
return;
655728

656729
// Generate point scalars for color mapping
657730
std::vector<float> scalars;
658731

659-
points->extractDataForDimension(scalars, dimensionIndex);
660-
661-
const auto numColorPoints = points->getNumPoints();
732+
pointsColor->extractDataForDimension(scalars, dimensionIndex);
662733

734+
const auto numColorPoints = pointsColor->getNumPoints();
663735

664736
if (numColorPoints != _numPoints) {
665737

@@ -668,6 +740,8 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& points, const std::uin
668740
/*then*/ _positionSourceDataset->getFullDataset<Points>()->getNumPoints() == numColorPoints :
669741
/*else*/ false;
670742

743+
const auto validSelectionMapping = getSelectionMapping(pointsColor, _positionDataset);
744+
671745
if (sameNumPointsAsFull) {
672746
std::vector<std::uint32_t> globalIndices;
673747
_positionDataset->getGlobalIndices(globalIndices);
@@ -680,6 +754,21 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& points, const std::uin
680754

681755
std::swap(localScalars, scalars);
682756
}
757+
else if (validSelectionMapping.has_value() && validSelectionMapping.value() != nullptr) {
758+
std::vector<float> localScalars(_numPoints, 0);
759+
760+
// Map values like selection
761+
const mv::SelectionMap::Map& linkedMap = validSelectionMapping.value()->getMapping().getMap();
762+
const std::uint32_t numPointsInTarget = _positionDataset->getNumPoints();
763+
764+
for (const auto& [fromID, vecOfIDs] : linkedMap) {
765+
for (std::uint32_t toID : vecOfIDs) {
766+
localScalars[toID] = scalars[fromID];
767+
}
768+
}
769+
770+
std::swap(localScalars, scalars);
771+
}
683772
else {
684773
qWarning("Number of points used for coloring does not match number of points in data, aborting attempt to color plot");
685774
return;

0 commit comments

Comments
 (0)