Gradient descent for fitting models to stars

This is a topic I've written about before, but I wanted to update my code and simplify things. You can clone the GitHub repo if you'd like an example to run for yourself as well.

Basically, the goal is to fit a two dimensional Gaussian distribution to a star in astronomical data (though it'd be easy to generalize this algorithm to the N-dimensional case).

The Gaussian is a good model in most cases, and it's easy to compute; in the past I tried fitting the Moffat function as well, but I found its parameter \(\beta\)  hard to fit, so an iterative method probably isn't optimal.

To start, I'll restate the gradient descent (GD) algorithm (you'll find it peppered throughout the literature because it's so well known).
\begin{align}\vec{\theta}_{\kappa+1}&=\vec{\theta}_{\kappa}-\eta\vec{\nabla}_{\vec{\theta}}\text{cost}(\vec{\theta}_{\kappa}),\end{align}
where
\begin{align}\text{cost}(\vec{\theta}_{\kappa}):=\dfrac{1}{MN}\displaystyle\sum_{i,j=1}^{N,M}[m(i,j;\vec{\theta}_{\kappa})-I(i,j)]^{2}.\end{align}
Here we substitute for the model \(m\) the Gaussian, defined as
\begin{align}G(x, y; \alpha,\mu_{x},\mu_{y},\sigma)&=\dfrac{\alpha}{2\pi\sigma^2}\exp\left[\dfrac{-(x-\mu_{x})^2-(y-\mu_{y})^2}{2\sigma^2}\right].\end{align}
I've written this function and its partial derivatives (to form \(\vec{\nabla}_{\vec{\theta}}\,\text{cost}\)) as follows. Because they'll be called so often, I found that using Numba's jit decorator sped the process up considerably (of which you can also see an example use on the project's home page).
#! /usr/bin/env python3.6
# -*- coding: utf-8 -*-

from numpy.linalg import norm
from math import pi, log, exp, sqrt
from numba import jit


@jit
def gaussian(X: int, Y: int, alpha: float, mu_X: float, mu_Y: float, sigma: float) -> float:
    # compute the 2d gaussian function at this point
    res = alpha * exp(-(((X - mu_X) / sigma) ** 2 + ((Y - mu_Y) / sigma) ** 2) / 2) / (2 * pi * sigma * sigma)

    return res


@jit
def partialalphagaussian(X: int, Y: int, *args: float) -> float:
    return gaussian(X, Y, *args) / args[0]


@jit
def partialsigmagaussian(X: int, Y: int, *args: float) -> float:
    alpha, mu_X, mu_Y, sigma = args
    diff = norm([X - mu_X, Y - mu_Y])

    res = gaussian(X, Y, *args) * ((diff * diff) / (sigma ** 3) - 2 / sigma)

    return res


@jit
def partialmu_Xgaussian(X: int, Y: int, *args: float) -> float:
    alpha, mu_X, _, sigma = args

    res = gaussian(X, Y, *args) * (X - mu_X) / (sigma * sigma)

    return res


@jit
def partialmu_Ygaussian(X: int, Y: int, *args: float) -> float:
    alpha, _, mu_Y, sigma = args

    res = gaussian(X, Y, *args) * (Y - mu_Y) / (sigma * sigma)

    return res


_partials = [partialalphagaussian, partialmu_Xgaussian, partialmu_Ygaussian, partialsigmagaussian]


def fwhm_gaussian(*args: float):
    return 2 * sqrt(2 * log(2)) * args[-1]


def get_amplitude(X: int, Y: int, *args) -> float:
    """ Evaluate the Gaussian given some parameters at 0 """
    return args[0] / (2 * pi * args[-1])
Given these equations, I've written the following class to `digest' stellar sources into the four best-fitting parameters of \(G\).
from typing import List
from types import FunctionType as Function
from itertools import product
from warnings import warn

import numpy as np


class GD:
    """ A callable to perform GD """
    _counter: int = 0

    def __init__(self, model: Function, partials: List[Function]) -> None:
        self.model = model
        self.partials = partials

    def __call__(self, data: np.ndarray, learning_rate: float =1.0,
                 steps: int =1000, db: bool =True) -> List[float]:
        """ `Learn` the parameters of best fit for the given data and model """

        _min = data.min()
        _max = data.max()

        # scale amplitude to [0, 1]
        self.data = (data - _min) / (_max - _min)

        self.cubeX, self.cubeY = data.shape
        self.learning_rate = learning_rate
        self.steps = steps

        # perform the fit
        result = self.simplefit()

        # unscale amplitude of resultant
        result[0] = result[0] * (_max - _min) + _min

        result_as_list = result.tolist()

        self._counter += 1

        return result_as_list

    def simplefit(self) -> np.ndarray:
        """ Perform linear gradient descent """

        # Determine the center of mass of the star for a rough starting position on mu_X and mu_Y
        com_X, com_Y = sum(np.array([i, j]) * self.data[i][j] for i, j in
                    product(range(self.cubeX), range(self.cubeY))) / np.sum(self.data)

        if not (0 <= com_X <= self.cubeX and 0 <= com_Y <= self.cubeY):
            warn(f'** star {self._counter} centroid lies outside boundaries')
            com_X, com_Y = self.cubeX // 2, self.cubeY // 2

        # initialize parameters
        parameters = np.array([random(), com_X, com_Y, (com_X + com_Y) / 4])

        # train the parameters using gradient descent
        for _ in range(self.steps):
            cost = self.cost(*parameters)

            # I found in practice that a differential learning rate yielded better results
            cost[0] *= 150
            cost[3] *= 4.5

            parameters -= self.learning_rate * cost

        return parameters

    def cost(self, *args) -> np.ndarray:
        """ Get the cost of applying our model to the data as-is """
        n_args = len(args)
        total_cost = np.zeros(n_args)
        nabla_args = np.empty((n_args,))

        for i in range(self.cubeX):
            for j in range(self.cubeY):
                cost = self.model(i, j, *args) - self.data[i][j]

                # loop over and evaluate the partial derivs
                for k in range(n_args):
                    nabla_args[k] = self.partials[k](i, j, *args)

                total_cost += cost * nabla_args

        normalized = total_cost / (self.cubeX * self.cubeY)

        return normalized
Now, calling this class on a generated source, we get a pretty decent-looking fit.
The awesome part is that it only took about 60 iterations (<2 seconds) for the above fit.
Another mini-project I'm working on is pulling together various packages to build a simple sqlite3 database of sources, organized in a table by position for each image. This way I can just query the database, a common interface for pattern-matching or classification algorithms.