diff --git a/src/Database/Persist/Sql/Graph/TransitiveReduction.hs b/src/Database/Persist/Sql/Graph/TransitiveReduction.hs index da34a30..f995b0b 100644 --- a/src/Database/Persist/Sql/Graph/TransitiveReduction.hs +++ b/src/Database/Persist/Sql/Graph/TransitiveReduction.hs @@ -17,6 +17,7 @@ module Database.Persist.Sql.Graph.TransitiveReduction ( -- * Transitive reduction of DAGs trrSelect , trrApply + , trrFix ) where @@ -212,3 +213,107 @@ trrApply proxy = do , " )" ] rawExecuteCount sql [] + +-- | Given an edge (u, v) that was just added to a reduced DAG, remove edges if +-- necessary to make sure the graph stays reduced. +-- +-- It more-or-less looks like this: +-- +-- > WITH RECURSIVE +-- > temp (id, path, cycle, contains) AS ( +-- > SELECT node.id, ARRAY[node.id], FALSE, FALSE +-- > FROM node +-- > UNION ALL +-- > SELECT edge.dest, +-- > temp.path || edge.dest, +-- > edge.dest = ANY(temp.path), +-- > temp.contains OR edge.dest = v +-- > FROM edge INNER JOIN temp +-- > ON edge.source = temp.id +-- > WHERE NOT temp.cycle AND +-- > ( edge.source = u AND edge.dest = v OR +-- > edge.source <> u AND edge.dest <> v +-- > ) +-- > ) +-- > DELETE FROM edge +-- > WHERE id IN ( +-- > SELECT edge.id +-- > FROM edge INNER JOIN temp +-- > ON edge.source = temp.path[1] AND +-- > edge.dest = temp.id +-- > WHERE array_length(temp.path, 1) > 2 AND +-- > NOT temp.cycle AND +-- > temp.contains +-- > ) +trrFix + :: ( MonadIO m + , PersistEntityGraph node edge + , SqlBackend ~ PersistEntityBackend node + , SqlBackend ~ PersistEntityBackend edge + ) + => Key edge + -> Key edge + -> Proxy (node, edge) + -> ReaderT SqlBackend m Int64 +trrFix from to proxy = do + conn <- ask + let tNode = entityDef $ dummyFromFst proxy + tEdge = entityDef $ dummyFromSnd proxy + fwd = persistFieldDef $ destFieldFromProxy proxy + bwd = persistFieldDef $ sourceFieldFromProxy proxy + temp = DBName "temp_hierarchy_cte" + tid = DBName "id" + tpath = DBName "path" + tcontains = DBName "cycle" + tcycle = DBName "contains" + dbname = connEscapeName conn + t ^* f = dbname t <> "." <> dbname f + t <#> s = dbname t <> " INNER JOIN " <> dbname s + t <# s = dbname t <> " LEFT OUTER JOIN " <> dbname s + + sqlStep forward backward = mconcat + [ "SELECT " + , entityDB tEdge ^* fieldDB forward, ", " + , temp ^* tpath, " || ", entityDB tEdge ^* fieldDB forward, ", " + , entityDB tEdge ^* fieldDB forward, " = ANY(", temp ^* tpath, ")," + , temp ^* tcontains, " OR " + , entityDB tEdge ^* fieldDB forward, " = ?" + , " FROM ", entityDB tEdge <#> temp + , " ON ", entityDB tEdge ^* fieldDB backward, " = ", temp ^* tid + , " WHERE NOT ", temp ^* tcycle, " AND (" + , entityDB tEdge ^* fieldDB backward, " = ? AND " + , entityDB tEdge ^* fieldDB forward, " = ?" + , " OR " + , entityDB tEdge ^* fieldDB backward, " <> ? AND " + , entityDB tEdge ^* fieldDB forward, " <> ?" + , ")" + ] + + sql = mconcat + [ "WITH RECURSIVE " + , dbname temp + , " (" + , T.intercalate "," $ map dbname [tid, tpath, tcycle, tcontains] + , ") AS ( SELECT " + , entityDB tNode ^* fieldDB (entityId tNode), ", " + , "ARRAY[", entityDB tNode ^* fieldDB (entityId tNode), "], " + , "FALSE, FALSE" + , " FROM ", dbname $ entityDB tNode + , " WHERE ", entityDB tNode ^* fieldDB (entityId tNode) + , " IN ?" + , " UNION ALL " + , sqlStep fwd bwd + , " ) DELETE FROM ", dbname $ entityDB tEdge + , " WHERE ", entityDB tEdge ^* fieldDB (entityId tEdge), " IN (" + , " SELECT ", entityDB tEdge ^* fieldDB (entityId tEdge) + , " FROM ", entityDB tEdge <#> temp + , " ON " + , entityDB tEdge ^* fieldDB bwd, " = ", temp ^* tpath + , "[1] AND ", entityDB tEdge ^* fieldDB fwd, " = ", temp ^* tid + , " WHERE array_length(", temp ^* tpath, ", 1) > 2 AND NOT " + , temp ^* tcycle, " AND ", temp ^* tcontains + , " )" + ] + u = toPersistValue from + v = toPersistValue to + rawExecuteCount sql [v, u, v, u, v]