# yiyiyacc - short for "Yes, Indeed! Yiyiyacc *Is* Yet Another Compiler Compiler!" # # A compiler compiler for generating extremely slim inline-style C compiler code. # # Copyright (c) 2008 by Matthias Kramm import sys class Spec: def __init__(self, productions, registry): self.productions=productions self.registry=registry # make sure that we have only one initial production if len(self.expand(productions[0].leftside)) == 1: self.start = productions[0].leftside else: self.start = registry.registerNonTerminal('*START*') old = productions[0].leftside self.productions = [Production(self.start, [old])] + self.productions self.end = registry.registerTerminal('\0') @staticmethod def load(filename): productions = [] registry = SymbolRegistry() fi = open(filename, "rb") p = None for linenr,line in enumerate(fi.readlines()): if line[0] in " \t" and ("->" in line or not p): raise SyntaxError("Unable to parse line %d" % (linenr+1)) if line[0] in " \t": p.code += line else: p = Production.parse(line, registry) if not p and "<" in line: line = line.strip() for x in registry.terminals: if line == x.s: x.code = "" p = x continue productions+=[p] fi.close() return Spec(productions, registry) def expand(self, leftside): return [p for p in self.productions if p.leftside==leftside] class State: def __init__(self, parent, elements): self.elements = parent.closure(elements) self.id = [str(e) for e in self.elements] self.id.sort() self.id_without_follow = dict([(e.id_without_follow(),None) for e in self.elements]).keys() self.id_without_follow.sort() self.goto = {} self.comefrom = [] def gotoTo(self, symbol, state): assert symbol not in self.goto self.goto[symbol] = state state.comefrom += [(symbol,self)] def __cmp__(self, other): return cmp(self.id,other.id) class ParseGraph: def __init__(self, spec): self.spec = spec self.productions = spec.productions self.nonterminals = spec.registry.nonterminals self.terminals = spec.registry.terminals self.start = ProductionPos(self, spec.productions[0], 0, [[spec.end]]) self.states = [State(self, [self.start])] while 1: newstates = [] for state in self.states: for symbol in self.nonterminals + self.terminals: if symbol not in state.goto: newstate = self.goto(state, symbol) if newstate: existing = [s for s in self.states if s==newstate] if not existing: newstates += [newstate] else: newstate = existing[0] state.gotoTo(symbol, newstate) if not newstates: break else: self.states += newstates self.optimize() def optimize(self): # perform a lalr table reduce, by merging states which differ only # in the follow sets for num,s in enumerate(self.states): s.tmpid = num s.deleted = 0 #self.show() repeat=1 while repeat: repeat=0 for pos,state1 in enumerate(self.states): e = [i for i in range(len(self.states)) if state1.id_without_follow == self.states[i].id_without_follow] if len(e)>1: repeat=1 s1 = self.states[e[0]] for i in e[1:]: s2 = self.states[i] s1.elements += s2.elements for symbol,otherstate in s2.comefrom: assert otherstate.goto[symbol] == s2 otherstate.goto[symbol] = s1 s1.comefrom += [(symbol,otherstate)] e = e[1:] e.sort() e.reverse() for i in e: del self.states[i] break self.states = [s for s in self.states if not s.deleted] def goto(self, state, symbol): r = [] for p in state.elements: if p.pos < p.len and p.rightside[p.pos] == symbol: r += [ProductionPos(self, p.production, p.pos+1, p.follow)] if r: return State(self, r) else: return None def show(self): print "Terminals:", ", ".join([str(t) for t in self.terminals]) print "Non-Terminals:", ", ".join([str(t) for t in self.nonterminals]) for t in self.nonterminals: if self.voidable(t): print "Non-Terminal",str(t),"can be reduced to ''" for i,state in enumerate(self.states): state.tmpid = i for i,state in enumerate(self.states): print "==== State #"+str(i)+" ====" for p in state.elements: print str(p) print for symbol,state in state.goto.items(): print "For",str(symbol),"jump to state",state.tmpid print def voidable(self, symbol, seen=None): if not seen: seen = {} if symbol == self.spec.end: return 1 if symbol.terminal: return 0 for p in self.productions: if p.leftside == symbol: for j in p.rightside: if symbol in seen: break else: seen[symbol] = 1 if not self.voidable(j, seen): break del seen[symbol] else: return 1 return 0 def first(self, string, seen=None): f = {} if not seen: seen = {} def add(r): if r[0].terminal and str(r) not in f: f[str(r)] = r for symbol in string: if symbol.terminal: add([symbol]) break elif symbol.nonterminal: if symbol not in seen: seen[symbol] = symbol for p in self.spec.expand(symbol): for j in self.first(p.rightside, seen): if j != [self.spec.end]: add(j) del seen[symbol] if not self.voidable(symbol): break else: add([self.spec.end]) ret = [k for k in f.values()] for i in ret: for j in i: assert j.terminal or j.nonterminal return ret def setfirst(self, set): f = [] for string in set: for e in self.first(string): if e not in f: f += [e] #print "first of",[[str(e) for e in s] for s in set], #print "is",[[str(e) for e in s] for s in f] return f def closure(self, p): p = p[:] ids = dict([(pr.id,None) for pr in p]) repeat = 1 while repeat: repeat = 0 for e in p: if e.pos == e.len: continue a = e.rightside[e.pos] for pr in self.productions: if pr.leftside == a: new = ProductionPos(self, pr, 0, e.nextfirst()) if new.id not in ids: ids[new.id] = None p += [new] repeat = 1 return p class ParseTable: def __init__(self, graph): self.accept = len(graph.states) self.width = len(graph.terminals)+len(graph.nonterminals) self.height = len(graph.states) self.productions = graph.productions # assign numbers to everything for num,s in enumerate(graph.states): s.num = num for num,p in enumerate(self.productions): p.num = num for num,t in enumerate(graph.terminals): t.num = num for num,t in enumerate(graph.nonterminals): t.num = num add = len(graph.terminals) self.catchall = None self.terminals = graph.terminals for a in graph.terminals: if a.catchall: self.catchall = a self.left = [(p.leftside.num+add) for p in self.productions] self.plen = [p.len for p in self.productions] self.table = [] for state in graph.states: line = [] for symbol in graph.terminals+graph.nonterminals: if symbol in state.goto: g = state.goto[symbol].num assert g>0 line += [g] else: line += [0] complete = 0 for p in state.elements: if p.iscomplete(): for symbol in [s[0] for s in p.follow]: if p.production.num == 0: # "accept" production value = self.accept else: value = -p.production.num if line[symbol.num] and line[symbol.num]!=value: if line[symbol.num]>0: sys.stderr.write("Shift/Reduce conflict for symbol %s and this production:\n" % `symbol.s`) sys.stderr.write("%s\n" % str(p.production)) elif line[symbol.num]<0: sys.stderr.write("Reduce/Reduce conflict for symbol %s and these two productions:\n" % `symbol.s`) sys.stderr.write("%s\n" % str(p.production)) sys.stderr.write("%s\n" % str(self.productions[-line[symbol.num]])) line[symbol.num] = value self.table += [line] def write_c(self): w = sys.stdout w.write(""" /* automatically generated by yiyiyacc, http://www.quiss.org/yiyiyacc/ */ double parse(const char*s) { static int chr2index[256]; static int initialized=0; if(!initialized) { memset(chr2index, -1, sizeof(chr2index)); """) for num,c in enumerate(self.terminals): s = c.s if s=='\'': s = '\\\''; elif s=='\0': s = '\\0'; if not c.catchall: w.write(" chr2index['%s'] = %d;\n" % (s,num)) w.write(""" } int stackpos = 1; int stack[256]; double values[256]; stack[0]=0; int accept = %s; """ % (self.accept)) w.write(" static int left[%d]=" % len(self.left)) w.write("{"+",".join(["%d"%n for n in self.left])+"}; //production left side\n") w.write(" static int plen[%d]=" % len(self.plen)) w.write("{"+",".join(["%d"%n for n in self.plen])+"}; //production size\n") w.write(" static int table[%d][%d] = {\n" % (self.height, self.width)) first = 1 for line in self.table: if not first: w.write(",\n") first = 0 w.write(" {"); first2 = 1 for value in line: if not first2: w.write(", ") first2=0 w.write("%d" % value) w.write("}"); w.write("};\n"); w.write(""" const char*p = s; while(1) { const char*pnext = p+1; int action; double value; if(!stackpos) { fprintf(stderr, "Error in expression\\n"); return 0.0; } """) if self.catchall is not None: w.write(""" if(chr2index[*p]<0) { action = table[stack[stackpos-1]][%d]; if(action>0) { while(chr2index[*pnext]<0) pnext++; """ % self.catchall.num) for line in self.catchall.code.split("\n"): line = line.strip() if not line: continue line = line.replace('%%','value').replace('%1','p') w.write(" "+line+"\n") w.write( """ } } else { action = table[stack[stackpos-1]][chr2index[*p]]; } """) else: w.write(""" action = table[stack[stackpos-1]][chr2index[*p]]; """) w.write(""" if(action == accept) { return values[stack[stackpos-1]]; } else if(action>0) { // shift if(stackpos>254) { fprintf(stderr, "Stack overflow while parsing expression\\n"); return 0.0; } values[stackpos]=value; stack[stackpos++]=action; p=pnext; } else if(action<0) { // reduce stackpos-=plen[-action]; stack[stackpos] = table[stack[stackpos-1]][left[-action]]; switch(-action) { """) for p in self.productions: if p.code: w.write(" case %d:\n" % p.num) for line in p.code.split("\n"): line = line.strip() line = line.replace('%%','values[stackpos]') count1 = 1 count2 = 0 for r in p.rightside: if r.nonterminal: line = line.replace('%'+str(count1),'values[stackpos+'+str(count2)+']') count1 = count1+1 count2 = count2+1 if line: w.write(" "+line+"\n") w.write(" break;\n") w.write( """ } stackpos++; } else { fprintf(stderr, "Syntax error in expression\\n"); return 0.0; } } } """) return class ProductionPos: def __init__(self, parent, production, pos, follow): self.parent = parent self.production = production self.leftside = production.leftside self.rightside = production.rightside self.len = production.len self.pos = pos self.follow = follow self.id = str(self) def nextfirst(self): if self.pos " for i,r in enumerate(self.production.rightside): if self.pos == i: s += " . " s += " "+str(r)+" " if self.pos == self.len: s += " . " return s def __str__(self): s = self.id_without_follow() s += " {" + ",".join([str(f[0]) for f in self.follow]) + "}" return s class SymbolRegistry: def __init__(self): self.counter = 0 self.terminals = [] self.nonterminals = [] self._terminals = {} self._nonterminals = {} def registerTerminal(self, s): if s not in self._terminals: new = self._terminals[s] = Terminal(s) self.terminals += [new] return new else: return self._terminals[s] def registerNonTerminal(self, s): if s not in self._nonterminals: new = self._nonterminals[s] = NonTerminal(s) self.nonterminals += [new] return new else: return self._nonterminals[s] def register(self, s): if s[0]=='\'' and s[-1]=='\'': s = s[1:-1] if s == '\\\'': s = '\'' return self.registerTerminal(s) elif s[0]=='<' and s[-1]=='>': if s[1:-1] == '*': t = self.registerTerminal("<*>") t.catchall = 1 return t else: print "Unknown <...> expression" else: return self.registerNonTerminal(s) class Symbol: def __init__(self, s, terminal): self.s = s self.terminal = terminal self.nonterminal = not terminal self.catchall = 0 def __str__(self): if self.terminal: if self.s=='\0': return "*END*" else: return "'" + self.s + "'" else: return self.s class Terminal(Symbol): def __init__(self, s): Symbol.__init__(self, s, 1) class NonTerminal(Symbol): def __init__(self, s): Symbol.__init__(self, s, 0) class Production: def __init__(self, leftside, rightside, code=""): self.leftside = leftside self.rightside = rightside self.len = len(self.rightside) self.code = code @staticmethod def parse(s, registry): try: i = s.index("->") except: return None leftside = registry.registerNonTerminal(s[0:i].strip()) rightside = [registry.register(term) for term in (s[i+2:].strip()).split(" ") if term] return Production(leftside, rightside) def __str__(self): return str(self.leftside)+" -> "+(" ".join([str(s) for s in self.rightside])) import sys spec = Spec.load(sys.argv[1]) graph = ParseGraph(spec) table = ParseTable(graph) table.write_c()