"""MIT LicenseCopyright (c) 2018 Benjamin BastianPermission is hereby granted, free of charge, to any person obtaining a copyof this software and associated documentation files (the "Software"), to dealin the Software without restriction, including without limitation the rightsto use, copy, modify, merge, publish, distribute, sublicense, and/or sellcopies of the Software, and to permit persons to whom the Software isfurnished to do so, subject to the following conditions:The above copyright notice and this permission notice shall be included in allcopies or substantial portions of the Software.THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS ORIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THEAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHERLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THESOFTWARE.A module for a mixture density network layerFor more info on MDNs, see _Mixture Desity Networks_ by Bishop, 1994.https://github.com/sagelywizard/pytorch-mdn/blob/master/mdn/mdn.py"""import mathimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.autograd import Variablefrom torch.distributions import CategoricalONEOVERSQRT2PI = 1.0 / math.sqrt(2 * math.pi)class MDN(nn.Module):    """A mixture density network layer    The input maps to the parameters of a MoG probability distribution, where    each Gaussian has O dimensions and diagonal covariance.    Arguments:        in_features (int): the number of dimensions in the input        out_features (int): the number of dimensions in the output        num_gaussians (int): the number of Gaussians per output dimensions    Input:        minibatch (BxD): B is the batch size and D is the number of input            dimensions.    Output:        (pi, sigma, mu) (BxG, BxGxO, BxGxO): B is the batch size, G is the            number of Gaussians, and O is the number of dimensions for each            Gaussian. Pi is a multinomial distribution of the Gaussians. Sigma            is the standard deviation of each Gaussian. Mu is the mean of each            Gaussian.    """    def __init__(self, in_features, out_features, num_gaussians):        super(MDN, self).__init__()        self.in_features = in_features        self.out_features = out_features        self.num_gaussians = num_gaussians        if self.num_gaussians > 1:            self.pi = nn.Sequential(                nn.Linear(in_features, num_gaussians), nn.Softmax(dim=1)            )        else:            self.pi = None        self.sigma = nn.Linear(in_features, out_features * num_gaussians)        self.mu = nn.Linear(in_features, out_features * num_gaussians)    def forward(self, minibatch):        if self.pi:            pi = self.pi(minibatch)        else:            pi = None        sigma = torch.exp(self.sigma(minibatch))        sigma = sigma.view(-1, self.num_gaussians, self.out_features)        mu = self.mu(minibatch)        mu = mu.view(-1, self.num_gaussians, self.out_features)        return sigma, mu, pidef gaussian_probability(sigma, mu, target):    """Returns the probability of `target` given MoG parameters `sigma` and `mu`.    Arguments:        sigma (BxGxO): The standard deviation of the Gaussians. B is the batch            size, G is the number of Gaussians, and O is the number of            dimensions per Gaussian.        mu (BxGxO): The means of the Gaussians. B is the batch size, G is the            number of Gaussians, and O is the number of dimensions per Gaussian.        target (BxI): A batch of target. B is the batch size and I is the number of            input dimensions.    Returns:        probabilities (BxG): The probability of each point in the probability            of the distribution in the corresponding sigma/mu index.    """    target = target.unsqueeze(1).expand_as(sigma)    ret = ONEOVERSQRT2PI * torch.exp(-0.5 * ((target - mu) / sigma) ** 2) / sigma    return torch.prod(ret, 2)def mdn_loss(sigma, mu, target, pi=None):    """Calculates the error, given the MoG parameters and the target    The loss is the negative log likelihood of the data given the MoG    parameters.    """    if pi is not None:        prob = pi * gaussian_probability(sigma, mu, target)    else:        prob = gaussian_probability(sigma, mu, target)    nll = -torch.log(torch.sum(prob, dim=1))    return torch.mean(nll)def sample(sigma, mu, device="cpu", pi=None):    """Draw samples from a MoG."""    # Choose which gaussian we'll sample from    if pi is not None:        pis = Categorical(pi).sample().view(pi.size(0), 1, 1)    # Choose a random sample, one randn for batch X output dims    # Do a (output dims)X(batch size) tensor here, so the broadcast works in    # the next step, but we have to transpose back.    if pi is not None:        variance_samples = sigma.gather(1, pis).detach().squeeze()        mean_samples = mu.detach().gather(1, pis).squeeze()        return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1)    else:        gaussian_noise = torch.randn_like(sigma, requires_grad=False).to(device)        return (gaussian_noise * sigma + mu).transpose(0, 1)