import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from jax.config import config
from jax.scipy.sparse.linalg import bicgstab
from scipy.sparse.linalg import eigs
from .primitives import spsolve, TMOperator
from .sim import SimGrid
from .typing import Callable, Dict, List, Optional, Shape2, Size, Size2, Size3, Spacing, SpSolve, Tuple, Union
from .utils import curl_fn, d2curl_op, yee_avg_2d_z, yee_avg_jax
try: # pardiso (using Intel MKL) is much faster than scipy's solver
from .mkl import spsolve_pardiso, feast_eigs
except OSError: # if mkl isn't installed
from scipy.sparse.linalg import spsolve
try:
from dphox.pattern import Pattern
DPHOX_INSTALLED = True
except ImportError:
DPHOX_INSTALLED = False
from logging import getLogger
logger = getLogger()
config.parse_flags_with_absl()
[docs]class FDFD(SimGrid):
"""Finite Difference Frequency Domain (FDFD) simulator
Notes:
Finite difference frequency domain works by performing a linear solve of discretized Maxwell's equations
at a `single` frequency (wavelength).
The discretized version of Maxwell's equations in frequency domain is:
.. math::
\\nabla \\times \\mu^{-1} \\nabla \\times \\mathbf{e} - k_0^2 \\epsilon \\mathbf{e} = k_0 \\mathbf{j},
which can be written in the form :math:`A \\mathbf{e} = \\mathbf{b}`, where:
.. math::
A &= \\nabla \\times \\mu^{-1} \\nabla \\times - k_0^2 \\epsilon
\\mathbf{b} &= k_0 \\mathbf{j}
is an operator representing the discretized EM wave operator at frequency
:math:`\\omega = k_0 = \\frac{2\\pi}{\\lambda}`.
Therefore, :math:`\\mathbf{e} = A^{-1}\\mathbf{b}`.
For 2D simulations, it can be more efficient to solve for just the :math:`z`-component of the fields since
only :math:`\\mathbf{e}_z` is non-zero. In this case, we can solve a smaller problem to improve the efficiency.
The form of this problem is :math:`A_z \\mathbf{e}_z = \\mathbf{b}_z`, where:
.. math::
A &= (\\nabla \\times \\mu^{-1} \\nabla \\times)_z + k_0^2 \\epsilon_z
\\mathbf{b}_z &= k_0 \\mathbf{j}_z
Attributes:
size: Tuple of size 1, 2, or 3 representing the size in the grid in arbitrary units
spacing: Spacing (microns) between each pixel along each axis (must be same dim as :code:`size`)
eps: Relative permittivity :math:`\\epsilon_r`
bloch_phase: Bloch phase (generally useful for angled scattering sims, not yet implemented!)
pml: Perfectly matched layer (PML) of thickness on both sides of the form :code:`(x_pml, y_pml, z_pml)`
pml_sep: The PML separation distance in number of pixels for sources
pml_params: The PML parameters of the form :code:`(exp_scale, log_reflectivity, pml_eps)`.
"""
def __init__(self, size: Size, spacing: Spacing,
wavelength: float = 1.55, eps: Union[float, np.ndarray] = 1,
bloch_phase: Union[Size, float] = 0.0, pml: Optional[Union[Size, float]] = None,
pml_sep: int = 5, pml_params: Size3 = (4, -16, 1.0), name: str = 'fdfd'):
super(FDFD, self).__init__(
size=size,
spacing=spacing,
eps=eps,
bloch_phase=bloch_phase,
pml=pml,
pml_params=pml_params,
pml_sep=pml_sep,
name=name
)
self.wavelength = wavelength
self.k0 = 2 * np.pi / self.wavelength
# overwrite dxes with PML-scaled ones if specified
if self.pml_shape is not None:
dxes_pml_e, dxes_pml_h = [], []
for ax, p in enumerate(self.pos):
scpml_e, scpml_h = self.scpml(ax)
dxes_pml_e.append(self.cells[ax] * scpml_e)
dxes_pml_h.append(self.cells[ax] * scpml_h)
self._dxes = np.meshgrid(*dxes_pml_e, indexing='ij'), np.meshgrid(*dxes_pml_h, indexing='ij')
[docs] @classmethod
def from_simgrid(cls, simgrid: SimGrid, wavelength: float):
"""Prepare an :code:`FDFD` instance from a generic :code:`SimGrid` and wavelength :math:`\\lambda`.
Args:
simgrid: :code:`SimGrid` instance.
wavelength: Wavelength (:math:`\\lambda`).
Returns:
The :code:`FDFD` instance
"""
fdfd = cls(
size=simgrid.size,
spacing=simgrid.spacing,
wavelength=wavelength,
eps=simgrid.eps,
pml=simgrid.pml_shape,
name=simgrid.name
)
fdfd.port = simgrid.port
return fdfd
@property
def mat(self) -> sp.csr_matrix:
"""Build the discrete Maxwell operator :math:`A(k_0)` acting on :math:`\\mathbf{e}`.
Returns:
Electric field operator :math:`A` for solving Maxwell's equations at frequency :math:`omega`.
"""
curl_curl: sp.spmatrix = d2curl_op(self.deriv_backward) @ d2curl_op(self.deriv_forward)
curl_curl.sort_indices() # for the solver
mat = curl_curl - self.k0 ** 2 * sp.diags(self.eps_t.flatten())
return mat
A = mat # alias A (common symbol for FDFD matrix) to mat
@property
def mat_ez(self) -> sp.csr_matrix:
"""Build the discrete Maxwell operator :math:`A_z(k_0)` acting on :math:`\\mathbf{e}_z` (for 1D/2D problems).
Returns:
Electric field operator :math:`A_z` for a source with ez-polarized field.
"""
df, db = self.deriv_forward, self.deriv_backward
ddz = -db[0] @ df[0] - db[1] @ df[1]
ddz.sort_indices() # for the solver
mat = ddz - self.k0 ** 2 * sp.diags(self.eps_t[2].flatten())
return mat
@property
def mat_hz(self) -> sp.csr_matrix:
"""Build the discrete Maxwell operator :math:`A_z(k_0)` acting on :math:`\\mathbf{h}_z` (for 2D problems).
Returns:
Magnetic field operator :math:`A_z` for a source with hz-polarized field.
"""
df, db = self.deriv_forward, self.deriv_backward
t0, t1 = sp.diags(1 / self.eps_t[0].flatten()), sp.diags(1 / self.eps_t[1].flatten())
mat = -db[0] @ t1 @ df[0] - db[1] @ t0 @ df[1] - self.k0 ** 2 * sp.identity(self.n)
return mat
[docs] def e2h(self, e: np.ndarray, beta: Optional[float] = None) -> np.ndarray:
"""Convert magnetic field :math:`\\mathbf{e}` to electric field :math:`\\mathbf{h}`.
Usage is: :code:`h = fdfd.e2h(e)`, where :code:`e` is grid-shaped (not flattened). If :code:`e` is flattened,
this function attempts to reshape it.
Mathematically, this represents rearranging the Maxwell equation in the frequency domain:
..math::
i \\omega \\mu \\mathbf{h} &= \\nabla \times \\mathbf{e}.
Returns:
The h-field converted from the e-field.
"""
return self.curl_fn(of_h=False, beta=beta)(self.reshape(e)) / (1j * self.k0)
[docs] def h2e(self, h: np.ndarray, beta: Optional[float] = None) -> np.ndarray:
"""Convert magnetic field :math:`\\mathbf{h}` to electric field :math:`\\mathbf{e}`.
Usage is: :code:`e = fdfd.h2e(h)`, where :code:`h` is grid-shaped (not flattened). If :code:`h` is flattened,
this function attempts to reshape it.
Mathematically, this represents rearranging the Maxwell equation in the frequency domain:
..math::
-i \\omega \\epsilon \\mathbf{e} &= \\nabla \\times \\mathbf{h}.
Returns:
The e-field converted from the h-field.
"""
return self.curl_fn(of_h=True, beta=beta)(self.reshape(h)) / (1j * self.k0 * self.eps_t)
[docs] def solve(self, src: np.ndarray, solver_fn: Optional[SpSolve] = None,
iterative: int = -1, tm_2d: bool = True, callback: Optional[Callable] = None) -> np.ndarray:
"""Solves the FDFD problem (see class description for math).
Args:
src: normalized source (can be wgm or tfsf)
solver_fn: any function that performs a sparse linalg solve
iterative: default = -1, direct = 0, gmres = 1, bicgstab
tm_2d: use the TM polarization (only relevant for 2D, ignored for 3D)
callback: a function to run during the solve (only applies in 3d iterative solver case, not yet implemented)
Returns:
Electric fields that solve the problem :math:`A\\mathbf{e} = \\mathbf{b} = i \\omega \\mathbf{j}`.
"""
b = self.k0 * src.flatten()
if self.ndim == 3:
if not src.size == 3 * self.n:
raise ValueError(f'Expected src.size == {3 * self.n}, but got {b.size}.')
if iterative > 0 and solver_fn is None and self.ndim == 3:
# use iterative solver for 3d sims by default
M = sp.linalg.LinearOperator(self.mat.shape, sp.linalg.spilu(self.mat).solve)
e, _ = sp.linalg.gmres(self.mat, b, M=M) if iterative == 1 else sp.linalg.bicgstab(self.mat, b, M=M)
else:
e = solver_fn(self.mat, b) if solver_fn else spsolve_pardiso(self.mat, b)
e = self.reshape(e)
curl_e = curl_fn(self.diff_fn(of_h=False))
h = curl_e(e) / (1j * self.k0)
return np.array((e, h))
else: # assume only the z component
if not src.size == self.n:
raise ValueError(f'Expected src.size == {self.n}, but got {b.size}.')
mat = self.mat_hz if tm_2d else self.mat_ez
fz = solver_fn(mat, b) if solver_fn else spsolve_pardiso(mat, b)
o = np.zeros_like(fz)
field = np.vstack((o, o, fz)).reshape((3, *self.shape3))
df = self.diff_fn(of_h=tm_2d, use_jax=False)
eps_t = self.eps_t
if tm_2d:
h = field
o = np.zeros_like(h[2])
return np.stack([np.stack((df(h[2], 1), -df(h[2], 0), o)) / (1j * self.k0 * eps_t), h])
else:
e = field
o = np.zeros_like(e[2])
return np.stack([e, np.stack((df(e[2], 1), -df(e[2], 0), o)) / (1j * self.k0)])
[docs] def scpml(self, ax: int) -> Tuple[np.ndarray, np.ndarray]:
exp_scale, log_reflection, pml_eps = self.pml_params
if self.cells[ax].size == 1:
return np.ones(1), np.ones(1)
p = self.pos[ax]
pe, ph = (p[:-1] + p[1:]) / 2, p[:-1]
absorption_corr = self.k0 * pml_eps
t = self.pml_shape[ax]
def _scpml(d: np.ndarray):
d_pml = np.hstack((
(d[t] - d[:t]) / (d[t] - p[0]),
np.zeros_like(d[t:-t]),
(d[-t:] - d[-t]) / (p[-1] - d[-t])
))
return 1 + 1j * (exp_scale + 1) * (d_pml ** exp_scale) * log_reflection / (2 * absorption_corr)
return _scpml(pe), _scpml(ph)
[docs] @classmethod
def from_pattern(cls, component: "Pattern", core_eps: float, clad_eps: float, spacing: float, boundary: Size,
pml: float, wavelength: float, component_t: float = 0, component_zmin: Optional[float] = None,
rib_t: float = 0, sub_z: float = 0, height: float = 0, bg_eps: float = 1, name: str = 'fdfd'):
"""Initialize an FDFD from a Pattern defined in DPhox.
Args:
component: component provided by DPhox
core_eps: core epsilon (in the pattern region)
clad_eps: clad epsilon
spacing: spacing required
boundary: boundary size around component
pml: PML boundary size
wavelength: Wavelength for the simulation (specific to FDFD).
height: height for 3d simulation
sub_z: substrate minimum height
component_zmin: component height (defaults to substrate_z)
component_t: component thickness
rib_t: rib thickness for component (partial etch)
bg_eps: background epsilon (usually 1 or air/vacuum)
name: Name of the component
Returns:
A Grid object for the component
"""
if not DPHOX_INSTALLED:
raise ImportError('DPhox not installed, but it is required to run this function.')
b = component.size
x = b[0] + 2 * boundary[0]
y = b[1] + 2 * boundary[1]
component_zmin = sub_z if component_zmin is None else component_zmin
spacing = spacing * np.ones(2 + (component_t > 0)) if isinstance(spacing, float) else np.asarray(spacing)
size = (x, y, height) if height > 0 else (x, y)
grid = cls(size, spacing, wavelength=wavelength, eps=bg_eps, pml=pml, name=name)
grid.fill(sub_z + rib_t, core_eps)
grid.fill(sub_z, clad_eps)
grid.add(component, core_eps, component_zmin, component_t)
return grid
[docs] def sparams(self, port_name: str, mode_idx: int = 0, measure_info: Optional[Dict[str, List[int]]] = None):
"""Measure sparams using a port profile provided for a given port and mode index.
Args:
port_name: The source port name for the sparams to measure.
mode_idx: Mode index to access for the source.
measure_info: A list of :code:`port_name`, :code:`mode_idx` to get mode measurements at each port.
Returns:
Sparams measured at the ports specified in the class.
"""
measure_fn = self.get_measure_fn(measure_info)
measure_info = [(name, 0) for name in self.port] if measure_info is None else measure_info
h = self.solve(self.port_source({(port_name, mode_idx): 1}))
s_out, s_in = measure_fn(h)
src_sparam_reference = measure_info.index((port_name, mode_idx))
return s_out / s_in[src_sparam_reference]
[docs] def get_fields_fn(self, src: np.ndarray, transform_fn: Optional[Callable] = None, tm_2d: bool = True) -> Callable:
"""Build a fields function of a set of parameters (e.g., density, epsilon, etc.)
given the source and transform function.
1. A numpy array source :code:`src`.
2. The JAX-transformable transform function :code:`transform_fn` (e.g. transform) that yields
the epsilon distribution used by jax.
Args:
src: Source for the solver.
transform_fn: Transforms parameters to yield the epsilon parameters used by jax (if None, use identity).
tm_2d: Whether to solve the TM polarization for this FDFD (only relevant for 2D, ignored for 3D).
Returns:
A solve function (2d or 3d based on defined :code:`ndim` specified for the instance of :code:`FDFD`).
"""
src = jnp.ravel(jnp.array(src))
k0 = self.k0
transform_fn = transform_fn if transform_fn is not None else lambda x: x
shape3 = self.shape3
field_shape = self.field_shape
def coo_to_jnp(mat: sp.coo_matrix):
mat.sort_indices()
mat = mat.tocoo()
return jnp.array(mat.data, dtype=np.complex128), jnp.vstack((jnp.array(mat.row), jnp.array(mat.col)))
if self.ndim < 3:
shape = self.shape
o = jnp.zeros(self.shape3, jnp.complex128)
if tm_2d:
# exact 2d FDFD for TM polarization
constant_term = -jnp.ones_like(self.eps.flatten()) * k0 ** 2
constant_term_indices = jnp.stack((jnp.arange(self.n), jnp.arange(self.n)))
# this is temporary while we wait for sparse-sparse support in JAX.
operator = TMOperator(self.deriv_forward, self.deriv_backward)
x_op = operator.compile_operator_along_axis(0)
y_op = operator.compile_operator_along_axis(1)
x_ind, y_ind = operator.x_indices, operator.y_indices
dh = self.diff_fn(of_h=True, use_jax=True)
# @jax.jit
def solve(rho: jnp.ndarray):
eps_t = yee_avg_jax(transform_fn(rho).reshape(self.shape3))
eps_x, eps_y = jnp.ravel(eps_t[0]), jnp.ravel(eps_t[1])
ddx_entries = x_op(-1 / eps_y)
ddy_entries = y_op(-1 / eps_x)
mat_entries = jnp.hstack((constant_term, ddx_entries, ddy_entries))
hz = spsolve(mat_entries, k0 * src, jnp.hstack((constant_term_indices, x_ind, y_ind)))
hz = hz.reshape(shape3)
h = jnp.stack((o, o, hz))
e = jnp.stack((dh(h[2], 1), -dh(h[2], 0), o)) / (1j * k0 * eps_t)
return jnp.stack((e, h))
else:
# exact 2d FDFD for TE polarization
df, db = self.deriv_forward, self.deriv_backward
ddz = -db[0] @ df[0] - db[1] @ df[1]
ddz_entries, ddz_indices = coo_to_jnp(ddz)
mat_indices = jnp.hstack((jnp.vstack((jnp.arange(self.n), jnp.arange(self.n))), ddz_indices))
de = self.diff_fn(of_h=False, use_jax=True)
# @jax.jit
def solve(rho: jnp.ndarray):
eps = yee_avg_2d_z(transform_fn(rho).reshape(shape3)).ravel()
mat_entries = jnp.hstack((-eps * k0 ** 2, ddz_entries))
ez = spsolve(mat_entries, k0 * src, mat_indices)
ez = ez.reshape(shape3)
e = jnp.stack((o, o, ez))
h = jnp.stack((de(e[2], 1), -de(e[2], 0), o)) / (1j * k0)
return jnp.stack((e, h))
else:
# iterative 3d FDFD (simpler than 2D code-wise, but takes way more memory and time)
curl_e = curl_fn(self.diff_fn(of_h=False, use_jax=True), use_jax=True)
curl_h = curl_fn(self.diff_fn(of_h=True, use_jax=True), use_jax=True)
def op(eps: jnp.ndarray):
return lambda b: curl_h(curl_e(b.reshape(field_shape))) - k0 ** 2 * eps * b.reshape(field_shape)
# @jax.jit
def solve(rho: jnp.ndarray):
eps = yee_avg_jax(transform_fn(rho))
e, _ = bicgstab(op(eps), k0 * src.reshape(field_shape))
e = jnp.stack([split_v.reshape(shape3) for split_v in jnp.split(e, 3)])
h = curl_e(e) / (1j * k0)
return jnp.stack((e, h))
return solve
[docs] def to_2d(self, wavelength: float = None, slab_x: Union[Shape2, Size2] = None,
slab_y: Union[Shape2, Size2] = None, tm: bool = False):
"""Project a 3D FDFD into a 2D FDFD using the variational 2.5D method laid out in the
[paper](https://ris.utwente.nl/ws/files/5413011/ishpiers09.pdf).
Args:
wavelength: The wavelength to use for calculating the effective 2.5 FDFD
(useful to stabilize multi-wavelength optimizations)
slab_x: Port location x (if None, the port is provided by reading the port location specified by the component)
slab_y: Port location y (if None, the port is provided by reading the port location specified by the component)
tm: Whether the mode in the 2D simulation is a TM mode (dominated by Hz component).
Returns:
A 2D FDFD to approximate the 3D FDFD
"""
wavelength = self.wavelength if wavelength is None else wavelength
return FDFD.from_simgrid(super(FDFD, self).to_2d(wavelength, tm=tm), wavelength)
[docs] def tfsf_profile(self, q_mask: np.ndarray, k: Size, wavelength: float = None):
mask = q_mask
q = sp.diags(mask.flatten())
wavelength = self.wavelength if wavelength is None else wavelength
k0 = 2 * np.pi / wavelength
k = np.asarray(k) / np.sum(k) * k0
fsrc = np.einsum('i,j,k->ijk', np.exp(1j * self.pos[0][:-1] * k[0]),
np.exp(1j * self.pos[1][:-1] * k[1]),
np.exp(1j * self.pos[2][:-1] * k[2])).flatten()
a = self.mat
src = self.reshape((q @ a - a @ q) @ fsrc) # qaaq = quack :)
raise NotImplementedError('TFSF profile not yet implemented.')
return src