Skip to content

Commit

Permalink
Investigate whether Traverse acceptance can be made to run in O(n) time
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmaxwell3 committed Feb 6, 2025
1 parent 3a9b17e commit f7d284d
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Linq;
using SIL.Extensions;
using SIL.Machine.Annotations;
using SIL.Machine.DataStructures;

namespace SIL.Machine.FiniteState
{
Expand Down Expand Up @@ -34,14 +33,10 @@ public override void CopyTo(TraversalInstance<TData, TOffset> other)
base.CopyTo(other);

var otherDfst = (DeterministicFstTraversalInstance<TData, TOffset>)other;
Dictionary<Annotation<TOffset>, Annotation<TOffset>> outputMappings = Output
.Annotations.SelectMany(a => a.GetNodesBreadthFirst())
.Zip(Output.Annotations.SelectMany(a => a.GetNodesBreadthFirst()))
.ToDictionary(t => t.Item1, t => t.Item2);
otherDfst.Mappings.AddRange(
_mappings.Select(kvp => new KeyValuePair<Annotation<TOffset>, Annotation<TOffset>>(
kvp.Key,
outputMappings[kvp.Value]
kvp.Value
))
);
foreach (Annotation<TOffset> ann in _queue)
Expand Down
188 changes: 183 additions & 5 deletions src/SIL.Machine/FiniteState/TraversalMethodBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using SIL.Extensions;
using SIL.Machine.Annotations;
using SIL.Machine.DataStructures;
using SIL.Machine.FeatureModel;
Expand All @@ -20,6 +22,8 @@ internal abstract class TraversalMethodBase<TData, TOffset, TInst> : ITraversalM
private readonly bool _useDefaults;
private readonly List<Annotation<TOffset>> _annotations;
private readonly Queue<TInst> _cachedInstances;
private readonly IDictionary<TInst, IList<CommandUpdate>> _commandUpdates;
private readonly IDictionary<TInst, IList<TraverseOutput>> _outputs;

protected TraversalMethodBase(
Fst<TData, TOffset> fst,
Expand Down Expand Up @@ -61,6 +65,8 @@ bool useDefaults
}
}
_cachedInstances = new Queue<TInst>();
_commandUpdates = new Dictionary<TInst, IList<CommandUpdate>>();
_outputs = new Dictionary<TInst, IList<TraverseOutput>>();
}

private int CompareAnnotations(Annotation<TOffset> x, Annotation<TOffset> y)
Expand Down Expand Up @@ -94,6 +100,39 @@ public abstract IEnumerable<FstResult<TData, TOffset>> Traverse(
ISet<int> initAnns
);

protected class CommandUpdate
{
public TInst Source;
public Arc<TData, TOffset> Arc;
public IEnumerable<TagMapCommand> Cmds;
public Register<TOffset> Start;
public Register<TOffset> End;

public CommandUpdate (TInst source, Arc<TData, TOffset> arc, IEnumerable<TagMapCommand> cmds, Register<TOffset> start, Register<TOffset> end)
{
Source = source;
Arc = arc;
Cmds = cmds;
Start = start;
End = end;
}
}

protected void RecordCommands(
TInst source,
Arc<TData, TOffset> arc,
IEnumerable<TagMapCommand> cmds,
Register<TOffset> start,
Register<TOffset> end,
TInst target
)
{
var commandUpdate = new CommandUpdate(source, arc, cmds, start, end);
if (!_commandUpdates.ContainsKey(target))
_commandUpdates[target] = new List<CommandUpdate>();
_commandUpdates[target].Add(commandUpdate);
}

protected static void ExecuteCommands(
Register<TOffset>[,] registers,
IEnumerable<TagMapCommand> cmds,
Expand Down Expand Up @@ -134,7 +173,8 @@ private void CheckAccepting(
VariableBindings varBindings,
Arc<TData, TOffset> arc,
ICollection<FstResult<TData, TOffset>> curResults,
IList<int> priorities
IList<int> priorities,
TInst inst
)
{
if (arc.Target.IsAccepting && (!_endAnchor || annIndex == _annotations.Count))
Expand All @@ -143,6 +183,15 @@ IList<int> priorities
annIndex < _annotations.Count ? _annotations[annIndex] : _data.Annotations.GetEnd(_fst.Direction);
var matchRegisters = (Register<TOffset>[,])registers.Clone();
ExecuteCommands(matchRegisters, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>());
TInst finalInst = CreateInstance();
RecordCommands(inst, null, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>(), finalInst);
if (!arc.Target.IsAccepting)
{
var outputs = GetOutputs(finalInst);
Debug.Assert(_fst.RegistersEqualityComparer.Equals(outputs[0].Registers, matchRegisters), "registers didn't match");
if (output != null)
Debug.Assert(outputs[0].Output.ToString().Equals(output.ToString()), "output didn't match");
}
if (arc.Target.AcceptInfos.Count > 0)
{
foreach (AcceptInfo<TData, TOffset> acceptInfo in arc.Target.AcceptInfos)
Expand Down Expand Up @@ -190,6 +239,93 @@ IList<int> priorities
}
}

protected class TraverseOutput
{
public Register<TOffset>[,] Registers;
public TData Output;
public IDictionary<Annotation<TOffset>, Annotation<TOffset>> Mappings;
public Queue<Annotation<TOffset>> Queue;

public TraverseOutput(Register<TOffset>[,] registers, TData output, Dictionary<Annotation<TOffset>, Annotation<TOffset>> mappings)
{
Registers = registers;
Output = output;
Mappings = mappings;
Queue = new Queue<Annotation<TOffset>>();
}

public TraverseOutput(TraverseOutput other)
{
Registers = (Register<TOffset>[,])other.Registers.Clone();
Output = ((ICloneable<TData>)other.Output).Clone();
Mappings = other.Mappings;
Queue = new Queue<Annotation<TOffset>>(other.Queue);
}
}

private IList<TraverseOutput> GetOutputs(TInst inst)
{
if (inst != null && _outputs.ContainsKey(inst))
return _outputs[inst];
IList<TraverseOutput> outputs = new List<TraverseOutput>();
IList<CommandUpdate> updates = GetCommandUpdates(inst);
if (updates.Count == 0)
{
// We are at the beginning.
var registers = inst != null ? inst.Registers : new Register<TOffset>[Fst.RegisterCount, 2];
var dataOutput = ((ICloneable<TData>)Data).Clone();
var mappings = new Dictionary<Annotation<TOffset>, Annotation<TOffset>>();
mappings.AddRange(Data.Annotations.SelectMany(a => a.GetNodesBreadthFirst())
.Zip(
dataOutput.Annotations.SelectMany(a => a.GetNodesBreadthFirst()),
(a1, a2) => new KeyValuePair<Annotation<TOffset>, Annotation<TOffset>>(a1, a2)
));
var output = new TraverseOutput(registers, dataOutput, mappings);
outputs.Add(output);
return outputs;
}
foreach (CommandUpdate update in updates)
{
foreach(TraverseOutput output in GetOutputs(update.Source))
{
var newOutput = new TraverseOutput(output);
ExecuteCommands(newOutput.Registers, update.Cmds, update.Start, update.End);
if (update.Arc != null)
{
for (int j = 0; j < update.Arc.Input.EnqueueCount; j++)
newOutput.Queue.Enqueue(Annotations[update.Source.AnnotationIndex]);

Annotation<TOffset> prevNewAnn = null;
foreach (Output<TData, TOffset> outputAction in update.Arc.Outputs)
{
Annotation<TOffset> outputAnn;
if (outputAction.UsePrevNewAnnotation && prevNewAnn != null)
{
outputAnn = prevNewAnn;
}
else
{
Annotation<TOffset> inputAnn = newOutput.Queue.Dequeue();
outputAnn = output.Mappings[inputAnn];
}
prevNewAnn = outputAction.UpdateOutput(newOutput.Output, outputAnn, Fst.Operations);
}
}
outputs.Add(newOutput);
}
}
return outputs;
}

private IList<CommandUpdate> GetCommandUpdates(TInst inst)
{
if (inst == null)
return new List<CommandUpdate>();
if (!_commandUpdates.ContainsKey(inst))
_commandUpdates[inst] = new List<CommandUpdate>();
return _commandUpdates[inst];
}

protected IEnumerable<TInst> Initialize(
ref int annIndex,
Register<TOffset>[,] registers,
Expand Down Expand Up @@ -221,6 +357,13 @@ ISet<int> initAnns
}
}

var startInst = CreateInstance();
for (int i = 0; i < registers.Length/2; i++)
{
for (int j = 0; j < 2; j++)
startInst.Registers[i, j] = registers[i, j];
}

ExecuteCommands(registers, cmds, new Register<TOffset>(offset, true), new Register<TOffset>());

for (
Expand All @@ -242,6 +385,9 @@ ISet<int> initAnns
}
}

foreach (var inst in insts)
RecordCommands(startInst, null, cmds, new Register<TOffset>(offset, true), new Register<TOffset>(), inst);

return insts;
}

Expand All @@ -253,6 +399,8 @@ protected IEnumerable<TInst> Advance(
bool optional = false
)
{
TInst source = inst;
inst = CopyInstance(inst);
inst.Priorities?.Add(arc.Priority);
int nextIndex = GetNextNonoverlappingAnnotationIndex(inst.AnnotationIndex);
TOffset nextOffset;
Expand Down Expand Up @@ -292,6 +440,14 @@ protected IEnumerable<TInst> Advance(
anns.Add(i);
}

RecordCommands(
source,
arc,
arc.Commands,
new Register<TOffset>(nextOffset, nextStart),
new Register<TOffset>(end, false),
inst
);
ExecuteCommands(
inst.Registers,
arc.Commands,
Expand All @@ -307,7 +463,8 @@ protected IEnumerable<TInst> Advance(
varBindings,
arc,
curResults,
inst.Priorities
inst.Priorities,
inst
);
}

Expand All @@ -327,13 +484,21 @@ protected IEnumerable<TInst> Advance(
}
else
{
RecordCommands(
source,
arc,
arc.Commands,
new Register<TOffset>(nextOffset, nextStart),
new Register<TOffset>(end, false),
inst
);
ExecuteCommands(
inst.Registers,
arc.Commands,
new Register<TOffset>(nextOffset, nextStart),
new Register<TOffset>(end, false)
);
CheckAccepting(nextIndex, inst.Registers, inst.Output, varBindings, arc, curResults, inst.Priorities);
CheckAccepting(nextIndex, inst.Registers, inst.Output, varBindings, arc, curResults, inst.Priorities, inst);

inst.State = arc.Target;
inst.AnnotationIndex = nextIndex;
Expand All @@ -348,12 +513,22 @@ protected TInst EpsilonAdvance(
ICollection<FstResult<TData, TOffset>> curResults
)
{
TInst source = inst;
inst = CopyInstance(source);
Annotation<TOffset> ann =
inst.AnnotationIndex < _annotations.Count
? _annotations[inst.AnnotationIndex]
: _data.Annotations.GetEnd(_fst.Direction);
int prevIndex = GetPrevNonoverlappingAnnotationIndex(inst.AnnotationIndex);
Annotation<TOffset> prevAnn = _annotations[prevIndex];
RecordCommands(
source,
arc,
arc.Commands,
new Register<TOffset>(ann.Range.GetStart(_fst.Direction), true),
new Register<TOffset>(prevAnn.Range.GetEnd(_fst.Direction), false),
inst
);
ExecuteCommands(
inst.Registers,
arc.Commands,
Expand All @@ -367,7 +542,8 @@ ICollection<FstResult<TData, TOffset>> curResults
inst.VariableBindings,
arc,
curResults,
inst.Priorities
inst.Priorities,
inst
);

inst.State = arc.Target;
Expand Down Expand Up @@ -418,7 +594,9 @@ protected TInst CopyInstance(TInst inst)

protected void ReleaseInstance(TInst inst)
{
_cachedInstances.Enqueue(inst);
if (inst == null)
return;
// _cachedInstances.Enqueue(inst);
}
}
}
21 changes: 21 additions & 0 deletions tests/SIL.Machine.Tests/Matching/MatcherTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1482,4 +1482,25 @@ public void OptionalAnnotation()
Assert.That(match.GroupCaptures["first"].Range, Is.EqualTo(Range<int>.Create(1, 3)));
Assert.That(match.GroupCaptures["second"].Range, Is.EqualTo(Range<int>.Create(3, 10)));
}

[Test]
public void CanonicalizeOptionalAnnotation()
{
FeatureStruct any = FeatureStruct.New().Value;
Pattern<AnnotatedStringData, int> pattern = Pattern<AnnotatedStringData, int>
.New()
.Group("first", first => first.Annotation(any))
.Group("second", second => second.Annotation(any))
.Group("third", third => third.Annotation(any))
.Value;
AnnotatedStringData word = CreateStringData("axxb");
word.Annotations.ElementAt(1).Optional = true;
word.Annotations.ElementAt(2).Optional = true;
var matcher = new Matcher<AnnotatedStringData, int>(
pattern,
new MatcherSettings<int> { AnchoredToStart = true, AnchoredToEnd = true, AllSubmatches = true }
);
IList<Match<AnnotatedStringData, int>> matches = matcher.AllMatches(word).ToList();
Assert.That(matches, Has.Count.EqualTo(2));
}
}

0 comments on commit f7d284d

Please # to comment.