diff --git a/src/Database/Persist/Local/Class/PersistQueryRecursive.hs b/src/Database/Persist/Local/Class/PersistQueryRecursive.hs new file mode 100644 index 0000000..0e7594b --- /dev/null +++ b/src/Database/Persist/Local/Class/PersistQueryRecursive.hs @@ -0,0 +1,255 @@ +{- This file is part of Vervis. + - + - Written in 2016 by fr33domlover . + - + - ♡ Copying is an act of love. Please copy, reuse and share. + - + - The author(s) have dedicated all copyright and related and neighboring + - rights to this software to the public domain worldwide. This software is + - distributed without any warranty. + - + - You should have received a copy of the CC0 Public Domain Dedication along + - with this software. If not, see + - . + -} + +{- The code is based on PersistQuery. Actually, most of the difference is + - slightly different names and 3 additional function parameters. + -} + +-- | Recursive queries are performed by taking the output of a recursion step, +-- possibly modifying it, and using the result as the input for the next +-- recursion step. +-- +-- This module currently provides a single way to perform that recursive step: +-- Match between the /id/ column and some other column which has the same type. +-- +-- For example, suppose we have a `Message` type with a `messageParent` field. +-- For given messages `a` and `b`, if `messageParent b == Just a` then `b` is a +-- reply to `a`. Therefore, all the replies to a given message point to it +-- using the `messageParent` field. And there can be replies to replies and so +-- on, creating a tree of messages. +-- +-- > Message +-- > author PersonId +-- > content Text +-- > parent MessageId Maybe +-- +-- If we start with a single message and follow the `messageParent` values +-- recursively, we'll be able to get a list (or a tree) of the __ancestors__ of +-- the message. Our message /a/ may be a reply to some other message /b/, and +-- /b/ may be a reply to message /c/ and so on. Eventually, if there are no +-- cycles and it's really a tree structure, we'll reach the root message, which +-- has no parent. +-- +-- But there's another way to recurse. What if we wanted to find the replies +-- for a given message? And the replies of the replies, and so on? In other +-- words, the __decendants__ of a given message. Suppose we start with a +-- message /a/. We get a list of the replies of /a/, i.e. message whose parent +-- is `Just a`. Then we find the replies of those messages, i.e. the replies of +-- the replies of /a/. And so on, recursively, until we can't find more replies +-- and then we stop. +-- +-- Therefore we can perform the recursion in one of two directions: +-- +-- - __Outwards__, i.e. follow from a message to its parents. More generally, +-- given a persistent entity type `Foobar`, follow recursively using a +-- specific field of it, whose type is `FoobarId`. It's called "outwards" +-- because it's like following out-edges of a graph node, i.e. arrows +-- pointing from a node towards other nodes. +-- - __Inwards__, i.e. find the children (i.e. replies) of a message, and then +-- their children, and so on. More generally, given a persistent entity type +-- `Foobar`, find other values referring to it using a specific field, whose +-- type is `FoobarId`, and recursive find such values for the results we get +-- and so on. It's called "inwards" because it's like following in-edges of a +-- graph node, i.e. arrows pointing from other nodes towards that node. +-- +-- The 'RecursionDirection' type is used for specifying the direction. +-- +-- When you follow all the children of an entity recursively, or all of its +-- parents, we call the result you get the __transitive closure__ of the +-- specific field you used. You can further specify the direction, i.e. +-- __outward transitive closure__ or __inward transitive closure__. For +-- examples, if you follow a message's parents recursively as in the example +-- above, you get an outward transitive closure on the /parent/ field. +-- +-- Note that the definition used here is __not__ the same as the mathematical +-- definition. When you perform a recursive query without filters, you get not +-- only the ancestors (or the decendants) of an entity, but also the root +-- entity itself. In other words, even though a message is not a reply of +-- itself, you'll still get it in the query result. If you want to get just the +-- ancestors (or decendants), i.e. the actual transitive closer of the "is +-- reply of" relation in the mathematical sense, use a filter to omit the root +-- message based on the ID, i.e. @[MessageParent /= msgid]@. +-- +-- Therefore, when the term "transitive closure" is used below, it means not +-- just the ancestors (or decendants), but also the origin entity too. +module Database.Persist.Local.Class.PersistQueryRecursive + ( RecursionDirection (..) + , PersistQueryRecursive (..) + , selectRecursivelySource + , selectRecursivelyKeys + , selectRecursivelyList + , selectRecursivelyKeysList + ) +where + +import Prelude + +import Control.Monad.IO.Class +import Control.Monad.Reader (MonadReader) +import Control.Monad.Trans.Reader (ReaderT) +import Control.Monad.Trans.Resource (MonadResource, release) +import Data.Acquire (Acquire, allocateAcquire, with) +import Database.Persist.Class +import Database.Persist.Types + +import qualified Data.Conduit as C +import qualified Data.Conduit.List as CL + +data RecursionDirection + = RecOut + | RecIn + deriving (Eq, Show) + +-- | Backends supporting recursive conditional operations. +class PersistQuery backend => PersistQueryRecursive backend where + -- | Update individual fields on any record in the transitive closure and + -- matching the given criterion. + updateRecursivelyWhere + :: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [Update val] + -> ReaderT backend m () + + -- | Delete all records in the transitive closure which match the given + -- criterion. + deleteRecursivelyWhere + :: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> ReaderT backend m () + + -- | Get all records in the transitive closure, which match the given + -- criterion, in the specified order. Returns also the identifiers. + selectRecursivelySourceRes + :: ( PersistEntity val + , PersistEntityBackend val ~ backend + , MonadIO m1 + , MonadIO m2 + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> ReaderT backend m1 (Acquire (C.Source m2 (Entity val))) + + -- | Get the 'Key's of all records in the transitive closure, which match + -- the given criterion. + selectRecursivelyKeysRes + :: ( PersistEntity val + , PersistEntityBackend val ~ backend + , MonadIO m1 + , MonadIO m2 + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> ReaderT backend m1 (Acquire (C.Source m2 (Key val))) + + -- | The total number of records in the transitive closure which fulfill + -- the given criterion. + countRecursively + :: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> ReaderT backend m Int + +-- | Get all records in the transitive closure, which match the given +-- criterion, in the specified order. Returns also the identifiers. +selectRecursivelySource + :: ( PersistQueryRecursive backend + , MonadResource m + , PersistEntity val + , PersistEntityBackend val ~ backend + , MonadReader env m + , HasPersistBackend env backend + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> C.Source m (Entity val) +selectRecursivelySource dir field root filts opts = do + srcRes <- + liftPersist $ selectRecursivelySourceRes dir field root filts opts + (releaseKey, src) <- allocateAcquire srcRes + src + release releaseKey + +-- | Get the 'Key's of all records in the transitive closure, which match the +-- given criterion. +selectRecursivelyKeys + :: ( PersistQueryRecursive backend + , MonadResource m + , PersistEntity val + , backend ~ PersistEntityBackend val + , MonadReader env m + , HasPersistBackend env backend + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> C.Source m (Key val) +selectRecursivelyKeys dir field root filts opts = do + srcRes <- liftPersist $ selectRecursivelyKeysRes dir field root filts opts + (releaseKey, src) <- allocateAcquire srcRes + src + release releaseKey + +-- | Call 'selectRecursivelySource' but return the result as a list. +selectRecursivelyList + :: ( PersistQueryRecursive backend + , MonadIO m + , PersistEntity val + , PersistEntityBackend val ~ backend + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> ReaderT backend m [Entity val] +selectRecursivelyList dir field root filts opts = do + srcRes <- selectRecursivelySourceRes dir field root filts opts + liftIO $ with srcRes (C.$$ CL.consume) + +-- | Call 'selectRecursivelyKeys' but return the result as a list. +selectRecursivelyKeysList + :: ( PersistQueryRecursive backend + , MonadIO m + , PersistEntity val + , PersistEntityBackend val ~ backend + ) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [SelectOpt val] + -> ReaderT backend m [Key val] +selectRecursivelyKeysList dir field root filts opts = do + srcRes <- selectRecursivelyKeysRes dir field root filts opts + liftIO $ with srcRes (C.$$ CL.consume) diff --git a/src/Database/Persist/Local/Sql/Orphan/Common.hs b/src/Database/Persist/Local/Sql/Orphan/Common.hs new file mode 100644 index 0000000..562e557 --- /dev/null +++ b/src/Database/Persist/Local/Sql/Orphan/Common.hs @@ -0,0 +1,239 @@ +{- This file contains (slightly modified) copies of unexported functions from + - Database.Persist.Sql.Orphan.PersistQuery, which I need for my + - PersistQueryRecursive implementation. They're released under MIT. + - + - This should be a temporary situation. Either my code moves to persistent and + - the functions are reused there, or these functions become exported in + - persistent and then I can import them instead of holding copies. + -} + +{-# LANGUAGE RankNTypes #-} + +module Database.Persist.Local.Sql.Orphan.Common + ( fieldName + , dummyFromFilts + , getFiltsValues + , updatePersistValue + , filterClause + , orderClause + ) +where + +import Prelude + +import Data.List (inits, transpose) +import Data.Monoid ((<>)) +import Data.Text (Text) +import Database.Persist +import Database.Persist.Sql +import Database.Persist.Sql.Util + +import qualified Data.Text as T + +fieldName + :: forall record typ. + (PersistEntity record , PersistEntityBackend record ~ SqlBackend) + => EntityField record typ + -> DBName +fieldName f = fieldDB $ persistFieldDef f + +dummyFromFilts :: [Filter v] -> Maybe v +dummyFromFilts _ = Nothing + +getFiltsValues + :: forall val. (PersistEntity val, PersistEntityBackend val ~ SqlBackend) + => SqlBackend + -> [Filter val] + -> [PersistValue] +getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo + +data OrNull = OrNullYes | OrNullNo + +filterClauseHelper + :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) + => Bool -- ^ include table name? + -> Bool -- ^ include WHERE? + -> SqlBackend + -> OrNull + -> [Filter val] + -> (Text, [PersistValue]) +filterClauseHelper includeTable includeWhere conn orNull filters = + ( if not (T.null sql) && includeWhere + then " WHERE " <> sql + else sql + , vals + ) + where + (sql, vals) = combineAND filters + combineAND = combine " AND " + + combine s fs = + (T.intercalate s $ map wrapP a, mconcat b) + where + (a, b) = unzip $ map go fs + wrapP x = T.concat ["(", x, ")"] + + go (BackendFilter _) = error "BackendFilter not expected" + go (FilterAnd []) = ("1=1", []) + go (FilterAnd fs) = combineAND fs + go (FilterOr []) = ("1=0", []) + go (FilterOr fs) = combine " OR " fs + go (Filter field value pfilter) = + let t = entityDef $ dummyFromFilts [Filter field value pfilter] + in case (isIdField field, entityPrimary t, allVals) of + (True, Just pdef, PersistList ys:_) -> + if length (compositeFields pdef) /= length ys + then error $ "wrong number of entries in compositeFields vs PersistList allVals=" ++ show allVals + else + case (allVals, pfilter, isCompFilter pfilter) of + ([PersistList xs], Eq, _) -> + let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef)) + in (wrapSql sqlcl,xs) + ([PersistList xs], Ne, _) -> + let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef)) + in (wrapSql sqlcl,xs) + (_, In, _) -> + let xxs = transpose (map fromPersistList allVals) + sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs) + in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs) + (_, NotIn, _) -> + let xxs = transpose (map fromPersistList allVals) + sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs) + in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs) + ([PersistList xs], _, True) -> + let zs = tail (inits (compositeFields pdef)) + sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs + sql2 islast a = connEscapeName conn (fieldDB a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? " + sqlcl = T.intercalate " or " sql1 + in (wrapSql sqlcl, concat (tail (inits xs))) + (_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys" + _ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals + (True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef + + _ -> case (isNull, pfilter, varCount) of + (True, Eq, _) -> (name <> " IS NULL", []) + (True, Ne, _) -> (name <> " IS NOT NULL", []) + (False, Ne, _) -> (T.concat + [ "(" + , name + , " IS NULL OR " + , name + , " <> " + , qmarks + , ")" + ], notNullVals) + -- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since + -- not all databases support those words directly. + (_, In, 0) -> ("1=2" <> orNullSuffix, []) + (False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals) + (True, In, _) -> (T.concat + [ "(" + , name + , " IS NULL OR " + , name + , " IN " + , qmarks + , ")" + ], notNullVals) + (_, NotIn, 0) -> ("1=1", []) + (False, NotIn, _) -> (T.concat + [ "(" + , name + , " IS NULL OR " + , name + , " NOT IN " + , qmarks + , ")" + ], notNullVals) + (True, NotIn, _) -> (T.concat + [ "(" + , name + , " IS NOT NULL AND " + , name + , " NOT IN " + , qmarks + , ")" + ], notNullVals) + _ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals) + + where + isCompFilter Lt = True + isCompFilter Le = True + isCompFilter Gt = True + isCompFilter Ge = True + isCompFilter _ = False + + wrapSql sqlcl = "(" <> sqlcl <> ")" + fromPersistList (PersistList xs) = xs + fromPersistList other = error $ "expected PersistList but found " ++ show other + + filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue] + filterValueToPersistValues v = map toPersistValue $ either return id v + + orNullSuffix = + case orNull of + OrNullYes -> mconcat [" OR ", name, " IS NULL"] + OrNullNo -> "" + + isNull = any (== PersistNull) allVals + notNullVals = filter (/= PersistNull) allVals + allVals = filterValueToPersistValues value + tn = connEscapeName conn $ entityDB + $ entityDef $ dummyFromFilts [Filter field value pfilter] + name = + (if includeTable + then ((tn <> ".") <>) + else id) + $ connEscapeName conn $ fieldName field + qmarks = case value of + Left _ -> "?" + Right x -> + let x' = filter (/= PersistNull) $ map toPersistValue x + in "(" <> T.intercalate "," (map (const "?") x') <> ")" + varCount = case value of + Left _ -> 1 + Right x -> length x + showSqlFilter Eq = "=" + showSqlFilter Ne = "<>" + showSqlFilter Gt = ">" + showSqlFilter Lt = "<" + showSqlFilter Ge = ">=" + showSqlFilter Le = "<=" + showSqlFilter In = " IN " + showSqlFilter NotIn = " NOT IN " + showSqlFilter (BackendSpecificFilter s) = s + +updatePersistValue :: Update v -> PersistValue +updatePersistValue (Update _ v _) = toPersistValue v +updatePersistValue _ = error "BackendUpdate not implemented" + +filterClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) + => Bool -- ^ include table name? + -> SqlBackend + -> [Filter val] + -> Text +filterClause b c = fst . filterClauseHelper b True c OrNullNo + +orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) + => Bool -- ^ include the table name + -> SqlBackend + -> SelectOpt val + -> Text +orderClause includeTable conn o = + case o of + Asc x -> name x + Desc x -> name x <> " DESC" + _ -> error "orderClause: expected Asc or Desc, not limit or offset" + where + dummyFromOrder :: SelectOpt a -> Maybe a + dummyFromOrder _ = Nothing + + tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o + + name :: (PersistEntityBackend record ~ SqlBackend, PersistEntity record) + => EntityField record typ -> Text + name x = + (if includeTable + then ((tn <> ".") <>) + else id) + $ connEscapeName conn $ fieldName x diff --git a/src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs b/src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs new file mode 100644 index 0000000..7e15eff --- /dev/null +++ b/src/Database/Persist/Local/Sql/Orphan/PersistQueryRecursive.hs @@ -0,0 +1,333 @@ +{- This file is part of Vervis. + - + - Written in 2016 by fr33domlover . + - + - ♡ Copying is an act of love. Please copy, reuse and share. + - + - The author(s) have dedicated all copyright and related and neighboring + - rights to this software to the public domain worldwide. This software is + - distributed without any warranty. + - + - You should have received a copy of the CC0 Public Domain Dedication along + - with this software. If not, see + - . + -} + +module Database.Persist.Local.Sql.Orphan.PersistQueryRecursive + ( deleteRecursivelyWhereCount + , updateRecursivelyWhereCount + ) +where + +import Prelude + +import Control.Monad (void) +import Control.Monad.IO.Class +import Control.Monad.Trans.Reader (ReaderT, ask) +import Control.Exception (throwIO) +import Data.ByteString.Char8 (readInteger) +import Data.Conduit (($=)) +import Data.Foldable (find) +import Data.Int (Int64) +import Data.Maybe (isJust) +import Data.Monoid ((<>)) +import Data.Text (Text) +import Database.Persist +import Database.Persist.Sql +import Database.Persist.Sql.Util + +import qualified Data.Conduit.List as CL (head, mapM) +import qualified Data.Text as T (pack, unpack, intercalate) + +import Database.Persist.Local.Class.PersistQueryRecursive +import Database.Persist.Local.Sql.Orphan.Common + +instance PersistQueryRecursive SqlBackend where + updateRecursivelyWhere dir field root filts upds = + void $ updateRecursivelyWhereCount dir field root filts upds + + deleteRecursivelyWhere dir field root filts = + void $ deleteRecursivelyWhereCount dir field root filts + + selectRecursivelySourceRes dir field root filts opts = do + conn <- ask + let (sql, vals, parse) = sqlValsParse conn + srcRes <- rawQueryRes sql vals + return $ fmap ($= CL.mapM parse) srcRes + where + sqlValsParse conn = (sql, vals, parse) + where + (temp, isRoot, cols, qcols, sqlWith) = + withRecursive dir field root conn t (flip entityColumnNames) + + (limit, offset, orders) = limitOffsetOrder opts + + parse xs = + case parseEntityValues t xs of + Left s -> liftIO $ throwIO $ PersistMarshalError s + Right row -> return row + t = entityDef $ dummyFromFilts filts + wher = + if null filts + then "" + else filterClause False conn filts + ord = + case map (orderClause False conn) orders of + [] -> "" + ords -> " ORDER BY " <> T.intercalate "," ords + sql = + mappend sqlWith $ + connLimitOffset conn (limit, offset) (not $ null orders) $ + mconcat + [ "SELECT " + , cols + , " FROM " + , connEscapeName conn temp + , wher + , ord + ] + vals = getFiltsValues conn $ isRoot : filts + + selectRecursivelyKeysRes dir field root filts opts = do + conn <- ask + let (sql, vals, parse) = sqlValsParse conn + srcRes <- rawQueryRes sql vals + return $ fmap ($= CL.mapM parse) srcRes + where + sqlValsParse conn = (sql, vals, parse) + where + (temp, isRoot, cols, qcols, sqlWith) = + withRecursive dir field root conn t dbIdColumns + + (limit, offset, orders) = limitOffsetOrder opts + + parse xs = do + keyvals <- + case entityPrimary t of + Nothing -> + case xs of + [PersistInt64 x] -> + return [PersistInt64 x] + [PersistDouble x] -> + -- oracle returns Double + return [PersistInt64 $ truncate x] + _ -> + liftIO $ throwIO $ PersistMarshalError $ + "Unexpected in selectKeys False: " <> + T.pack (show xs) + Just pdef -> + let pks = map fieldHaskell $ compositeFields pdef + keyvals = + map snd $ + filter + (\ (a, _) -> + let ret = isJust (find (== a) pks) + in ret + ) $ + zip (map fieldHaskell $ entityFields t) xs + in return keyvals + case keyFromValues keyvals of + Right k -> return k + Left _ -> error "selectKeysImpl: keyFromValues failed" + t = entityDef $ dummyFromFilts filts + wher = + if null filts + then "" + else filterClause False conn filts + ord = + case map (orderClause False conn) orders of + [] -> "" + ords -> " ORDER BY " <> T.intercalate "," ords + sql = + mappend sqlWith $ + connLimitOffset conn (limit, offset) (not $ null orders) $ + mconcat + [ "SELECT " + , cols + , " FROM " + , connEscapeName conn temp + , wher + , ord + ] + vals = getFiltsValues conn $ isRoot : filts + + countRecursively dir field root filts = do + conn <- ask + let (sql, vals) = sqlAndVals conn + withRawQuery sql vals $ do + mm <- CL.head + case mm of + Just [PersistInt64 i] -> return $ fromIntegral i + Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle + Just [PersistByteString i] -> case readInteger i of -- gb mssql + Just (ret,"") -> return $ fromIntegral ret + xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]" + Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]" + Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]" + where + sqlAndVals conn = (sql, vals) + where + (temp, isRoot, cols, qcols, sqlWith) = + withRecursive dir field root conn t dbIdColumns + + t = entityDef $ dummyFromFilts filts + wher = + if null filts + then "" + else filterClause False conn filts + sql = mconcat + [ sqlWith + , "SELECT COUNT(*) FROM " + , connEscapeName conn temp + , wher + ] + vals = getFiltsValues conn $ isRoot : filts + +-- | Same as 'deleteRecursivelyWhere', but returns the number of rows affected. +deleteRecursivelyWhereCount + :: (PersistEntity val, MonadIO m, PersistEntityBackend val ~ SqlBackend) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> ReaderT SqlBackend m Int64 +deleteRecursivelyWhereCount dir field root filts = do + conn <- ask + let (sql, vals) = sqlAndVals conn + rawExecuteCount sql vals + where + sqlAndVals conn = (sql, vals) + where + (temp, isRoot, cols, qcols, sqlWith) = + withRecursive dir field root conn t dbIdColumns + + t = entityDef $ dummyFromFilts filts + wher = mconcat + [ if null filts + then " WHERE ( " + else filterClause False conn filts <> " AND ( " + , connEscapeName conn $ fieldDB $ entityId t + , " IN (SELECT " + , connEscapeName conn $ fieldDB $ entityId t + , " FROM " + , connEscapeName conn temp + , ") ) " + ] + sql = mconcat + [ sqlWith + , "DELETE FROM " + , connEscapeName conn $ entityDB t + , wher + ] + vals = getFiltsValues conn $ isRoot : filts + +-- | Same as 'updateRecursivelyWhere', but returns the number of rows affected. +updateRecursivelyWhereCount + :: (PersistEntity val, MonadIO m, SqlBackend ~ PersistEntityBackend val) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> [Filter val] + -> [Update val] + -> ReaderT SqlBackend m Int64 +updateRecursivelyWhereCount _ _ _ _ [] = return 0 +updateRecursivelyWhereCount dir field root filts upds = do + conn <- ask + let (sql, vals) = sqlAndVals conn + rawExecuteCount sql vals + where + sqlAndVals conn = (sql, vals) + where + (temp, isRoot, cols, qcols, sqlWith) = + withRecursive dir field root conn t dbIdColumns + + t = entityDef $ dummyFromFilts filts + + go'' n Assign = n <> "=?" + go'' n Add = mconcat [n, "=", n, "+?"] + go'' n Subtract = mconcat [n, "=", n, "-?"] + go'' n Multiply = mconcat [n, "=", n, "*?"] + go'' n Divide = mconcat [n, "=", n, "/?"] + go'' _ (BackendSpecificUpdate up) = + error $ T.unpack $ "BackendSpecificUpdate " <> up <> " not supported" + go' (x, pu) = go'' (connEscapeName conn x) pu + go x = (updateField x, updateUpdate x) + + updateField (Update f _ _) = fieldName f + updateField _ = error "BackendUpdate not implemented" + + wher = mconcat + [ if null filts + then " WHERE ( " + else filterClause False conn filts <> " AND ( " + , connEscapeName conn $ fieldDB $ entityId t + , " IN (SELECT " + , connEscapeName conn $ fieldDB $ entityId t + , " FROM " + , connEscapeName conn temp + , ") ) " + ] + sql = mconcat + [ sqlWith + , "UPDATE " + , connEscapeName conn $ entityDB t + , " SET " + , T.intercalate "," $ map (go' . go) upds + , wher + ] + vals = + getFiltsValues conn [isRoot] ++ + map updatePersistValue upds ++ + getFiltsValues conn filts + +withRecursive + :: (PersistEntity val, SqlBackend ~ PersistEntityBackend val) + => RecursionDirection + -> EntityField val (Maybe (Key val)) + -> Key val + -> SqlBackend + -> EntityDef + -> (SqlBackend -> EntityDef -> [Text]) + -> (DBName, Filter val, Text, DBName -> Text, Text) +withRecursive dir field root conn t getcols = + let temp = DBName "temp_hierarchy_cte" + isRoot = persistIdField ==. root + cols = T.intercalate "," $ getcols conn t + qcols name = + T.intercalate ", " $ + map ((connEscapeName conn name <>) . ("." <>)) $ + getcols conn t + sql = mconcat + [ "WITH RECURSIVE " + , connEscapeName conn temp + , "(" + , cols + , ") AS ( SELECT " + , cols + , " FROM " + , connEscapeName conn $ entityDB t + , filterClause False conn [isRoot] + --, " WHERE " + --, connEscapeName conn $ fieldDB $ entityId t + --, " = ?" + , " UNION SELECT " + , qcols temp + , " FROM " + , connEscapeName conn $ entityDB t + , ", " + , connEscapeName conn temp + , " WHERE " + , connEscapeName conn $ entityDB t + , "." + , connEscapeName conn $ fieldDB $ case dir of + RecOut -> persistFieldDef field + RecIn -> entityId t + , " = " + , connEscapeName conn temp + , "." + , connEscapeName conn $ fieldDB $ case dir of + RecOut -> entityId t + RecIn -> persistFieldDef field + , " ) " + ] + in (temp, isRoot, cols, qcols, sql) diff --git a/vervis.cabal b/vervis.cabal index 9312f9e..bf62658 100644 --- a/vervis.cabal +++ b/vervis.cabal @@ -64,6 +64,9 @@ library Database.Esqueleto.Local Database.Persist.Class.Local Database.Persist.Sql.Local + Database.Persist.Local.Class.PersistQueryRecursive + Database.Persist.Local.Sql.Orphan.Common + Database.Persist.Local.Sql.Orphan.PersistQueryRecursive Development.DarcsRev Formatting.CaseInsensitive Network.SSH.Local @@ -219,6 +222,8 @@ library , memory , monad-control , monad-logger + -- for Database.Persist.Local + , mtl , pandoc , pandoc-types -- for PathPiece instance for CI, Web.PathPieces.Local @@ -227,6 +232,8 @@ library , persistent-postgresql , persistent-template , process + -- for Database.Persist.Local + , resourcet , safe , shakespeare , ssh