from __future__ import print_function
import re
from antk.core.node_ops import *
from antk.lib import termcolor
import sys
import os
import traceback
NODE_GLOBALS = globals().copy()
[docs]def ph_rep(ph):
"""
Convenience function for representing a tensorflow placeholder.
:param ph: A `tensorflow`_ `placeholder`_.
:return: A string representing the placeholder.
"""
return 'Placeholder("%s", shape=%s, dtype=%r)' % (ph.name, ph.get_shape().as_list(), ph.dtype)
[docs]class UndefinedVariableError(Exception):
'''Raised when a a variable in config is not a key in variable_bindings map handed to graph_setup.'''
pass
[docs]class UnsupportedNodeError(NameError):
'''Raised when a config file calls a function that is not defined, i.e., has not been imported, or is not in the
node_ops base file.'''
pass
[docs]class RandomNodeFunctionError(KeyError):
'''Raised when something strange happened with a node function call.'''
pass
[docs]class MissingTensorError(Exception):
'''Raised when a tensor is described by name only in the graph and it is not in a dictionary.'''
pass
[docs]class MissingDataError(Exception):
'''Raised when data needed to determine shapes is not found in the :any:`DataSet`.'''
pass
[docs]class ProcessLookupError(Exception):
'''Raised when lookup receives a dataname argument without a corresponding value in it's :any:`DataSet`
and there is not already a Placeholder with that name.'''
pass
[docs]class GraphMarkerError(Exception):
'''Raised when leading character of a line (other than first)
in a graph config file is not the specified level marker.'''
pass
[docs]class AntGraph(object):
"""
Object to store graph information from graph built with config file.
:param config: A plain text config file
:param tensordict: A dictionary of premade tensors represented in the config by key
:param placeholderdict: A dictionary of premade placeholder tensors represented in the config by key
:param data: A dictionary of data matrices with keys corresponding to placeholder names in graph.
:param function_map: A dictionary of function_handle:node_op pairs to use in building the graph
:param imports: A dictionary of module_name:path_to_module key value pairs for custom node_ops modules.
:param marker: The marker for representing graph structure
:param variable_bindings: A dictionary with entries of the form *variable_name:value* for variable replacement in config file.
:param graph_name: The name of the graph. Will be used to name the graph pdf file.
:param graph_dest: The folder to write the graph pdf and graph dot string to.
:param develop: True|False. Whether to print tensor info, while constructing the tensorflow graph.
"""
def __init__(self, config, tensordict={}, placeholderdict={}, data=None, function_map={},
imports={}, marker='-', variable_bindings=None, graph_name='no_name',
graph_dest='antpics/', develop=False):
self.marker = marker
if data and type(data) is not dict:
raise TypeError('Data argument to AntGraph constructor must be a python dictionary with keys, corresponding'
'to placeholder names, and values of numpy arrays, scipy sparse csr_matrices, or HotIndex objects.')
self.data = data
self._tensordict = tensordict
self._placeholderdict = placeholderdict
self.develop = develop
#==========================================================================
#==========================Node Function Extensions========================
#==========================================================================
NODE_GLOBALS.update(function_map)
self._import_node_files(imports)
#===============================================================================
#==================Make Graph===================================================
#===============================================================================
with open(config, 'r') as config_file:
graph_spec = config_file.read().strip() # remove whitespace at end of file
graph_spec = self._substitute_variables(graph_spec, variable_bindings).split('\n')
outputs, node_names = self._get_edges(graph_spec)
self._dotstring = 'digraph ' + graph_name + ' {'
self._add_nodes(node_names)
output_list = []
for subgraph in outputs:
output_list.append(self._traverse_graph(subgraph))
if len(output_list) == 1:
self._tensor_out = output_list[0]
else:
self._tensor_out = output_list
#===============================================================================
#==================Make Graphviz Dot Picture====================================
#===============================================================================
self._dotstring += '\n}'
if not graph_dest.endswith('/'):
graph_dest += '/'
self._path_to_graph_pic = graph_dest + graph_name + '.pdf'
os.system('mkdir ' + graph_dest)
with open(graph_dest + graph_name + '.dot', 'w') as dot_file:
dot_file.write(self._dotstring)
os.system('dot -Tpdf -o ' + graph_dest + graph_name + '.pdf ' + graph_dest + graph_name + '.dot')
#===============================================================================
#==================PROPERTIES===================================================
#===============================================================================
@property
def tensordict(self):
'''
A dictionary of tensors which are nodes in the graph.
'''
return self._tensordict
@property
def placeholderdict(self):
'''
A dictionary of tensors which are placeholders in the graph. The key should correspond to the key of
the corresponding data in a data dictionary.
'''
return self._placeholderdict
@property
def tensor_out(self):
'''
Tensor or list of tensors returned from last node of graph.
'''
return self._tensor_out
#===============================================================================
#==================INSTANCE METHODS=============================================
#===============================================================================
[docs] def display_graph(self, pdfviewer='okular'):
"""
Display the pdf image of graph from config file to screen.
"""
os.system(pdfviewer + ' ' + self._path_to_graph_pic + ' &')
[docs] def get_array(collection_name, index, session, graph):
#return(graph.get_tensor_by_name)
return session.run(tf.get_collection(collection_name)[index])
#===============================================================================
#==================PRIVATE METHODS==============================================
#===============================================================================
def _traverse_graph(self, graph):
"""
This is a postorder 'tree' traversal with possibly repeated non-looping nodes.
"""
if len(graph) == 1:
outspec = graph[0]
return self._make_tensor(outspec)
else:
vertex_name = graph[0].strip().split()[0].strip(self.marker[0])
t_list = []
edges, node_names = self._get_edges(graph[1:len(graph)])
self._add_edges(vertex_name, node_names)
for end_node in edges:
t_list.append(self._traverse_graph(end_node))
spec = graph[0]
if len(edges) == 1:
tensor_out = self._make_tensor(spec, intensors=t_list[0])
else:
tensor_out = self._make_tensor(spec, intensors=t_list)
return tensor_out
def _make_tensor(self, spec, intensors=None):
'''
Parses a line from config file to make a tensor.
'''
spec = spec.strip().split()
name = spec[0].strip(self.marker[0])
double_comma = re.compile(',,') # fix for baffling error
function_spec = ''.join(spec[1:len(spec)])
function_params = function_spec.split('(')
func = function_params[0]
if name in self._tensordict:
return self._tensordict[name]
elif len(spec) > 1:
params = function_params[1].strip(')')
if intensors is not None:
params = double_comma.sub(',', 'intensors,' + params + ',name=name')
params = params.strip(',')
else:
params = double_comma.sub(',', params + ',name=name')
params = params.strip(',')
if func == 'placeholder':
return self._process_placeholder(func, params, name)
elif func == 'lookup':
return self._process_lookup(func, params, name)
else:
function_call = func + '(' + params + ')'
try:
self._tensordict[name] = eval(function_call, NODE_GLOBALS, locals())
if self.develop:
heading = 'Node %s: %s' % (name, self.tensordict[name])
print(termcolor.colored(heading, 'green'))
print('\tFunction Call: %s\n\tTensor Inputs:\n\t\t' %
(function_call), end="")
if type(intensors) is list:
print(*intensors, sep='\n\t\t')
else:
print(intensors)
except NameError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above | "
"Input Tensors Below============================", 'red'))
print('Input Tensors:\n\t', end="")
if type(intensors) is list:
print(*intensors, sep='\n\t')
else:
print(intensors)
raise NameError('\nFunction Call: %s\n intensors: %r' % (function_call, intensors))
except TypeError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above |"
"Input Tensors Below============================", 'red'))
print('Input Tensors:\n\t', end="")
if type(intensors) is list:
print(*intensors, sep='\n\t')
else:
print(intensors)
raise TypeError('\nFunction Call: %s\n intensors: %r' % (function_call, intensors))
except ValueError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above |"
"Input Tensors Below============================", 'red'))
print('Input Tensors:\n\t', end="")
if type(intensors) is list:
print(*intensors, sep='\n\t')
else:
print(intensors)
raise ValueError('\nFunction Call: %s\n intensors: %r' % (function_call, intensors))
return self._tensordict[name]
else:
raise MissingTensorError('Name %s: from config file is not in the tensor or placeholder '
'dictionary so it must have a function call.' % name)
def _get_edges(self, graph):
'''
Gets subgraphs and names of parent nodes.
'''
level = self._get_level(graph[0])
node_names = [graph[0].strip().split()[0].strip(self.marker[0])]
list = [0]
for i in range(1, len(graph)):
if self._get_level(graph[i]) == level:
list.append(i)
node_names.append(graph[i].strip().split()[0].strip(self.marker[0]))
list.append(len(graph))
intensors = []
for i in range(0, len(list) - 1):
intensors.append(graph[list[i]:list[i+1]])
return intensors, node_names
def _get_level(self, line):
'''
Find level of line from graph markers.
'''
spot = 0
level = 0
line = line.strip()
while line[spot] == self.marker[0]:
level += 1
spot += 1
if level % len(self.marker) != 0:
raise GraphMarkerError('Need multiples of %s %s to delimit edges. Line: %s'
% (len(self.marker), self.marker, line))
return level
def _process_placeholder(self, func=None, params=None, name=None):
"""
Special treatment for placeholders which may be data dependent.
"""
if name not in self._placeholderdict:
if self.data is not None and name in self.data:
params += ', data=self.data[name]'
else:
raise MissingDataError('There is no data called %s in the DataSet for this AntGraph.' % name)
function_call = func + '(' + params + ')'
try:
self._placeholderdict[name] = eval(function_call, NODE_GLOBALS, locals())
if self.develop:
heading = 'Node %s: %r' % (name, self.placeholderdict[name])
print(termcolor.colored(heading, 'green'))
print('\tFunction Call: %s\n\tInput Data: %r' %
(function_call, self.data[name]))
return self._placeholderdict[name]
except NameError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise NameError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[name], name))
except TypeError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise TypeError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[name], name))
except ValueError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise ValueError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[name], name))
return self._placeholderdict[name]
def _process_lookup(self, func=None, params=None, name=None):
"""
Special treatment for lookup function which may be data dependent.
"""
paramlist = params.split(',')
dataname = None
for p in paramlist:
if p.startswith('dataname='):
dataname = p.split('=')[1]
if dataname is not None:
dataname = dataname.strip("'")
if dataname in self._placeholderdict:
params += ', makeplace=False, indices=self._placeholderdict[dataname], ' \
'data=self.data[dataname]'
elif dataname in self.data:
params += ', data=self.data[dataname]'
else:
function_call = func + '(' + params + ')'
raise ProcessLookupError('"%s" is not a key in the data dictionary for this AntGraph. '
'Need to provide a valid dataname argument for lookup without tensor input.'
' \nCall: %s' % (dataname, function_call))
function_call = func + '(' + params + ')'
try:
vals = eval(function_call, NODE_GLOBALS, locals())
self._tensordict[name] = vals[0]
self._tensordict[name + '_weights'] = vals[1]
self._placeholderdict[dataname] = vals[2]
if self.develop:
heading = 'Node %s: %s' % (name, self.tensordict[name])
print(termcolor.colored(heading, 'green'))
print('\tFunction Call: %s\n\tPlaceholder: %s\n\tWeights: %s\n\tInput Data: %r\n\t' %
(function_call, ph_rep(vals[1]), vals[2], self.data[dataname]))
return self._tensordict[name]
except NameError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise NameError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[dataname], dataname))
except TypeError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise TypeError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[dataname], dataname))
except ValueError as e:
traceback.print_exc()
print(termcolor.colored("==========================Original Handled Exception Above============================", 'red'))
raise ValueError('\nFunction Call: %s\nname=%r\ndata: %r hash=%s)' %
(function_call, name, self.data[dataname], dataname))
def _add_nodes(self, node_names):
"""
Add nodes to graphviz dot string.
"""
for node in node_names:
self._dotstring += '\n\t' + node + ';'
def _add_edges(self, start, dest):
"""
Add edges to graphviz dot string.
"""
self._dotstring += '\n\t' + start + ' -> {'
for node in dest:
self._dotstring += node + ','
self._dotstring = self._dotstring.strip(',')
self._dotstring += '} [dir=back];'
def _substitute_variables(self, graph_spec, variable_bindings):
"""
String substitution of graph text marked as variables.
"""
has_marker = False
test_graph = graph_spec.split('\n')
for line in test_graph:
if line.strip().startswith(self.marker):
has_marker = True
if not has_marker:
raise GraphMarkerError("There are no instances of the chosen "
"marker '%s' in the graph config file." % self.marker)
if variable_bindings is None:
indice = graph_spec.find('$')
if indice >= 0:
raise UndefinedVariableError('Need variable_bindings argument in call to AntGraph to bind '
'variable beginning: %s' % graph_spec[indice:indice+10])
else:
for symbol in variable_bindings:
replacee = '$' + symbol
if graph_spec.find(replacee) >= 0:
if type(variable_bindings[symbol]) is str:
graph_spec = graph_spec.replace(replacee, "'" + str(variable_bindings[symbol] + "'"))
else:
graph_spec = graph_spec.replace(replacee, str(variable_bindings[symbol]))
else:
raise UndefinedVariableError('%s is not mentioned in config file.' % replacee)
indice = graph_spec.find('$')
if indice >= 0:
variable = graph_spec[indice+1:len(graph_spec)]
if variable.find(',') >= 0:
variable = variable.split(',')[0] #parameter in middle of function call
elif variable.find(')') >= 0:
variable = variable.split(')')[0] #parameter at end of function call
else:
raise RandomNodeFunctionError('You forgot a parenthesis.')
raise UndefinedVariableError('%s was not bound. Include %s in '
'variable_bindings dictionary' % (variable, variable))
return graph_spec
def _import_node_files(self, files):
'''
Import node functions from modules in import parameter of constructor.
'''
for name in files:
try:
if files[name] is not None:
sys.path.append(files[name])
m = __import__(name=name, globals=globals(), locals=locals(), fromlist="*")
try:
attrlist = m.__all__
except AttributeError:
attrlist = dir(m)
for attr in [a for a in attrlist if '__' not in a]:
NODE_GLOBALS[attr] = getattr(m, attr)
except ImportError, e:
sys.stderr.write('Unable to read %s/%s.py\n' % (files[name], name))
sys.exit(1)
# ====================================================================
# ===========Graph Format Testing ====================================
# ====================================================================
[docs]def testGraph(config, marker='-', graph_dest='antpics/', graph_name='test_graph'):
"""
:param config: A graph specification in .config format.
:param marker: A character or string of characters to delimit graph edges.
:param graph_dest: Where to save the graphviz pdf and associated dot file.
:param graph_name: A name for the graph (without extension)
"""
with open(config, 'r') as config_file:
graph_spec = config_file.read().strip().split('\n') # remove whitespace at end of file
outputs, node_names = _get_edges(graph_spec, marker)
dotstring = 'digraph test_graph' + ' {'
dotstring = _add_nodes(node_names, dotstring)
for subgraph in outputs:
dotstring = _traverse_graph(subgraph, marker, dotstring)
dotstring += '\n}'
if not graph_dest.endswith('/'):
graph_dest += '/'
path_to_graph_pic = graph_dest + graph_name + '.pdf'
os.system('mkdir ' + graph_dest)
with open(graph_dest + graph_name + '.dot', 'w') as dot_file:
dot_file.write(dotstring)
os.system('dot -Tpdf -o ' + path_to_graph_pic + ' ' + graph_dest + graph_name + '.dot')
os.system('okular ' + path_to_graph_pic + ' &')
def _traverse_graph(graph, marker, dotstring):
"""
This is a postorder 'tree' traversal with possibly repeated non-looping nodes.
"""
if len(graph) == 1:
return dotstring
else:
vertex_name = graph[0].strip().split()[0].strip(marker[0])
edges, node_names = _get_edges(graph[1:len(graph)], marker)
dotstring = _add_edges(vertex_name, node_names, dotstring)
for end_node in edges:
dotstring = _traverse_graph(end_node, marker, dotstring)
return dotstring
def _get_edges(graph, marker):
'''
Gets subgraphs and names of parent nodes.
'''
level = _get_level(graph[0], marker)
node_names = [graph[0].strip().split()[0].strip(marker[0])]
list = [0]
for i in range(1, len(graph)):
if _get_level(graph[i], marker) == level:
list.append(i)
node_names.append(graph[i].strip().split()[0].strip(marker[0]))
list.append(len(graph))
intensors = []
for i in range(0, len(list) - 1):
intensors.append(graph[list[i]:list[i+1]])
return intensors, node_names
def _get_level(line, marker):
'''
Find level of line from graph markers.
'''
spot = 0
level = 0
line = line.strip()
while line[spot] == marker[0]:
level += 1
spot += 1
if level % len(marker) != 0:
raise GraphMarkerError('Need multiples of %s %s to delimit edges. Line: %s'
% (len(marker), marker, line))
return level
def _add_nodes(node_names, dotstring):
"""
Add nodes to graphviz dot string.
"""
for node in node_names:
dotstring += '\n\t' + node + ';'
return dotstring
def _add_edges(start, dest, dotstring):
"""
Add edges to graphviz dot string.
"""
dotstring += '\n\t' + start + ' -> {'
for node in dest:
dotstring += node + ','
dotstring = dotstring.strip(',')
dotstring += '} [dir=back];'
return dotstring