Source code for tupa.action
from .config import Config, COMPOUND
from .labels import Labels
[docs]class Action(dict):
type_to_id = {}
def __init__(self, action_type, tag=None, orig_edge=None, orig_node=None, oracle=None, id_=None):
self.type = action_type # String
self.tag = tag # Usually the tag of the created edge; but if COMPOUND_SWAP, the distance
self.orig_node = orig_node # Node created by this action, if any (during training)
self.orig_edge = orig_edge # Edge created by this action, if any (during training)
self.node = None # Will be set by State when the node created by this action is known
self.edge = None # Will be set by State when the edge created by this action is known
self.oracle = oracle # Reference to oracle, to inform it of actually created nodes/edges
self.index = None # Index of this action in history
self.type_id = Action.type_to_id.get(self.type) # Allocate ID for fast comparison
if self.type_id is None:
self.type_id = len(Action.type_to_id)
Action.type_to_id[self.type] = self.type_id
self.id = id_
super().__init__(action_type=self.type, tag=self.tag)
[docs] def is_type(self, *others):
return self.type_id in (o.type_id for o in others)
[docs] def apply(self):
if self.oracle is not None:
self.oracle.remove(self.orig_edge, self.orig_node)
def __repr__(self):
return Action.__name__ + "(" + ", ".join(map(str, filter(None, (self.type, self.tag)))) + ")"
def __str__(self):
s = self.type
if self.tag:
s += "-%s" % self.tag
return s
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return hash(self.id)
[docs] def __call__(self, *args, **kwargs):
return Action(self.type, *args, **kwargs)
@property
def remote(self):
return self.is_type(Actions.RemoteNode, Actions.LeftRemote, Actions.RightRemote)
@property
def is_swap(self):
return self.is_type(Actions.Swap)
[docs]class Actions(Labels):
Shift = Action("SHIFT")
Node = Action("NODE")
RemoteNode = Action("REMOTE-NODE")
Implicit = Action("IMPLICIT")
Label = Action("LABEL")
Reduce = Action("REDUCE")
LeftEdge = Action("LEFT-EDGE")
RightEdge = Action("RIGHT-EDGE")
LeftRemote = Action("LEFT-REMOTE")
RightRemote = Action("RIGHT-REMOTE")
Swap = Action("SWAP")
Finish = Action("FINISH")
def __init__(self, actions=None, size=None):
super().__init__(size=size)
self._all = None
self._ids = None
if actions is not None:
self.all = actions
[docs] def init(self):
# edge and node action will be created as they are returned by the oracle
args = Config().args
self.all = [Actions.Reduce, Actions.Shift, Actions.Finish] + \
(list(map(Actions.Swap, range(1, args.max_swap))) if args.swap == COMPOUND else
[Actions.Swap] if args.swap else [])
@property
def all(self):
if self._all is None:
self.init()
return self._all
@all.setter
def all(self, actions):
self._all = [Action(**a) if isinstance(a, dict) else a for a in actions]
self._ids = {(action.type_id, action.tag): i for i, action in enumerate(self._all)}
for action in self._all:
self.generate_id(action)
@property
def ids(self):
if self._all is None:
self.init()
return self._ids
[docs] def generate_id(self, action, create=True):
if action.id is None:
key = (action.type_id, action.tag)
action.id = self.ids.get(key)
if create and action.id is None: # New action, add to list
# noinspection PyTypeChecker
action.id = len(self.all)
self.all.append(action(tag=action.tag, id_=action.id))
self.ids[key] = action.id