Source code for generic_model

import scipy.sparse as sps
import tensorflow as tf
import numpy as np
import time
from antk.core import loader
import os
import datetime
import matplotlib.pyplot as plt
from pprint import pprint
# ============================================================================================
# ============================CONVENIENCE DICTIONARY==========================================
# ============================================================================================
OPT = {'adam': tf.train.AdamOptimizer,
       'ada': tf.train.AdagradOptimizer,
       'grad': tf.train.GradientDescentOptimizer,
       'mom': tf.train.MomentumOptimizer}

# ============================================================================================
# ============================GLOBAL MODULE FUNCTIONS=========================================
# ============================================================================================
[docs]def get_feed_list(batch, placeholderdict, supplement=None, train=1, debug=False): """ :param batch: A dataset object. :param placeholderdict: A dictionary where the keys match keys in batch, and the values are placeholder tensors :param supplement: A dictionary of numpy input matrices with keys corresponding to placeholders in placeholderdict, where the row size of the matrices do not correspond to the number of datapoints. For use with input data intended for `embedding_lookup`_. :param dropouts: Dropout tensors in graph. :param dropout_flag: Whether to use Dropout probabilities for feed forward. :return: A feed dictionary with keys of placeholder tensors and values of numpy matrices """ ph, dt = [], [] datadict = batch.features.copy() datadict.update(batch.labels) if supplement: datadict.update(supplement) for desc in placeholderdict: ph.append(placeholderdict[desc]) if sps.issparse(datadict[desc]): dt.append(datadict[desc].todense().astype(float, copy=False)) elif type(datadict[desc]) is loader.HotIndex: dt.append(datadict[desc].vec) else: dt.append(datadict[desc]) if debug: print('%s\n\tph: %s\n\tdt: %s' % (desc, placeholderdict[desc].get_shape().as_list(), datadict[desc].shape)) dropouts = tf.get_collection('dropout_prob') if dropouts: for prob in dropouts: ph.append(prob[0]) if train == 1: dt.append(prob[1]) else: dt.append(1.0) fd = {i: d for i, d in zip(ph, dt)} bn_deciders = tf.get_collection('bn_deciders') if bn_deciders: fd.update({decider:[train] for decider in bn_deciders}) return fd
[docs]def parse_summary_val(summary_str): """ Helper function to parse numeric value from tf.scalar_summary :param summary_str: Return value from running session on tf.scalar_summary :return: A dictionary containing the numeric values. """ summary_proto = tf.Summary() summary_proto.ParseFromString(summary_str) summaries = {} for val in summary_proto.value: summaries[val.tag] = val.simple_value return summaries
# ============================================================================================ # ============================GENERIC MODEL CLASS============================================= # ============================================================================================
[docs]class Model(object): """ Generic model builder for training and predictions. :param objective: Loss function :param placeholderdict: A dictionary of placeholders :param maxbadcount: For early stopping :param momentum: The momentum for tf.MomentumOptimizer :param mb: The mini-batch size :param verbose: Whether to print dev error, and save_tensor evals :param epochs: maximum number of epochs to train for. :param learnrate: learnrate for gradient descent :param save: Save best model to *best_model_path*. :param opt: Optimization strategy. May be 'adam', 'ada', 'grad', 'momentum' :param decay: Parameter for decaying learn rate. :param evaluate: Evaluation metric :param predictions: Predictions selected from feed forward pass. :param logdir: Where to put the tensorboard data. :param random_seed: Random seed for TensorFlow initializers. :param model_name: Name for model :param clip_gradients: The limit on gradient size. If 0.0 no clipping is performed. :param make_histograms: Whether or not to make histograms for model weights and activations :param best_model_path: File to save best model to during training. :param save_tensors: A hashmap of str:Tensor mappings. Tensors are evaluated during training. Evaluations of these tensors on best model are accessible via property :any:`evaluated_tensors`. :param tensorboard: Whether to make tensorboard histograms of weights and activations, and graphs of dev_error. :return: :any:`Model` """ def __init__(self, objective, placeholderdict, maxbadcount=20, momentum=None, mb=1000, verbose=True, epochs=50, learnrate=0.003, save=False, opt='grad', decay=[1, 1.0], evaluate=None, predictions=None, logdir='log/', random_seed=None, model_name='generic', clip_gradients=0.0, make_histograms=False, best_model_path='/tmp/model.ckpt', save_tensors={}, tensorboard=False, train_evaluate=None, debug=False): self.objective = objective self.debug = debug for t in tf.get_collection('losses'): self.objective += t self._placeholderdict = placeholderdict self.maxbadcount = maxbadcount self.momentum = momentum self.mb = mb self.verbose = verbose self.epochs = epochs self.learnrate = learnrate self.save = save self.opt = opt self.decay = decay self.epoch_times = [] self.evaluate = evaluate self.train_evaluate = train_evaluate self._best_dev_error = float('inf') self.predictor = predictions self.random_seed = random_seed self.session = tf.Session() if self.random_seed is not None: tf.set_random_seed(self.random_seed) self.model_name = model_name self.clip_gradients = clip_gradients self.tensorboard = tensorboard self.make_histograms = make_histograms if self.make_histograms: self.tensorboard = True self.histogram_summaries = [] if not logdir.endswith('/'): self.logdir = logdir + '/' else: self.logdir = logdir os.system('mkdir ' + self.logdir) self.save_tensors = save_tensors self._completed_epochs = 0.0 self._best_completed_epochs = 0.0 self._evaluated_tensors = {} self.deverror = [] self._badcount = 0 self.batch = tf.Variable(0) self.train_eval = [] self.dev_spot = [] self.train_spot = [] # ================================================================ # ======================For tensorboard=========================== # ================================================================ if tensorboard: self._init_summaries() # ============================================================================= # ===================OPTIMIZATION STRATEGY===================================== # ============================================================================= optimizer = OPT[self.opt] decay_step = self.decay[0] decay_rate = self.decay[1] global_step = tf.Variable(0, trainable=False) #keeps track of the mini-batch iteration if not (decay_step == 1 and decay_rate == 1.0): self.learnrate = tf.train.exponential_decay(self.learnrate, self.batch*self.mb, decay_step, decay_rate, name='learnrate_decay') if self.clip_gradients > 0.0: params = tf.trainable_variables() self.gradients = tf.gradients(self.objective, params) if self.clip_gradients > 0.0: self.gradients, self.gradients_norm = tf.clip_by_global_norm( self.gradients, self.clip_gradients) grads_and_vars = zip(self.gradients, params) if self.opt == 'mom': self.train_step = optimizer(self.learnrate, self.momentum).apply_gradients(grads_and_vars, global_step=self.batch, name="train") else: self.train_step = optimizer(self.learnrate).apply_gradients(grads_and_vars, global_step=self.batch, name="train") else: if self.opt == 'mom': self.train_step = optimizer(self.learnrate, self.momentum).minimize(self.objective, global_step=self.batch) else: self.train_step = optimizer(self.learnrate).minimize(self.objective, global_step=self.batch) # ============================================================================= # ===================Initialize graph ===================================== # ============================================================================= self.session.run(tf.initialize_all_variables()) if save: self.saver = tf.train.Saver() self.best_model_path = best_model_path self.save_path = self.saver.save(self.session, self.best_model_path) # ====================================================================== # ================Properites============================================ # ====================================================================== @property def placeholderdict(self): ''' Dictionary of model placeholders ''' return self._placeholderdict @property def best_dev_error(self): """ The best dev error reached during training. """ return self._best_dev_error @property def average_secs_per_epoch(self): """ The average number of seconds to complete an epoch. """ return np.sum(np.array(self.epoch_times))/self._completed_epochs @property def evaluated_tensors(self): ''' A dictionary of evaluations on best model for tensors and keys specified by *save_tensors* argument to constructor. ''' return self._evaluated_tensors @property def completed_epochs(self): ''' Number of epochs completed during training (fractional) ''' return self._completed_epochs @property def best_completed_epochs(self): ''' Number of epochs completed during at point of best dev eval during training (fractional) ''' return self._best_completed_epochs
[docs] def plot_train_dev_eval(self, figure_file='testfig.pdf'): plt.plot(self.dev_spot, self.deverror, label='dev') plt.plot(self.train_spot, self.train_eval, label='train') plt.ylabel('Error') plt.xlabel('Epoch') plt.legend(loc='upper right') plt.savefig(figure_file)
[docs] def predict(self, data, supplement=None): """ :param data: :any:`DataSet` to make predictions from. :return: A set of predictions from feed forward defined by :any:`self.predictions` """ fd = get_feed_list(data, self.placeholderdict, supplement=supplement, train=0, debug=self.debug) return self.session.run(self.predictor, feed_dict=fd)
[docs] def eval(self, tensor_in, data, supplement=None): """ Evaluation of model. :param data: :any:`DataSet` to evaluate on. :return: Result of evaluating on data for :any:`self.evaluate` """ fd = get_feed_list(data, self.placeholderdict, supplement=supplement, train=0, debug=self.debug) return self.session.run(tensor_in, feed_dict=fd)
[docs] def train(self, train, dev=None, supplement=None, eval_schedule='epoch', train_dev_eval_factor=3): """ :param data: :any:`DataSet` to train on. :return: A trained :any:`Model` """ self._completed_epochs = 0.0 if self.save: self.saver.restore(self.session, self.best_model_path) # ======================================================== # ===========Check data to see if dev eval================ # ======================================================== if eval_schedule == 'epoch': eval_schedule = train.num_examples self._badcount = 0 start_time = time.time() # ============================================================================================ # =============================TRAINING======================================================= # ============================================================================================ counter = 0 train_eval_counter = 0 while self._completed_epochs < self.epochs: # keeps track of the epoch iteration # ==============PER MINI-BATCH===================================== newbatch = train.next_batch(self.mb) fd = get_feed_list(newbatch, self.placeholderdict, supplement, debug=self.debug) self.session.run(self.train_step, feed_dict=fd) counter += self.mb train_eval_counter += self.mb self._completed_epochs += float(self.mb)/float(train.num_examples) if self.train_evaluate is not None and train_eval_counter >= train_dev_eval_factor*eval_schedule: self.train_eval.append(self.eval(self.evaluate, train, supplement)) self.train_spot.append(self._completed_epochs) if np.isnan(self.train_eval[-1]): print("Aborting training...train evaluates to nan.") break if self.verbose: print("epoch: %f train eval: %.10f" % (self._completed_epochs, self.train_eval[-1])) train_eval_counter = 0 if (counter >= eval_schedule or self._completed_epochs >= self.epochs): #=================PER eval_schedule================================== self._log_summaries(dev, supplement) counter = 0 if dev: self.deverror.append(self.eval(self.evaluate, dev, supplement)) self.dev_spot.append(self._completed_epochs) if np.isnan(self.deverror[-1]): print("Aborting training...dev evaluates to nan.") break if self.verbose: print("epoch: %f dev error: %.10f" % (self._completed_epochs, self.deverror[-1])) for tname in self.save_tensors: self._evaluated_tensors[tname] = self.eval(self.save_tensors[tname], dev, supplement) if self.verbose: print("\t%s: %s" % (tname, self._evaluated_tensors[tname])) # ================Early Stopping==================================== if self.deverror[-1] < self.best_dev_error: self._badcount = 0 self._best_dev_error = self.deverror[-1] if self.save: self.save_path = self.saver.save(self.session, self.best_model_path) self._best_completed_epochs = self._completed_epochs else: self._badcount += 1 if self._badcount > self.maxbadcount: print('badcount exceeded: %d' % self._badcount) break # ================================================================== self.epoch_times.append(time.time() - start_time) start_time = time.time()
# ================================================================ # ======================For tensorboard=========================== # ================================================================ def _init_summaries(self): if self.make_histograms: self.histogram_summaries.extend(map(tf.histogram_summary, [var.name for var in tf.trainable_variables()], tf.trainable_variables())) self.histogram_summaries.extend(map(tf.histogram_summary, ['normalization/'+n.name for n in tf.get_collection('normalized_activations')], tf.get_collection('normalized_activations'))) self.histogram_summaries.extend(map(tf.histogram_summary, ['activation/'+a.name for a in tf.get_collection('activation_layers')], tf.get_collection('activation_layers'))) self.loss_summary = tf.scalar_summary('Loss', self.objective) self.dev_error_summary = tf.scalar_summary('dev_error', self.evaluate) summary_directory = os.path.join(self.logdir, self.model_name + '-' + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) self._summary_writer = tf.train.SummaryWriter(summary_directory, self.session.graph.as_graph_def()) def _log_summaries(self, dev, supplement): fd = get_feed_list(dev, self.placeholderdict, supplement=supplement, train=0, debug=self.debug) if self.tensorboard: if self.make_histograms: sum_str = self.session.run(self.histogram_summaries, fd) for summary in sum_str: self._summary_writer.add_summary(summary, self._completed_epochs) loss_sum_str = self.session.run(self.loss_summary, fd) self._summary_writer.add_summary(loss_sum_str, self._completed_epochs) if dev: if self.tensorboard: dev_sum_str = self.session.run(self.dev_error_summary, fd) self._summary_writer.add_summary(dev_sum_str, self._completed_epochs)