|
1 | 1 | module StringsModule
|
2 | 2 |
|
| 3 | +import Compat: Returns |
| 4 | + |
3 | 5 | import ..UtilsModule: deprecate_varmap
|
4 | 6 | import ..OperatorEnumModule: AbstractOperatorEnum
|
5 | 7 | import ..EquationModule: AbstractExpressionNode, tree_mapreduce
|
@@ -181,4 +183,131 @@ for io in ((), (:(io::IO),))
|
181 | 183 | end
|
182 | 184 | end
|
183 | 185 |
|
| 186 | +function pretty_string_graph( |
| 187 | + tree::AbstractExpressionNode{T}, |
| 188 | + operators::Union{AbstractOperatorEnum,Nothing}=nothing; |
| 189 | + f_variable::F1=string_variable, |
| 190 | + f_constant::F2=string_constant, |
| 191 | + variable_names::Union{Vector{String},Nothing}=nothing, |
| 192 | +) where {T,F1,F2} |
| 193 | + output = Char[] |
| 194 | + |
| 195 | + # First, we build a mapping of shared nodes. We skip varible nodes |
| 196 | + # as it wouldn't simplify printing. |
| 197 | + shared_nodes = _build_shared_node_dict(tree, node -> node.degree != 0 || node.constant) |
| 198 | + # NOTE THAT THIS IS ASSUMED BY _leaf_string_or_shared_variable |
| 199 | + |
| 200 | + # We collect the shared nodes in order of appearance, and will print |
| 201 | + # them in that order (deepest first) |
| 202 | + iter = collect(values(shared_nodes)) |
| 203 | + sort!(iter; by=x -> x.index) |
| 204 | + |
| 205 | + # We also want to print the final expression: |
| 206 | + push!(iter, (node=tree, index=length(iter) + 1)) |
| 207 | + |
| 208 | + for (; node, index) in iter |
| 209 | + raw_output, _ = tree_mapreduce( |
| 210 | + leaf -> _leaf_string_or_shared_variable( |
| 211 | + index, leaf, shared_nodes; f_variable, f_constant, variable_names |
| 212 | + ), |
| 213 | + branch -> |
| 214 | + _branch_string_or_shared_variable(index, branch, shared_nodes, operators), |
| 215 | + _combine_op_with_inputs_or_shared_variable, |
| 216 | + node, |
| 217 | + @NamedTuple{chars::Vector{Char}, shared::Bool}; |
| 218 | + break_sharing=Val(true), |
| 219 | + ) |
| 220 | + is_output = index == length(iter) |
| 221 | + if is_output |
| 222 | + if !isempty(shared_nodes) |
| 223 | + append!(output, ('│', '\n')) |
| 224 | + end |
| 225 | + append!(output, ('=', ' ')) |
| 226 | + else |
| 227 | + append!(output, (index == 1 ? '┌' : '├', '─', '─', '─', ' ')) |
| 228 | + append!(output, _get_z_name(index)) |
| 229 | + append!(output, (' ', '=', ' ')) |
| 230 | + end |
| 231 | + append!(output, strip_brackets(raw_output)) |
| 232 | + append!(output, ('\n',)) |
| 233 | + end |
| 234 | + return String(output) |
| 235 | +end |
| 236 | + |
| 237 | +function _build_shared_node_dict(tree, filter::F=Returns(true)) where {F} |
| 238 | + i = Ref(0) |
| 239 | + node_counts = Dict{UInt,@NamedTuple{node::typeof(tree), index::Int}}() |
| 240 | + tree_mapreduce( |
| 241 | + p -> p, |
| 242 | + (p, _...) -> p, |
| 243 | + tree, |
| 244 | + typeof(tree); |
| 245 | + f_on_shared=(node, is_shared) -> begin |
| 246 | + if is_shared && filter(node) |
| 247 | + oid = objectid(node) |
| 248 | + if !haskey(node_counts, oid) |
| 249 | + node_counts[oid] = (node=node, index=(i[] += 1)) |
| 250 | + end |
| 251 | + end |
| 252 | + node |
| 253 | + end, |
| 254 | + ) |
| 255 | + return node_counts |
| 256 | +end |
| 257 | + |
| 258 | +function _get_z_name(j::Int) |
| 259 | + out = ['z'] |
| 260 | + for k in digits(j) |
| 261 | + push!(out, Char('0' + k)) |
| 262 | + end |
| 263 | + return out |
| 264 | +end |
| 265 | + |
| 266 | +function _combine_op_with_inputs_or_shared_variable(p, c...) |
| 267 | + if p.shared |
| 268 | + # We assume that this is an intermediate variable, so don't wish to expand it! |
| 269 | + return (chars=p.chars, shared=false) |
| 270 | + else |
| 271 | + return ( |
| 272 | + chars=combine_op_with_inputs(p.chars, (ci -> ci.chars).(c)...), shared=false |
| 273 | + ) |
| 274 | + end |
| 275 | +end |
| 276 | + |
| 277 | +function _leaf_string_or_shared_variable( |
| 278 | + cur_intermediate_variable, |
| 279 | + leaf::AbstractExpressionNode{T}, |
| 280 | + shared_nodes; |
| 281 | + f_variable::F1, |
| 282 | + f_constant::F2, |
| 283 | + variable_names, |
| 284 | +) where {T,F1,F2} |
| 285 | + if leaf.constant |
| 286 | + oid = objectid(leaf) |
| 287 | + if haskey(shared_nodes, oid) && |
| 288 | + (cur_i = @inbounds(shared_nodes[oid][2])) < cur_intermediate_variable |
| 289 | + return (chars=_get_z_name(cur_i), shared=true) |
| 290 | + else |
| 291 | + return (chars=collect(f_constant(leaf.val::T)), shared=false) |
| 292 | + end |
| 293 | + else |
| 294 | + return (chars=collect(f_variable(leaf.feature, variable_names)), shared=false) |
| 295 | + end |
| 296 | +end |
| 297 | +function _branch_string_or_shared_variable( |
| 298 | + cur_intermediate_variable, branch, shared_nodes, operators |
| 299 | +) |
| 300 | + oid = objectid(branch) |
| 301 | + if haskey(shared_nodes, oid) && |
| 302 | + (cur_i = @inbounds(shared_nodes[oid][2])) < cur_intermediate_variable |
| 303 | + return (chars=_get_z_name(cur_i), shared=true) |
| 304 | + else |
| 305 | + if branch.degree == 1 |
| 306 | + return (chars=dispatch_op_name(Val(1), operators, branch.op), shared=false) |
| 307 | + else |
| 308 | + return (chars=dispatch_op_name(Val(2), operators, branch.op), shared=false) |
| 309 | + end |
| 310 | + end |
| 311 | +end |
| 312 | + |
184 | 313 | end
|
0 commit comments