Source code for simphox.sim

from functools import lru_cache

import jax
import numpy as np
import jax.numpy as jnp

from .grid import YeeGrid
from .mode import ModeSolver, ModeLibrary
from .parse import parse_excitation
from .typing import Excitation, Spacing, Shape, Union, Size, Optional, List, Tuple, Shape2, Size2, Dict, Size3, \
    Callable, \
    MeasureInfo, Op, PortLabel
from .utils import fix_dataclass_init_docs
from .viz import get_extent_2d

try:
    HOLOVIEWS_IMPORTED = True
    import holoviews as hv
    from holoviews.streams import Pipe
    from holoviews import opts
    import panel as pn
except ImportError:
    HOLOVIEWS_IMPORTED = False

try:
    from dphox.pattern import Pattern

    DPHOX_INSTALLED = True
except ImportError:
    DPHOX_INSTALLED = False

import dataclasses
import xarray as xr


[docs]@fix_dataclass_init_docs @dataclasses.dataclass class SimCrossSection: io: ModeLibrary center: Size size: Size wavelength: float
[docs] def place(self, mode_idx: int, grid: YeeGrid) -> np.ndarray: return self.io.place(mode_idx, grid, self.center, self.size)
[docs] def slice(self, grid: YeeGrid): return (slice(None),) + grid.slice(self.center, self.size)
@property def prop_axis(self) -> int: """The propagation axis (the poynting direction where the port slice size should be 0) Returns: The propagation axis for the mode in the simulation cross section. """ # Find return np.where(np.array(self.size) == 0)[0][0]
[docs] def profile(self, mode_idx: int, use_h: float = False): """Returns the mode profile oriented based on specified xyz size of the profile. Args: mode_idx: The mode index of the profile use_h: Return the h field for the profile Returns: The oriented profile. """ # Find the reorientation of field axes based on place_axis # 0: (0, 1, 2) -> (2, 0, 1) # 1: (0, 1, 2) -> (0, 2, 1) # 2: (0, 1, 2) -> (0, 1, 2) axes = [ np.asarray((2, 0, 1), dtype=int), np.asarray((0, 2, 1), dtype=int), np.asarray((0, 1, 2), dtype=int) ][self.prop_axis] signs = [ np.asarray((1, -1, 1), dtype=int), np.asarray((1, 1, 1), dtype=int), np.asarray((1, 1, -1), dtype=int) ][self.prop_axis] mode = self.io.h(mode_idx) if use_h else self.io.e(mode_idx) mode = np.stack([signs[0] * mode[axes[0]], signs[1] * mode[axes[1]], signs[2] * mode[axes[2]]]) return mode.transpose((0, *tuple(1 + axes)))
[docs]class SimGrid(YeeGrid): def __init__(self, size: Size, spacing: Spacing, eps: Union[float, np.ndarray] = 1, bloch_phase: Union[Size, float] = 0.0, pml: Optional[Union[int, Shape, Size]] = None, pml_params: Size3 = (4, -16, 1.0, 5), pml_sep: int = 5, use_jax: bool = False, name: str = 'simgrid'): """The base :code:`SimGrid` class (adding things to :code:`Grid` like Yee grid support, Bloch phase, PML shape, etc.). Args: size: Tuple of size 1, 2, or 3 representing the number of pixels in the grid. spacing: Spacing (microns) between each pixel along each axis (must be same dim as `grid_shape`). eps: Relative permittivity :math:`\\epsilon_r`. bloch_phase: Bloch phase (generally useful for angled scattering sims). pml: Perfectly matched layer (PML) of thickness on both sides of the form :code:`(x_pml, y_pml, z_pml)`. pml_sep: The PML separation distance in number of pixels for sources. pml_params: The parameters of the form :code:`(exp_scale, log_reflectivity, pml_eps)`. """ super(SimGrid, self).__init__(size, spacing, eps, bloch_phase, pml, pml_sep, pml_params, name) self.use_jax = use_jax
[docs] @lru_cache() def modes(self, center: Size3, size: Size3, wavelength: float = 1.55, num_modes: int = 1) -> SimCrossSection: """Eigenmode profile of a 2d or 3d :code:`SimGrid` object. This function is cached so that it only computes the modes if the input parameters are provided. Args: center: center tuple of the form :code:`(x, y, z)` (in sim units, NOT pixels) size: size of the source (in sim units, NOT pixels) wavelength: wavelength (arb. units, should match with spacing) num_modes: number of modes to find Returns: A Tuple of a Modes object and view function for measuring the fields. """ mode_eps = self.eps.reshape(self.shape3)[self.slice(center, size)].squeeze() mode_size = tuple(np.array(mode_eps.shape) * self.spacing[0]) modes = ModeLibrary( ModeSolver(size=mode_size, spacing=self.spacing[0], eps=mode_eps, wavelength=wavelength), max_num_modes=num_modes ) return SimCrossSection(modes, center, size, wavelength)
[docs] def port_modes(self, excitation: List[Tuple[str, int]] = None, profile_size_factor: float = 3, wavelength: float = 1.55) -> Dict[PortLabel, SimCrossSection]: """Profile for all the ports in the grid (always assumed to be along x or y axes!). Args: excitation: Dictionary mapping port to mode index for excitations profile_size_factor: Factor to rescale the mode view slice compared to the port wavelength: Wavelength for the modes Returns: A dictionary from mode to SimCrossSection containing the mode, and its center and size in this grid needed to reconstruct the source using :code:`mode.place`. """ excitation = self.parse_excitation(excitation) port_names = {e[0] for e in excitation} num_modes = {name: np.max([mode_idx for port_name, mode_idx in excitation if port_name == name]) + 1 for name in port_names} return {name: self.modes(center=self.pml_safe_placement(self.port[name].xyz), size=tuple(self.port[name].size * profile_size_factor), num_modes=num_modes[name], wavelength=wavelength) for name in port_names}
[docs] def port_source(self, source: Optional[Union[Dict[Tuple[str, int], float], Dict[str, float]]] = None, profile_size_factor: float = 3, unidirectional: bool = True, wavelength: float = 1.55) -> np.ndarray: """Return a non-sparse source array based on the ports defined in the simulation grid. Args: source: Map each port and mode index to a weight to yield a weighted port source. If a dictionary is specified, it can be of the form :code:`{(port_name, mode_idx): weight}` or :code:`{port_name: weight}`, where in the latter case, a default mode index of 0 is used. A source is then created by summing the contributions from all of those ports. profile_size_factor: Factor to rescale the mode view slice compared to the port unidirectional: In FDFD, this specifies whether to send the mode in unidirectionally, determined using the port angle. wavelength: Wavelength for the source Returns: The non-sparse source array that can be used as a source profile for either FDFD or FDTD """ ports = list(self.port.keys()) # if the source is a list of numbers, just assign the appropriate weight to the port's fundamental mode source = {(ports[0], 0): 1} if source is None else source source = {(port, 0): weight for port, weight in zip(ports, source)} if isinstance(source, tuple) else source source_library = self.port_modes(profile_size_factor=profile_size_factor, wavelength=wavelength) sources_to_sum = [] for port_mode, weight in source.items(): port, mode = port_mode axis = source_library[port].prop_axis beta = source_library[port].io.betas[mode] shift = 2 * (np.mod(self.port[port].a, 360) < 180) - 1 src = source_library[port].place(mode, self) * weight src = np.roll(src, axis=axis, shift=2 * shift) if unidirectional: src += np.roll(src, axis=axis, shift=shift) * np.exp(-1j * self.spacing[axis] * beta - 1j * np.pi) sources_to_sum.append(src) return sum(sources_to_sum) if sources_to_sum else np.array([])
[docs] def parse_excitation(self, excitation: Optional[Excitation] = None): return [(port, 0) for port in self.port] if excitation is None else parse_excitation(excitation)
[docs] def get_measure_fn(self, measure_port: Optional[Excitation] = None, wavelength: float = 1.55, profile_size_factor: float = 3, use_jax: bool = False, tm_2d: bool = True) -> Op: """Measure function: measure the fields using the Modes object provided for each port. Args: measure_port: List of port name and mode index at that port. wavelength: The wavelength for the measurement. profile_size_factor: Factor to rescale the mode view slice compared to the port. use_jax: Whether to use jax in the measure function (relevant for simulations). tm_2d: Whether to use TM polarization (applies to the 2D case only). Returns: Callable function that gives port-wise measurements. """ # Set up the port profiles for measurement at each port (by default assumes single mode waveguides) measure_port = self.parse_excitation(measure_port) port_to_modes = self.port_modes(measure_port, profile_size_factor, wavelength) ports = port_to_modes.keys() port_nums = np.arange(len(ports)) angles = [self.port[port].a for port in ports] # We measure polarity, which is the orientation of the measurement interface # to determine whether the wave is entering or leaving the device # The polarity below assumes that ports are near edge of simulation. polarity = 2 * (np.mod(angles, 360) < 180).astype(int) - 1 measure_fns = [port_to_modes[port_name].io.measure_fn(m, use_jax, tm_2d=tm_2d) for port_name, m in measure_port] view_fns = [self.view_fn(port_to_modes[port_name].center, port_to_modes[port_name].size, use_jax) for port_name, _ in measure_port] xp = jnp if use_jax else np def measure_fn(fields): e, h = fields return xp.stack([measure_fns[i](view_fns[i](e), view_fns[i](h))[::polarity[i]] for i in port_nums]).T return measure_fn
[docs] def get_fields_fn(self, src: Union[np.ndarray, Callable], transform_fn: Optional[Callable] = None, tm_2d: bool = True) -> Callable: """Returns a function that yields the fields given a transform function and source. We first initialize the problem solver given two callable functions: 1. A numpy array source :code:`src` 2. The JAX-transformable transform function :code:`transform_fn` (e.g. transform) (identity if None) Args: src: source for the solver (either a callable for time domain or array for frequency domain) transform_fn: Transforms parameters to yield the epsilon function used by jax Returns: A solve function (2d or 3d based on defined :code:`ndim` specified for the instance of :code:`FDFD`) """ raise NotImplementedError("A child class of SimGrid needs to implement get_fields_fn")
[docs] def get_sim_fn(self, src: Union[np.ndarray, Callable], transform_fn: Optional[Callable] = None, tm_2d: bool = True) -> Callable: """Returns a function that measures the sparams and fields. We first initialize the optimization problem solver given two callable functions: 1. A numpy array or callable source :code:`src` (only an array is needed in fdfd). 2. The JAX-transformable transform function :code:`transform_fn` (e.g. transform) (identity if None) We then extract the sparams using the port locations provided in this class. Args: src: source for the solver transform_fn: Transforms parameters to yield the epsilon function used by jax tm_2d: Whether to use TM polarization (applies to the 2D case only). Returns: A solve function (2d or 3d based on defined :code:`ndim` specified for the instance of :code:`FDFD`) """ fields_fn = self.get_fields_fn(src, transform_fn, tm_2d=tm_2d) measure_fn = self.get_measure_fn(use_jax=True, tm_2d=tm_2d) # @jax.jit def sim_fn(rho: jnp.ndarray): fields = fields_fn(rho) measurements = measure_fn(fields) return measurements, fields return sim_fn
[docs] def get_sim_sparams_fn(self, port_name: Optional[str] = None, transform_fn: Optional[Callable] = None, mode_idx: int = 0, profile_size_factor: int = 3, measure_info: Optional[MeasureInfo] = None, tm_2d: bool = True) -> Callable: """Returns a function that measures the sparams and fields. We first initialize the optimization problem solver given a JAX-transformable transform function :code:`transform_fn` (e.g. transform) and the port, mode pair for the input source (used to normalize the output measurements to get the s params). We then extract the sparams using the port locations provided in this class. Args: port_name: Port name for the source mode_idx: Mode index for the source transform_fn: Transforms parameters to yield the epsilon function used by jax (identity if None) profile_size_factor: Profile size factor to rescale the port size to get mode size measure_info: Measurement info consisting of a list of port name and mode index pairs tm_2d: Whether to use TM polarization (applies to the 2D case only). Returns: A solve function (2d or 3d based on defined :code:`ndim` specified for the instance of :code:`FDFD`) """ measure_info = [(name, 0) for name in self.port] if measure_info is None else measure_info source_info = (port_name, mode_idx) if port_name is not None else measure_info[0] fields_fn = self.get_fields_fn(src=self.port_source({source_info: 1}, profile_size_factor=profile_size_factor), transform_fn=transform_fn, tm_2d=tm_2d) measure_fn = self.get_measure_fn(use_jax=True, tm_2d=tm_2d) port_idx = measure_info.index(source_info) # @jax.jit def sim_fn(rho: jnp.ndarray): fields = fields_fn(rho) s_out, s_in = measure_fn(fields) sparams = s_out / s_in[port_idx] return sparams, fields return sim_fn
[docs] def viz_panel(self, img_width: float = 700, xs_axis: int = 2) -> Tuple["pn.layout.Panel", Tuple["Pipe", "Pipe", "Pipe"]]: """Visualizes a 2D slice of a simulation. Args: img_width: Width of the visualization panel for the simulation xs_axis: Axis for the intersection Returns: Panel / dashboard for visualizing the permittivity and field overlay and Tuple of Pipes for feeding data to the panel (e.g. from an optimization or FDTD real-time simulation). """ if not HOLOVIEWS_IMPORTED: raise ImportError("Holoviews not imported, so a viz panel cannot be generated") if self.ndim == 1: raise NotImplementedError("Only implemented for ndim == 2, ndim == 3!") extent = get_extent_2d(self.shape, self.spacing[0]) aspect = (extent[1] - extent[0]) / (extent[3] - extent[2]) bounds = (extent[0], extent[2], extent[1], extent[3]) if self.ndim == 2: eps_norm = self.eps.T / np.max(self.eps.T) else: eps_slice = [slice(None), slice(None), slice(None)] eps_slice[xs_axis] = 0 eps_norm = np.ones_like(self.eps[tuple(eps_slice)].T) bounded_img = lambda data: hv.Image(data, bounds=bounds) eps_pipe = Pipe(data=[]) eps_dmap = hv.DynamicMap(bounded_img, streams=[eps_pipe]) field_pipe = Pipe(data=[]) field_dmap = hv.DynamicMap(bounded_img, streams=[field_pipe]) power_pipe = Pipe(data=[]) power_dmap = hv.DynamicMap(bounded_img, streams=[power_pipe]) eps_pipe.send(eps_norm) field_pipe.send(np.zeros_like(eps_norm)) power_pipe.send(np.zeros_like(eps_norm)) ed, fd, pd = (eps_dmap.opts(alpha=0.2, width=img_width, height=int(img_width / aspect), cmap='gray'), field_dmap.opts(cmap='RdBu', width=img_width, height=int(img_width / aspect)), power_dmap.opts(cmap='hot', width=img_width, height=int(img_width / aspect))) return pn.Row((fd * ed).opts(title=f'{self.name}: Fields (hz)'), (pd * ed).opts(title=f'{self.name}: Power (|hz|²)') ), (eps_pipe, field_pipe, power_pipe)
[docs] def to_2d(self, wavelength: float = None, slab_loc: Size2 = None, tm: bool = True) -> "SimGrid": """Project a 3D simulation into a 2D simulation using the variational 2.5D method. .. seealso:: https://ris.utwente.nl/ws/files/5413011/ishpiers09.pdf Args: wavelength: The wavelength to use for calculating the effective 2.5 FDFD (useful to stabilize multi-wavelength optimizations) slab_loc: Slab location x (if None, the first port location (in alphanumeric order) specified by component) tm: Whether the 2D simulation is a TM mode simulation Returns: A 2D simulation to approximate the 3D simulation. """ # get slab index if not self.ndim == 3: raise RuntimeError("Require ndim = 3 for 2d variational effective index method.") if not wavelength: raise ValueError("Must specify a projection wavelength for the effective index method.") if slab_loc is None: if not self.port: raise ValueError('Must define x, y inputs since the port width and/or locations' 'are not automatically discoverable.') port = list(self.port.values())[0] slab_x, slab_y, _ = self.pml_safe_placement(*port.xyz) slab_loc = (slab_x, slab_y) slab_mode_eps = self.eps[int(slab_loc[0] / self.spacing[0]), int(slab_loc[1] / self.spacing[1])] beta, slab_mode = ModeSolver( size=np.asarray(slab_mode_eps.shape) * self.spacing, spacing=self.spacing[-1], eps=slab_mode_eps, wavelength=wavelength ).solve(1) k0 = (2 * np.pi) * wavelength slab_mode_eps = slab_mode_eps[np.newaxis, np.newaxis, :] if tm: eps_inv_diff = 1 / self.eps - 1 / slab_mode_eps num_1 = 1 / slab_mode_eps @ np.abs(slab_mode) ** 2 den_1 = 1 / self.eps @ np.abs(slab_mode) ** 2 num_2 = eps_inv_diff @ np.abs(np.vstack((np.diff(slab_mode), 0))) ** 2 den_2 = k0 ** 2 * self.eps @ np.abs(slab_mode) ** 2 eps_effective = (beta[0] / k0) ** 2 * num_1 / den_1 + num_2 / den_2 else: eps_diff = self.eps - slab_mode_eps eps_effective = (beta[0] / k0) ** 2 + eps_diff @ np.abs(slab_mode) ** 2 / np.sum(np.abs(slab_mode) ** 2) sim = SimGrid( size=np.asarray(eps_effective.shape) * self.spacing, spacing=self.spacing[:2], eps=eps_effective.real, pml=self.pml_shape[:2], name=self.name ) sim.port = self.port return sim
[docs] def decorate(self, sparams: np.ndarray, fields: np.ndarray) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: """Decorates the :code:`sparams` and :code:`fields` using :code:`xarray.DataArray` Args: sparams: The sparams resulting from a call to the returned callable from :code:`get_sim_sparams_fn`. fields: The fields resulting from a call to the returned callable from :code:`get_sim_sparams_fn`. Returns: The decorated :code:`sparams` and :code:`fields`. """ decorated_sparams = xr.DataArray( data=sparams, dims=["port"], coords={ "port": list(self.port.keys()) } ) e, h = fields decorated_e = xr.DataArray( data=e, dims=["direction", "x", "y", "z"], coords={ "direction": ["x", "y", "z"], "x": self.pos[0][:-1] if self.pos[0].size > 1 else [0], "y": self.pos[1][:-1] if self.pos[1].size > 1 else [0], "z": self.pos[2][:-1] if self.pos[2].size > 1 else [0] } ) # shift h by half yee cell decorated_h = xr.DataArray( data=h, dims=["direction", "x", "y", "z"], coords={ "direction": ["x", "y", "z"], "x": self.pos[0][:-1] + 0.5 * self.spacing3[0] if self.pos[0].size > 1 else [0], "y": self.pos[1][:-1] + 0.5 * self.spacing3[1] if self.pos[1].size > 1 else [0], "z": self.pos[2][:-1] + 0.5 * self.spacing3[2] if self.pos[2].size > 1 else [0] } ) return decorated_sparams, decorated_e, decorated_h
[docs] def fidelity(self, desired_sparams: Union[Dict[Tuple[str, int], np.complex128], Dict[str, np.complex128]], measure_info: List[Tuple[str, int]] = None, insertion_weight: float = 0) -> Callable: """Returns the fidelity for the :code:`sparams`. Args: desired_sparams: The desired sparams, provided in dictionary form mapping port to relative magnitude; if not an ndarray and/or not normalized, it is converted to a normalized ndarray. measure_info: Measurement info consisting of a list of port name and mode index pairs (used to index s) insertion_weight: Renormalize s-params to separate insertion loss and sparams. Then, weight insertion by :code:`insertion_weight` and sparams by :code:`1 - insertion_weight`. If zero, ignore. Returns: The fidelity based on the desired sparams :code:`s`. """ measure_info = [(name, 0) for name in self.port] if measure_info is None else measure_info s = np.zeros(len(measure_info), dtype=np.complex128) for port, weight in desired_sparams.items(): key = (port, 0) if isinstance(port, str) else port s[measure_info.index(key)] = weight s = jnp.array(s / np.linalg.norm(s)) if insertion_weight != 0: def obj(sparams_fields: Tuple[jnp.ndarray, jnp.ndarray]): sparams, fields = sparams_fields insertion_sqrt = jnp.linalg.norm(sparams) cost = -jnp.abs(s @ (sparams / insertion_sqrt)) ** 2 * ( 1 - insertion_weight) - insertion_sqrt ** 2 * insertion_weight return cost, jax.lax.stop_gradient((sparams, fields)) else: def obj(sparams_fields: Tuple[jnp.ndarray, jnp.ndarray]): sparams, fields = sparams_fields return -jnp.abs(s @ sparams) ** 2, jax.lax.stop_gradient((sparams, fields)) return obj