{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Defining your own custom criterion for dynamics evaluation\n", "Defining your own custom criterion can be very useful as the same instance you can either use it for analysing training dynamics or you can use it as IB-based loss objective for HSIC networks. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing glow modules" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# importing Estimator class from glow\n", "from glow.information_bottleneck import Estimator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom criterion class\n", "Following is the general structure which your custom criterion class should follow in order for the PyGlow to detect your criterion function.\n", "NOTE: In most of the cases you won't require implementation for eval_dynamics_segment so only criterion function is sufficient for PyGlow to accept your instance. \n", "You need eval_dynamics_segment implementation when you totally have to change the way you want to process a single dynamics segment." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class CustomCriterion(Estimator):\n", " \"\"\"\n", " This is a custom class for implementing your\n", " own criterion function (measure of dependence\n", " between random variables)\n", "\n", " Arguments:\n", " function (callable, optional): this is an example argument\n", " which denotes any function which your criterion demands\n", " gpu (bool): if true then all the computation is done on\n", " `GPU` otherwise on `CPU`\n", " **kwargs: The keyword arguments store all the parameters\n", " that are essential for the criterion function\n", "\n", " \"\"\"\n", " def __init__(self, function, gpu, **kwargs):\n", " super().__init__(gpu, **kwargs)\n", " self.func = function # attributes other than parameters\n", "\n", " def criterion(self, x, y):\n", " \"\"\"\n", " Custom criterion function.\n", "\n", " Arguments:\n", " x (torch.Tensor): contains random variable one\n", " y (torch.Tensor): contains random variable two\n", " \n", " Returns:\n", " (torch.Tensor): calculated criterion of the two random variables 'x' and 'y'\n", "\n", " \"\"\"\n", " # Define your logic for calculating criterion here\n", " # TODO\n", "\n", " \n", " # implement this only if you want to totally change the\n", " # logic of evaluating a single dyanmics segment else\n", " # there is inbuilt function that handles it.\n", " def eval_dynamics_segment(self, dynamics_segment):\n", " \"\"\"\n", " Evaluates criterion for a dynamics segment using the\n", " criterion defined in the class itself.\n", "\n", " Arguments:\n", " dynamics_segment(iterable): this is a list of :class:`torch.Tensor`\n", " type objects whose first entry is input batch, last entry is labels\n", " batch and intermediate entries are hidden layers outputs batch\n", " \n", " Returns:\n", " (iterable): list of calculated coordinates according to the criterion \n", " with length equal to 'len(dynamics_segment)-2'\n", "\n", " \"\"\"\n", " segment_size = len(dynamics_segment)\n", " # NOTE: dynamics_segment is a list where\n", " # input = dynamics_segment[0]\n", " # label = dynamics_segment[segment_size - 1]\n", " # and remaining tensors are hidden layers output\n", " output_segment = [] # output list which contains calculated coordinates\n", " for idx in range(1, segment_size-1):\n", " h = dynamics_segment[idx]\n", " # Define your logic of using h, input and label here\n", " # to apply your custom criterion on them.\n", " output_segment.append([self.criterion(h, x), self.criterion(h, y)])\n", "\n", " return output_segment" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }