Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

foldTree does not optimize well #946

Open
meooow25 opened this issue Apr 30, 2023 · 2 comments
Open

foldTree does not optimize well #946

meooow25 opened this issue Apr 30, 2023 · 2 comments

Comments

@meooow25
Copy link
Contributor

We have in Data.Tree

foldTree :: (a -> [b] -> b) -> Tree a -> b 

Unfortunately foldTree does not optimize as well it could when b is a function.

An an example, consider that we want to calculate the sum of depths of nodes in a tree.
We can write a recursive function manually:

depthSum_rec :: Tree a -> Int
depthSum_rec t = go t 0 0 where
    go (Node _ ts) depth acc = foldl' (\acc' t' -> go t' (depth+1) acc') (acc + depth) ts

--    depthSum_rec: OK (0.18s)
--      5.18 ms ± 508 μs

Now let's use foldTree:

depthSum_foldTree :: Tree a -> Int
depthSum_foldTree t = foldTree f t 0 0 where
    f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks

--    depthSum_foldTree: OK (0.34s)
--      43.6 ms ± 3.2 ms

That's a lot worse! The problem is that the list of partially applied functions [b] is manifested, see GHC#23319. According to SPJ this can't be easily improved.


Consider a different fold function which also folds over the [b] without creating it:

foldTree2 :: (a -> b -> c) -> (c -> b -> b) -> b -> Tree a -> c
foldTree2 f c z = go where go (Node x ts) = f x (foldr (c . go) z ts)

Now we can write:

depthSum_foldTree2 :: Tree a -> Int
depthSum_foldTree2 t = foldTree2 f f' (const id) t 0 0 where
    f _ k depth acc = k depth (acc + depth)
    f' k1 k2 depth acc = k2 depth (k1 (depth+1) acc)

--    depthSum_foldTree2: OK (0.23s)
--      5.16 ms ± 376 μs

As good as depthSum_rec! Could we have foldTree2 (perhaps with a better name) in Data.Tree?


The benchmark setup, for completeness
import Data.List
import Data.Tree
import Test.Tasty.Bench

main :: IO ()
main = defaultMain
    [ env (pure binTree) $ \t -> bgroup ""
        [ bench "depthSum_rec" $ whnf depthSum_rec t
        , bench "depthSum_foldTree" $ whnf depthSum_foldTree t
        , bench "depthSum_foldTree2" $ whnf depthSum_foldTree2 t
        ]
    ]

binTree :: Tree Int
binTree = unfoldTree (\x -> (x, takeWhile (<1000000) [2*x + 1, 2*x + 2])) 1
@treeowl
Copy link
Contributor

treeowl commented Apr 30, 2023

The type of the function gives me no clue what it does. That makes me a bit suspicious. Can you write documentation that makes it easy for people to think about? What benefit does this have over doing the folding by hand?

@meooow25
Copy link
Contributor Author

The type of the function gives me no clue what it does.

It is just the replacement of all the constructors involved in a Tree. So foldTree2 Node (:) [] = id.

What benefit does this have over doing the folding by hand?

It lets us avoid writing a recursive function, which often ends up shorter or simpler. Another benefit is that it could participate in fold/build fusion. I have been thinking about this a bit, but perhaps it deserves a separate issue.

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

2 participants