Source code for pychangcooper.tridiagonal_solver

import numpy as np
from numba import jit, njit


[docs]@njit(fastmath=True) def jit_forward_sweep(n, cprime, dprime, a, b, c, d): for i in range(1, n): b_minus_ac = b[i] - a[i] * cprime[i - 1] cprime[i] = c[i] / b_minus_ac dprime[i] = (d[i] - a[i] * dprime[i - 1]) / b_minus_ac
[docs]@njit(fastmath=True) def jit_backward_sweep(n, n_j_plus_1, cprime, dprime): for j in range(n - 2, -1, -1): n_j_plus_1[j] = dprime[j] - cprime[j] * n_j_plus_1[j + 1]
[docs]class TridiagonalSolver(object): def __init__(self, a, b, c): """ A tridiagonal solver for the equation: a_i x_i-1 + b_i x_i _ c_i x_i+1 = d_i :param a: the x_i-1 terms :param b: the x_i terms :param c: the x_i+1 terms """ # get the number of elements self._n_grid_points = len(a) # make sure the elements are all the same # size assert len(a) == len(b) assert len(a) == len(c) # assign the elements to the class self._a = np.array(a) self._b = np.array(b) self._c = np.array(c) # if the a terms are all zero, we do not need the # forward sweep as a is already eliminated self._a_non_zero = ~np.all(self._a == 0) # the 0th term of the c prime will always # be the same so pre compute. # is a = 0 then all terms will be the same. # this is basically the first step of the forward # sweep self._cprime = self._c / self._b def _forward_sweep(self, d): """ This is the forward sweep of the tridiagonal solver """ # we have already set the first terms in c prime # if we need to forward sweep, we must set the remaining # terms. Otherwise, they are just ratios self._cprime = self._c / self._b self._dprime = d / self._b if self._a_non_zero: jit_forward_sweep( self._n_grid_points, self._cprime, self._dprime, self._a, self._b, self._c, d, ) def _backwards_substitution(self): """ This is the backwards substitution step of the tridiagonal solver. """ n_j_plus_1 = np.zeros(self._n_grid_points) # set the end points n_j_plus_1[-1] = self._dprime[-1] # backwards step to the beginning jit_backward_sweep(self._n_grid_points, n_j_plus_1, self._cprime, self._dprime) return n_j_plus_1
[docs] def solve(self, d_j): """""" self._forward_sweep(d_j) d_j_plus_one = self._backwards_substitution() return d_j_plus_one