From 84012c12c52aa574f8b464c7ee7c603b9dc3bb61 Mon Sep 17 00:00:00 2001 From: manavggupta Date: Sun, 22 Feb 2026 23:15:44 +0530 Subject: [PATCH] Fix: Validate join columns before performing join to prevent unintended cartesian product (#118) --- src/DataFrame/Operations/Join.hs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index 06a66971..f62aa62e 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -28,6 +28,21 @@ data JoinType | RIGHT | FULL_OUTER + validateJoinColumns :: [T.Text] -> DataFrame -> DataFrame -> () +validateJoinColumns cs df1 df2 = + let + df1Cols = D.columnNames df1 + df2Cols = D.columnNames df2 + + missingInDf1 = filter (`notElem` df1Cols) cs + missingInDf2 = filter (`notElem` df2Cols) cs + + missing = missingInDf1 ++ missingInDf2 + in + if not (null missing) + then error $ "Column not found: " <> show missing + else () + {- | Join two dataframes using SQL join semantics. Only inner join is implemented for now. @@ -66,6 +81,7 @@ ghci> D.innerJoin ["key"] df other innerJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame innerJoin cs right left = let + _ = validateJoinColumns cs right left -- Prepare Keys for the Right DataFrame rightIndicesToGroup = [c | (k, c) <- M.toList (D.columnIndices right), k `elem` cs] @@ -170,6 +186,7 @@ leftJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame leftJoin cs right left = let + _ = validateJoinColumns cs right left leftIndicesToGroup = M.elems $ M.filterWithKey (\k _ -> k `elem` cs) (D.columnIndices left) leftRowRepresentations = D.computeRowHashes leftIndicesToGroup left rightIndicesToGroup = M.elems $ M.filterWithKey (\k _ -> k `elem` cs) (D.columnIndices right) @@ -250,6 +267,7 @@ fullOuterJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame fullOuterJoin cs right left = let + _ = validateJoinColumns cs right left leftIndicesToGroup = M.elems $ M.filterWithKey (\k _ -> k `elem` cs) (D.columnIndices left) leftRowRepresentations = D.computeRowHashes leftIndicesToGroup left leftKeyCountsAndIndices =