from semstr.util.amr import LABEL_ATTRIB, LABEL_SEPARATOR
from ucca import layer1
from .action import Actions
from .config import Config, COMPOUND
from .states.state import InvalidActionError
# Constants for readability, used by Oracle.action
RIGHT = PARENT = NODE = 0
LEFT = CHILD = EDGE = 1
LABEL = 2
ACTIONS = ( # index by [NODE/EDGE][PARENT/CHILD or RIGHT/LEFT][True/False (remote)]
( # node actions
(Actions.Node, Actions.RemoteNode), # creating a parent
(Actions.Implicit, None) # creating a child (remote implicit is not allowed)
),
( # edge actions
(Actions.RightEdge, Actions.RightRemote), # creating a right edge
(Actions.LeftEdge, Actions.LeftRemote) # creating a left edge
)
)
[docs]class Oracle:
"""
Oracle to produce gold transition parses given UCCA passages
To be used for creating training data for a transition-based UCCA parser
:param passage gold passage to get the correct edges from
"""
def __init__(self, passage):
self.args = Config().args
self.unlabeled = Config().is_unlabeled()
l1 = passage.layer(layer1.LAYER_ID)
self.nodes_remaining = {n.ID for n in l1.all if n is not l1.heads[0] and
(self.args.linkage or n.tag != layer1.NodeTags.Linkage) and
(self.args.implicit or not is_implicit_node(n))}
self.edges_remaining = {e for n in passage.nodes.values() for e in n if (self.args.linkage or e.tag not in (
layer1.EdgeTags.LinkRelation, layer1.EdgeTags.LinkArgument)) and
(self.args.implicit or not is_implicit_node(e.child)) and
(self.args.remote or not is_remote_edge(e))}
if self.unlabeled: # Keep only one edge between each pair of nodes, since we cannot distinguish between them
unique_edges = {(e.parent.ID, e.child.ID, is_remote_edge(e)): e for e in self.edges_remaining}
self.edges_remaining = set(unique_edges.values())
self.passage = passage
self.found = False
self.log = None
[docs] def get_actions(self, state, all_actions, create=True):
"""
Determine all zero-cost action according to current state
Asserts that the returned action is valid before returning
:param state: current State of the parser
:param all_actions: Actions object used to map actions to IDs
:param create: whether to create new actions if they do not exist yet
:return: dict of action ID to Action
"""
actions = {}
invalid = []
for action in self.generate_actions(state):
all_actions.generate_id(action, create=create)
if action.id is not None:
try:
if self.args.validate_oracle:
state.check_valid_action(action, message=True)
actions[action.id] = action
except InvalidActionError as e:
invalid.append((action, e))
if self.args.validate_oracle:
assert actions, self.generate_log(invalid, state)
return actions
[docs] def generate_log(self, invalid, state):
self.log = "\n".join(["Oracle found no valid action",
state.str("\n"), self.str("\n"),
"Actions returned by the oracle:"] +
[" %s: %s" % (action, e) for (action, e) in invalid] or ["None"])
return self.log
[docs] def generate_actions(self, state):
"""
Determine all zero-cost action according to current state
:param state: current State of the parser
:return: generator of Action items to perform
"""
self.found = False
if state.stack:
s0 = state.stack[-1]
incoming, outgoing = [[e for e in l if e in self.edges_remaining]
for l in (s0.orig_node.incoming, s0.orig_node.outgoing)]
if not incoming and not outgoing and not self.need_label(s0):
yield self.action(Actions.Reduce)
else:
# Check for node label action: if all terminals have already been connected
if self.need_label(s0) and not any(is_terminal_edge(e) for e in outgoing):
yield self.action(s0, LABEL, 1)
# Check for actions to create new nodes
for edge in incoming:
if edge.parent.ID in self.nodes_remaining and not is_implicit_node(edge.parent) and (
not is_remote_edge(edge) or
# Allow remote parent if all its children are remote/implicit
all(is_remote_edge(e) or is_implicit_node(e.child) for e in edge.parent)):
yield self.action(edge, NODE, PARENT) # Node or RemoteNode
for edge in outgoing:
if edge.child.ID in self.nodes_remaining and is_implicit_node(edge.child) and (
not is_remote_edge(edge)): # Allow implicit child if it is not remote
yield self.action(edge, NODE, CHILD) # Implicit
if len(state.stack) > 1:
s1 = state.stack[-2]
# Check for node label action: if all terminals have already been connected
if self.need_label(s1) and not any(is_terminal_edge(e) for e in
self.edges_remaining.intersection(s1.orig_node.outgoing)):
yield self.action(s1, LABEL, 2)
# Check for actions to create binary edges
for edge in incoming:
if edge.parent.ID == s1.node_id:
yield self.action(edge, EDGE, RIGHT) # RightEdge or RightRemote
for edge in outgoing:
if edge.child.ID == s1.node_id:
yield self.action(edge, EDGE, LEFT) # LeftEdge or LeftRemote
elif state.buffer and edge.child.ID == state.buffer[0].node_id and \
len(state.buffer[0].orig_node.incoming) == 1:
yield self.action(Actions.Shift) # Special case to allow discarding simple children quickly
if not self.found:
# Check if a swap is necessary, and how far (if compound swap is enabled)
related = dict([(edge.child.ID, edge) for edge in outgoing] +
[(edge.parent.ID, edge) for edge in incoming])
distance = None # Swap distance (how many nodes in the stack to swap)
for i, s in enumerate(state.stack[-3::-1], start=1): # Skip top two: checked above, not related
edge = related.pop(s.node_id, None)
if edge is not None:
if not self.args.swap: # We have no chance to reach it, so stop trying
self.remove(edge)
continue
if distance is None and self.args.swap == COMPOUND: # Save the first one
distance = min(i, Config().args.max_swap) # Do not swap more than allowed
if not related: # All related nodes are in the stack
yield self.action(Actions.Swap(distance))
break
if not self.found:
yield self.action(Actions.Shift if state.buffer else Actions.Finish)
[docs] def action(self, edge, kind=None, direction=None):
self.found = True
if kind is None:
return edge # Will be just an Action object in this case
if kind == LABEL:
return Actions.Label(direction, orig_node=edge.orig_node, oracle=self)
node = (edge.parent, edge.child)[direction] if kind == NODE else None
tag = "" if self.unlabeled else edge.tag
return ACTIONS[kind][direction][is_remote_edge(edge)](tag=tag, orig_edge=edge, orig_node=node, oracle=self)
[docs] def remove(self, edge, node=None):
self.edges_remaining.discard(edge)
if node is not None:
self.nodes_remaining.discard(node.ID)
[docs] def need_label(self, node):
return self.args.node_labels and not self.args.use_gold_node_labels \
and not node.labeled and node.orig_node.attrib.get(LABEL_ATTRIB)
[docs] def get_label(self, state, node):
true_label = raw_true_label = None
if node.orig_node is not None:
raw_true_label = node.orig_node.attrib.get(LABEL_ATTRIB)
if raw_true_label is not None:
true_label, _, _ = raw_true_label.partition(LABEL_SEPARATOR)
if self.args.validate_oracle:
try:
state.check_valid_label(true_label, message=True)
except InvalidActionError as e:
raise InvalidActionError("True label is invalid: " + "\n".join(map(str, (true_label, state, e))))
return true_label, raw_true_label
[docs] def str(self, sep):
return "nodes left: [%s]%sedges left: [%s]" % (" ".join(self.nodes_remaining), sep,
" ".join(map(str, self.edges_remaining)))
def __str__(self):
return str(" ")
[docs]def is_terminal_edge(edge):
return edge.tag == layer1.EdgeTags.Terminal
[docs]def is_remote_edge(edge):
return edge.attrib.get("remote", False)
[docs]def is_implicit_node(node):
return node.attrib.get("implicit", False)