import ast

class CFACreator(ast.NodeVisitor):
    def __init__(self):
        self.root = CFANode()
        self.node_stack = list()
        self.node_stack.append(self.root)
        self.loop_head_stack = list()
        self.exit_loop_stack = list()

    def visit_While(
        self, node
    ):  # TODO: break and continue have to be considered inside loops
        entry_node = self.node_stack.pop()
        self.loop_head_stack.append(entry_node)
        inside = CFANode()
        edge = CFAEdge(entry_node, inside, Instruction.assumption(node.test))
        outside = CFANode()
        edge = CFAEdge(
            entry_node, outside, Instruction.assumption(node.test, negated=True)
        )
        self.exit_loop_stack.append(outside)
        self.node_stack.append(inside)
        for statement in node.body:
            self.visit(statement)
        body_exit_node = self.node_stack.pop()
        CFANode.merge(entry_node, body_exit_node)
        
        self.exit_loop_stack.pop()
        self.loop_head_stack.pop()
        self.node_stack.append(outside)

    def visit_Break(self, node):
        entry_node = self.node_stack.pop()
        to = self.exit_loop_stack[-1]
        edge = CFAEdge(entry_node, to, Instruction.statement(node))
        self.node_stack.append(CFANode())

    def visit_Continue(self, node):
        entry_node = self.node_stack.pop()
        to = self.loop_head_stack[-1]
        edge = CFAEdge(entry_node, to, Instruction.statement(node))
        self.node_stack.append(CFANode())
        
    def visit_If(self, node):
        entry_node = self.node_stack.pop()
        left = CFANode()
        edge = CFAEdge(entry_node, left, Instruction.assumption(node.test))
        right = CFANode()
        edge = CFAEdge(
            entry_node, right, Instruction.assumption(node.test, negated=True)
        )
        self.node_stack.append(left)
        for statement in node.body:
            self.visit(statement)
        left_exit = self.node_stack.pop()
        self.node_stack.append(right)
        for statement in node.orelse:
            self.visit(statement)
        right_exit = self.node_stack.pop()
        merged_exit = CFANode.merge(left_exit, right_exit)
        self.node_stack.append(merged_exit)

    def visit_Expr(self, node):
        entry_node = self.node_stack.pop()
        exit_node = CFANode()
        edge = CFAEdge(entry_node, exit_node, Instruction.statement(node.value))
        self.node_stack.append(exit_node)

    def visit_Assign(self, node):
        entry_node = self.node_stack.pop()
        exit_node = CFANode()
        edge = CFAEdge(entry_node, exit_node, Instruction.statement(node))
        self.node_stack.append(exit_node)
