import dataclasses
from functools import lru_cache
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from .typing import Size, Size3, Spacing, Optional, List, Union, Dict, Op, Tuple
from .utils import curl_fn, yee_avg, fix_dataclass_init_docs, Box
try:
DPHOX_IMPORTED = True
from dphox.pattern import Pattern
except ImportError:
DPHOX_IMPORTED = False
[docs]@fix_dataclass_init_docs
@dataclasses.dataclass
class Port:
"""Port to define where sources and measurements lie in photonic simulations.
A port defines the center and angle/orientation in a design.
Args:
x: x position of the port
y: y position of the port
a: angle (orientation) of the port (in degrees)
w: the width of the port (specified in design, mostly used for simulation)
z: z position of the port (not specified in design, mostly used for simulation)
h: the height of the port (not specified in design, mostly used for simulation)
"""
x: float
y: float
a: float = 0
w: float = 0
z: float = 0
h: float = 0
def __post_init__(self):
self.xy = (self.x, self.y)
self.xya = (self.x, self.y, self.a)
self.xyz = (self.x, self.y, self.z)
self.center = np.array(self.xyz)
@property
def size(self):
if np.mod(self.a, 90) != 0:
raise ValueError(f"Require angle to be a multiple a multiple of 90 but got {self.a}")
return np.array((self.w, 0, self.h)) if np.mod(self.a, 180) != 0 else np.array((0, self.w, self.h))
[docs]class Grid:
def __init__(self, size: Size, spacing: Spacing, eps: Union[float, np.ndarray] = 1.0):
"""Grid object accomodating any electromagnetic simulation (FDFD, FDTD, BPM, etc.)
Args:
size: Tuple of size 1, 2, or 3 representing the size of the grid
spacing: Spacing (microns) between each pixel along each axis (must be same dim as `grid_shape`)
eps: Relative permittivity (
"""
self.size = np.asarray(size)
self.spacing = spacing * np.ones(len(size)) if isinstance(spacing, int) or isinstance(spacing, float) else np.asarray(spacing)
self.ndim = len(size)
if not self.ndim == self.spacing.size:
raise AttributeError(f'Require size.size == ndim == spacing.size but got '
f'{self.size.size} != {self.spacing.size}')
self.shape = np.around(self.size / self.spacing).astype(int)
self.shape3 = np.hstack((self.shape, np.ones((3 - self.ndim,), dtype=self.shape.dtype)))
self.spacing3 = np.hstack((self.spacing, np.ones((3 - self.ndim,), dtype=self.spacing.dtype) * np.inf))
self.size3 = np.hstack((self.size, np.zeros((3 - self.ndim,), dtype=self.size.dtype)))
self.center = self.size3 / 2
self.field_shape = (3, *self.shape3)
self.n = np.prod(self.shape)
self.eps: np.ndarray = np.ones(self.shape) * eps if not isinstance(eps, np.ndarray) else eps
if not tuple(self.shape) == self.eps.shape:
raise AttributeError(f'Require grid.shape == eps.shape but got '
f'{self.shape} != {self.eps.shape}')
self.cells = [(self.spacing[i] * np.ones((self.shape[i],)) if self.ndim > 1 else self.spacing * np.ones(self.shape))
if i < self.ndim else np.ones((1,)) for i in range(3)]
self.pos = [np.hstack((0, np.cumsum(dx))) if dx.size > 1 else np.asarray((0,)) for dx in self.cells]
self.components = []
# used to handle special functions of waveguide-based components
self.port: Dict[str, Port] = {}
[docs] def fill(self, height: float, eps: float) -> "Grid":
"""Fill grid up to `height`, typically used for substrate + cladding epsilon settings
Args:
height: Maximum final dimension of the fill operation (`y` if 2D, `z` if 3D).
eps: Relative permittivity to fill.
Returns:
The modified :code:`Grid` for chaining (:code:`self`)
"""
if height > 0:
self.eps[..., :int(height / self.spacing[-1])] = eps
else:
self.eps = np.ones_like(self.eps) * eps
return self
[docs] def add(self, component: "Pattern", eps: float, zmin: float = None, thickness: float = None) -> "Grid":
"""Add a component to the grid.
Args:
component: component to add
eps: permittivity of the component being added (isotropic only, for now)
zmin: minimum z extent of the component
thickness: component thickness (`zmax = zmin + thickness`)
Returns:
The modified :code:`Grid` for chaining (:code:`self`)
"""
b = component.bounds
if not b[0] >= 0 and b[1] >= 0 and b[2] <= self.size[0] and b[3] <= self.size[1]:
raise ValueError('The pattern must have min x, y >= 0 and max x, y less than size.')
self.components.append(component)
mask = component.mask(self.shape[:2], self.spacing)[:self.eps.shape[0], :self.eps.shape[1]]
if self.ndim == 2:
self.eps[mask == 1] = eps
else:
zidx = (int(zmin / self.spacing[0]), int((zmin + thickness) / self.spacing[1]))
self.eps[mask == 1, zidx[0]:zidx[1]] = eps
self.port = {port_name: Port(*port.xya, port.w, zmin + thickness / 2, thickness)
for port_name, port in component.port.items()}
return self
[docs] def set_eps(self, center: Size3, size: Size3, eps: float):
"""Set the region specified by :code:`center`, :code:`size` (in grid units) to :code:`eps`.
Args:
center: Center of the region.
size: Size of the region.
eps: Epsilon (relative permittivity) to set.
Returns:
The modified :code:`Grid` for chaining (:code:`self`)
"""
s = self.slice(center, size, squeezed=True)
eps_3d = self.eps.reshape(self.shape3)
eps_3d[s] = eps
self.eps = eps_3d.squeeze()
return self
[docs] def mask(self, center: Size3, size: Size3):
"""Given a size and center, this function defines a mask which sets pixels in the region corresponding to
:code:`center` and :code:`size` to 1 and all other pixels to zero.
Args:
center: position of the mask in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
size: size of the mask box in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
Returns:
The mask array of size :code:`grid.shape`.
"""
s = self.slice(center, size, squeezed=True)
mask = np.zeros(self.shape3)
mask[s] = 1
return mask.squeeze()
[docs] def reshape(self, v: np.ndarray) -> np.ndarray:
"""A simple method to reshape flat 3d field array into the grid shape
Args:
v: vector of size :code:`3n` to rearrange into array of size :code:`(3, nx, ny, nz)`
Returns:
The reshaped array
"""
return v.reshape((3, *self.shape3))
[docs] def slice(self, center: Size3, size: Size3, squeezed: bool = True):
"""Pick a slide of this grid
Args:
center: center of the slice in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
size: size of the slice in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
squeezed: whether to squeeze the slice to the minimum dimension (the squeeze order is z, then y).
Returns:
The slices to access the array
"""
# if self.ndim == 1:
# raise ValueError(f"Simulation dimension ndim must be 2 or 3 but got {self.ndim}.")
if not len(size) == 3:
raise ValueError(f"For simulation that is 3d, must provide size arraylike of size 3 but got {size}")
if not len(center) == 3:
raise ValueError(f"For simulation that is 3d, must provide center arraylike of size 3 but got {center}")
c = np.around(np.asarray(center) / self.spacing3).astype(int) # assume isotropic for now...
shape = np.around(np.asarray(size) / self.spacing3).astype(int)
s0, s1, s2 = shape[0] // 2, shape[1] // 2, shape[2] // 2
c0 = c[0] if squeezed else slice(c[0], c[0] + 1)
c1 = c[1] if squeezed else slice(c[1], c[1] + 1)
c2 = c[2] if squeezed else slice(c[2], c[2] + 1)
# if s0 == s1 == s2 == 0:
# raise ValueError(f"Require the size result in a nonzero-sized shape, but got a single point in the grid"
# f"(i.e., the size {size} may be less than the spacing {self.spacing3})")
return (slice(c[0] - s0, c[0] - s0 + shape[0]) if shape[0] > 0 else c0,
slice(c[1] - s1, c[1] - s1 + shape[1]) if shape[1] > 0 else c1,
slice(c[2] - s2, c[2] - s2 + shape[2]) if shape[2] > 0 else c2)
[docs] def view_fn(self, center: Size3, size: Size3, use_jax: bool = True):
"""Return a function that views a field at specific region.
The view function is specified by center and size in the grid. This is typically used for
mode-based sources and measurements. Once a slice is found, the fields need to be reoriented
such that the field components point in the right direction despite a change in axis assignment.
This function will handle this logic automatically in 1d, 2d, and 3d cases.
Args:
center: Center of the region
size: Size of the region
use_jax: Use jax
Returns:
A view callable function that orients the field and finds the appropriate slice.
"""
if np.count_nonzero(size) == 3:
raise ValueError(f"At least one element of size must be zero, but got {size}")
s = self.slice(center, size, squeezed=False)
xp = jnp if use_jax else np
# Find the view axis (the poynting direction)
view_axis = 0
for i in range(self.ndim):
if size[i] == 0:
view_axis = i
# Find the reorientation of field axes based on view_axis
# 0 -> (1, 2, 0)
# 1 -> (0, 2, 1)
# 2 -> (0, 1, 2)
axes = [
np.asarray((1, 2, 0), dtype=int),
np.asarray((0, 2, 1), dtype=int),
np.asarray((0, 1, 2), dtype=int)
][view_axis]
def view(field):
oriented_field = xp.stack(
(field[axes[0]].reshape(self.shape3),
field[axes[1]].reshape(self.shape3),
field[axes[2]].reshape(self.shape3))
) # orient the field by axis (useful for mode calculation)
return oriented_field[:, s[0], s[1], s[2]].transpose((0, *tuple(1 + axes)))
return view
[docs] def mask_fn(self, size: Size3, center: Optional[Size3] = None):
"""Given a box with :code:`size` and :code:`center`, return a function that sets pixels in :code:`rho`,
where :code:`rho.shape == grid.eps.shape`, outside the box to :code:`eps`.
This is important in inverse design to avoid modifying the material region near the source and measurement
regions.
Args:
center: position of the mask in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
size: size of the mask box in (x, y, z) in the units of the simulation (note: NOT in terms of array index)
Returns:
The mask function
"""
rho_init = self.eps
center = self.center if center is None else center
mask = self.mask(center, size)
return lambda rho: jnp.array(rho_init) * (1 - mask) + rho * mask
[docs] def block_design(self, waveguide: Box, wg_height: Optional[float] = None, sub_eps: float = 1,
sub_height: float = 0, gap: float = 0, block: Optional[Box] = None, sep: Size = (0, 0),
vertical: bool = False, rib_y: float = 0):
"""A helper function for designing a useful port or cross section for a mode solver.
Args:
waveguide: The base waveguide material and size in the form of :code:`Box`.
wg_height: The waveguide height.
sub_eps: The substrate epsilon (defaults to air)
sub_height: The height of the substrate (or the min height of the waveguide built on top of it)
gap: The coupling gap specified means we get a pair of base blocks
separated by :code:`coupling_gap`.
block: Perturbing block that is to be aligned either vertically or horizontally with waveguide (MEMS).
sep: Separation of the block from the base waveguide layer.
vertical: Whether the perturbing block moves vertically, or laterally otherwise.
rib_y: Rib section.
Returns:
The resulting :code:`Grid` with the modified :code:`eps` property.
"""
if rib_y > 0:
self.fill(rib_y + sub_height, waveguide.eps)
self.fill(sub_height, sub_eps)
waveguide.align(self.center)
if wg_height:
waveguide.valign(wg_height)
else:
wg_height = waveguide.min[1]
sep = (sep, sep) if not isinstance(sep, Tuple) else sep
d = gap / 2 + waveguide.size[0] / 2 if gap > 0 else 0
waveguides = [waveguide.copy.translate(-d), waveguide.copy.translate(d)]
blocks = []
if vertical:
blocks = [block.copy.align(waveguides[0]).valign(waveguides[0]).translate(dy=sep[0]),
block.copy.align(waveguides[1]).valign(waveguides[1]).translate(dy=sep[1])]
elif block is not None:
blocks = [block.copy.valign(wg_height).halign(waveguides[0], left=False).translate(-sep[0]),
block.copy.valign(wg_height).halign(waveguides[1]).translate(sep[1])]
for wg in waveguides + blocks:
self.set_eps((wg.center[0], wg.center[1], 0), (wg.size[0], wg.size[1], 0), wg.eps)
return self
[docs]class YeeGrid(Grid):
def __init__(self, size: Size, spacing: Spacing, eps: Union[float, np.ndarray] = 1,
bloch_phase: Union[Size, float] = 0.0, pml: Optional[Size] = None, pml_sep: int = 5,
pml_params: Size3 = (4, -16, 1.0), name: str = 'simgrid'):
"""The base :code:`YeeGrid` class (adding things to :code:`Grid` like Yee grid support, Bloch phase,
PML shape, etc.).
Args:
size: Tuple of size 1, 2, or 3 representing the size of the grid
spacing: Spacing (microns) between each pixel along each axis (must be same dim as `grid_shape`)
eps: Relative permittivity :math:`\\epsilon_r`
bloch_phase: Bloch phase (generally useful for angled scattering sims)
pml: Perfectly matched layer (PML) of thickness on both sides of the form :code:`(x_pml, y_pml, z_pml)`
pml_sep: Specifies the number of pixels that any source must be placed away from a PML region.
pml_params: The parameters of the form :code:`(exp_scale, log_reflectivity, pml_eps)`.
"""
super(YeeGrid, self).__init__(size, spacing, eps)
self.pml = pml
self.pml_sep = pml_sep
self.pml_shape = pml if pml is None else (np.asarray(pml, dtype=float) / self.spacing).astype(np.int)
self.pml_params = pml_params
self.name = name
if self.pml_shape is not None:
if np.any(self.pml_shape <= 3) or np.any(self.pml_shape >= self.shape // 2):
raise AttributeError(f'PML shape must be more than 3 and less than half the shape on each axis.')
if pml is not None and not len(self.pml_shape) == len(self.shape):
raise AttributeError(f'Need len(pml_shape) == grid.shape,'
f'got ({len(pml)}, {len(self.shape)}).')
self.bloch = np.ones_like(self.shape) * np.exp(1j * np.asarray(bloch_phase)) if isinstance(bloch_phase, float) \
else np.exp(1j * np.asarray(bloch_phase))
if not len(self.bloch) == len(self.shape):
raise AttributeError(f'Need bloch_phase.size == grid.shape,'
f'got ({len(self.bloch)}, {len(self.shape)}).')
self.dtype = np.float64 if pml is None and bloch_phase == 0 else np.complex128
self._dxes = np.meshgrid(*self.cells, indexing='ij'), np.meshgrid(*self.cells, indexing='ij')
[docs] def deriv(self, back: bool = False) -> List[sp.spmatrix]:
"""Calculate directional derivative.
Args:
back: Return backward derivative.
Returns:
Discrete directional derivative :code:`d` of the form :code:`(d_x, d_y, d_z)`
"""
# account for 1d and 2d cases
b = np.hstack((self.bloch, np.ones((3 - self.ndim,), dtype=self.bloch.dtype)))
s = np.hstack((self.shape, np.ones((3 - self.ndim,), dtype=self.shape.dtype)))
if back:
# get backward derivative
_, dx = self._dxes
d = [sp.diags([1, -1, -np.conj(b[ax])], [0, -1, n - 1], shape=(n, n))
if n > 1 else 0 for ax, n in enumerate(s)] # get single axis back-derivs
else:
# get forward derivative
dx, _ = self._dxes
d = [sp.diags([-1, 1, b[ax]], [0, 1, -n + 1], shape=(n, n))
if n > 1 else 0 for ax, n in enumerate(s)] # get single axis forward-derivs
d = [sp.kron(d[0], sp.eye(s[1] * s[2])).astype(np.complex128),
sp.kron(sp.kron(sp.eye(s[0]), d[1]), sp.eye(s[2])).astype(np.complex128),
sp.kron(sp.eye(s[0] * s[1]), d[2]).astype(np.complex128)] # tile over the other axes using sp.kron
d = [sp.diags(1 / dx[ax].ravel()) @ d[ax] for ax in range(len(s))] # scale by dx (incl pml)
return d
@property
def deriv_forward(self):
"""The forward derivative
Returns:
The forward derivative
"""
return self.deriv()
@property
def deriv_backward(self):
"""The backward derivative
Returns:
The backward derivative
"""
return self.deriv(back=True)
[docs] def diff_fn(self, of_h: bool = False, use_jax: bool = False):
"""Return a function that takes the discrete derivative of a field in a functional manner based on grid.
Args:
of_h: Take the derivative of :math:`\\mathbf{H}`, otherwise :math:`\\mathbf{E}`.
use_jax: Whether to use jax.
Returns:
The discrete derivative function
"""
xp = jnp if use_jax else np
dx = jnp.array(self._dxes[of_h]) if use_jax else self._dxes[of_h]
if of_h:
def _diff(f, a):
return (f - xp.roll(f, 1, axis=a)) / dx[a]
else:
def _diff(f, a):
return (xp.roll(f, -1, axis=a) - f) / dx[a]
return _diff
[docs] def curl_fn(self, beta: Optional[float] = None, of_h: bool = False, use_jax: bool = False) -> Op:
"""Get the function that computes curl of the electric field :math:`\\mathbf{E}`.
Args:
beta: Propagation constant in the z direction (note: x, y are the `cross section` axes).
of_h: Whether to take the curl of h
use_jax: Whether the returned function should use jax.
Returns:
A function that computes discretized curl :math:`\\nabla \\times \\mathbf{E}`.
"""
diff_fn = self.diff_fn(of_h=of_h, use_jax=use_jax)
return curl_fn(diff_fn, use_jax=use_jax, beta=beta)
[docs] def pml_safe_placement(self, loc: Size3) -> Size3:
"""Specifies a source/ measurement placement that is safe from the PML region / edge of the simulation.
Args:
loc: Location of the form (x, y, z) to move safely away from the PML
Returns:
New x, y, z tuple that is safe from PML (at least one Yee grid point away from the pml region).
"""
x, y, z = loc
pml = (self.pml_shape + self.pml_sep) * self.spacing if self.pml_shape is not None else (0, 0)
maxx, maxy = self.size[:2] - self.spacing[:2]
new_x = min(max(x, pml[0]), maxx - pml[0])
new_y = min(max(y, pml[1]), maxy - pml[1])
return new_x, new_y, z
@property
@lru_cache()
def eps_t(self):
"""This attribute specifies the grid-averaged epsilon in the grid.
Returns:
The grid-averaged epsilon.
"""
return yee_avg(self.eps.reshape(self.shape3))