from collections import defaultdict from disjoint_set import DisjointSet import networkx as nx import matplotlib.pyplot as plt from graphviz import Source class DFA(object): def __init__(self,states_or_filename,terminals=None,start_state=None, \ transitions=None,final_states=None): if terminals is None: self._get_graph_from_file(states_or_filename) else: assert isinstance(states_or_filename,list) or \ isinstance(states_or_filename,tuple) self.states = states_or_filename assert isinstance(terminals,list) or isinstance(terminals,tuple) self.terminals = terminals assert isinstance(start_state,str) self.start_state = start_state assert isinstance(transitions,dict) self.transitions = transitions assert isinstance(final_states,list) or \ isinstance(final_states,tuple) self.final_states = final_states def draw(self): ''' Draws the dfa using networkx and matplotlib ''' g = nx.DiGraph() for x in self.states: g.add_node(x,shape='doublecircle' if x in self.final_states else 'circle',fillcolor='grey' if x == self.start_state else 'white',style='filled') temp = defaultdict(list) for k,v in self.transitions.items(): temp[(k[0],v)].append(k[1]) for k,v in temp.items(): g.add_edge(k[0],k[1],label=','.join(v)) return Source(nx.drawing.nx_agraph.to_agraph(g)) def _remove_unreachable_states(self): ''' Removes states that are unreachable from the start state ''' g = defaultdict(list) for k,v in self.transitions.items(): g[k[0]].append(v) # do DFS stack = [self.start_state] reachable_states = set() while stack: state = stack.pop() if state not in reachable_states: stack += g[state] reachable_states.add(state) self.states = [state for state in self.states \ if state in reachable_states] self.final_states = [state for state in self.final_states \ if state in reachable_states] self.transitions = { k:v for k,v in self.transitions.items() \ if k[0] in reachable_states} def minimize(self): self._remove_unreachable_states() def order_tuple(a,b): return (a,b) if a < b else (b,a) table = {} sorted_states = sorted(self.states) # initialize the table for i,item in enumerate(sorted_states): for item_2 in sorted_states[i+1:]: table[(item,item_2)] = (item in self.final_states) != (item_2\ in self.final_states) flag = True # table filling method while flag: flag = False for i,item in enumerate(sorted_states): for item_2 in sorted_states[i+1:]: if table[(item,item_2)]: continue # check if the states are distinguishable for w in self.terminals: t1 = self.transitions.get((item,w),None) t2 = self.transitions.get((item_2,w),None) if t1 is not None and t2 is not None and t1 != t2: marked = table[order_tuple(t1,t2)] flag = flag or marked table[(item,item_2)] = marked if marked: break d = DisjointSet(self.states) # form new states for k,v in table.items(): if not v: d.union(k[0],k[1]) self.states = [str(x) for x in range(1,1+len(d.get()))] new_final_states = [] self.start_state = str(d.find_set(self.start_state)) for s in d.get(): for item in s: if item in self.final_states: new_final_states.append(str(d.find_set(item))) break self.transitions = {(str(d.find_set(k[0])),k[1]):str(d.find_set(v)) for k,v in self.transitions.items()} self.final_states = new_final_states def __str__(self): ''' String representation ''' num_of_state = len(self.states) start_state = self.start_state num_of_final = len(self.final_states) return '{} states. {} final states. start state - {}'.format( \ num_of_state,num_of_final,start_state) def _get_graph_from_file(self,filename): ''' Load the graph from file ''' with open(filename,'r') as f: try: lines = f.readlines() states,terminals,start_state,final_states = lines[:4] if states: self.states = states[:-1].split() else: raise Exception('Invalid file format: cannot read states') if terminals: self.terminals = terminals[:-1].split() else: raise Exception('Invalid file format: cannot read terminals') if start_state: self.start_state = start_state[:-1] else: raise Exception('Invalid file format: cannot read start state') if final_states: self.final_states = final_states[:-1].split() else: raise Exception('Invalid file format: cannot read final states') lines = lines[4:] self.transitions = {} for line in lines: current_state,terminal,next_state = line[:-1].split() self.transitions[(current_state,terminal)] = next_state except Exception as e: print("ERROR: ",e)