Skip to content

Commit 9f2c116

Browse files
committed
Create function for pretty printing of graphs
1 parent b6cc5a1 commit 9f2c116

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import .EquationModule: constructorof, preserve_sharing
3636
has_constants,
3737
get_constants,
3838
set_constants!
39-
@reexport import .StringsModule: string_tree, print_tree
39+
@reexport import .StringsModule: string_tree, print_tree, pretty_string_graph
4040
@reexport import .OperatorEnumModule: AbstractOperatorEnum
4141
@reexport import .OperatorEnumConstructionModule:
4242
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!

src/Strings.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module StringsModule
22

3+
import Compat: Returns
4+
35
import ..UtilsModule: deprecate_varmap
46
import ..OperatorEnumModule: AbstractOperatorEnum
57
import ..EquationModule: AbstractExpressionNode, tree_mapreduce
@@ -181,4 +183,131 @@ for io in ((), (:(io::IO),))
181183
end
182184
end
183185

186+
function pretty_string_graph(
187+
tree::AbstractExpressionNode{T},
188+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
189+
f_variable::F1=string_variable,
190+
f_constant::F2=string_constant,
191+
variable_names::Union{Vector{String},Nothing}=nothing,
192+
) where {T,F1,F2}
193+
output = Char[]
194+
195+
# First, we build a mapping of shared nodes. We skip varible nodes
196+
# as it wouldn't simplify printing.
197+
shared_nodes = _build_shared_node_dict(tree, node -> node.degree != 0 || node.constant)
198+
# NOTE THAT THIS IS ASSUMED BY _leaf_string_or_shared_variable
199+
200+
# We collect the shared nodes in order of appearance, and will print
201+
# them in that order (deepest first)
202+
iter = collect(values(shared_nodes))
203+
sort!(iter; by=x -> x.index)
204+
205+
# We also want to print the final expression:
206+
push!(iter, (node=tree, index=length(iter) + 1))
207+
208+
for (; node, index) in iter
209+
raw_output, _ = tree_mapreduce(
210+
leaf -> _leaf_string_or_shared_variable(
211+
index, leaf, shared_nodes; f_variable, f_constant, variable_names
212+
),
213+
branch ->
214+
_branch_string_or_shared_variable(index, branch, shared_nodes, operators),
215+
_combine_op_with_inputs_or_shared_variable,
216+
node,
217+
@NamedTuple{chars::Vector{Char}, shared::Bool};
218+
break_sharing=Val(true),
219+
)
220+
is_output = index == length(iter)
221+
if is_output
222+
if !isempty(shared_nodes)
223+
append!(output, ('', '\n'))
224+
end
225+
append!(output, ('=', ' '))
226+
else
227+
append!(output, (index == 1 ? '' : '', '', '', '', ' '))
228+
append!(output, _get_z_name(index))
229+
append!(output, (' ', '=', ' '))
230+
end
231+
append!(output, strip_brackets(raw_output))
232+
append!(output, ('\n',))
233+
end
234+
return String(output)
235+
end
236+
237+
function _build_shared_node_dict(tree, filter::F=Returns(true)) where {F}
238+
i = Ref(0)
239+
node_counts = Dict{UInt,@NamedTuple{node::typeof(tree), index::Int}}()
240+
tree_mapreduce(
241+
p -> p,
242+
(p, _...) -> p,
243+
tree,
244+
typeof(tree);
245+
f_on_shared=(node, is_shared) -> begin
246+
if is_shared && filter(node)
247+
oid = objectid(node)
248+
if !haskey(node_counts, oid)
249+
node_counts[oid] = (node=node, index=(i[] += 1))
250+
end
251+
end
252+
node
253+
end,
254+
)
255+
return node_counts
256+
end
257+
258+
function _get_z_name(j::Int)
259+
out = ['z']
260+
for k in digits(j)
261+
push!(out, Char('0' + k))
262+
end
263+
return out
264+
end
265+
266+
function _combine_op_with_inputs_or_shared_variable(p, c...)
267+
if p.shared
268+
# We assume that this is an intermediate variable, so don't wish to expand it!
269+
return (chars=p.chars, shared=false)
270+
else
271+
return (
272+
chars=combine_op_with_inputs(p.chars, (ci -> ci.chars).(c)...), shared=false
273+
)
274+
end
275+
end
276+
277+
function _leaf_string_or_shared_variable(
278+
cur_intermediate_variable,
279+
leaf::AbstractExpressionNode{T},
280+
shared_nodes;
281+
f_variable::F1,
282+
f_constant::F2,
283+
variable_names,
284+
) where {T,F1,F2}
285+
if leaf.constant
286+
oid = objectid(leaf)
287+
if haskey(shared_nodes, oid) &&
288+
(cur_i = @inbounds(shared_nodes[oid][2])) < cur_intermediate_variable
289+
return (chars=_get_z_name(cur_i), shared=true)
290+
else
291+
return (chars=collect(f_constant(leaf.val::T)), shared=false)
292+
end
293+
else
294+
return (chars=collect(f_variable(leaf.feature, variable_names)), shared=false)
295+
end
296+
end
297+
function _branch_string_or_shared_variable(
298+
cur_intermediate_variable, branch, shared_nodes, operators
299+
)
300+
oid = objectid(branch)
301+
if haskey(shared_nodes, oid) &&
302+
(cur_i = @inbounds(shared_nodes[oid][2])) < cur_intermediate_variable
303+
return (chars=_get_z_name(cur_i), shared=true)
304+
else
305+
if branch.degree == 1
306+
return (chars=dispatch_op_name(Val(1), operators, branch.op), shared=false)
307+
else
308+
return (chars=dispatch_op_name(Val(2), operators, branch.op), shared=false)
309+
end
310+
end
311+
end
312+
184313
end

0 commit comments

Comments
 (0)