import jax.numpy as jnp
import numpy as np
from jax.scipy.signal import convolve as conv
from skimage.draw import disk
from .typing import Union, List
from .utils import Box
[docs]def get_smooth_fn(beta: float, radius: float, eta: float = 0.5):
"""Using the sigmoid function and convolutional kernel provided in jax, we return a function that
effectively binarizes the design respectively and smooths the density parameters.
Args:
beta: A multiplicative factor in the tanh function to effectively define how binarized the design should be
radius: The radius of the convolutional kernel for smoothing
eta: The average value of the design
Returns:
The smoothing function
"""
rr, cc = disk((radius, radius), radius + 1)
kernel = np.zeros((2 * radius + 1, 2 * radius + 1), dtype=np.float64)
kernel[rr, cc] = 1
kernel = kernel / kernel.sum()
def smooth(rho: jnp.ndarray):
rho = conv(rho, kernel, mode='same')
return jnp.divide(jnp.tanh(beta * eta) + jnp.tanh(beta * (rho - eta)),
jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta)))
return smooth
[docs]def get_symmetry_fn(ortho_x: bool = False, ortho_y: bool = False, diag_p: bool = False, diag_n: bool = False,
avg: bool = False):
"""Get the array-based reflection symmetry function based on orthogonal or diagonal axes.
Args:
ortho_x: symmetry along x-axis (axis 0)
ortho_y: symmetry along y-axis (axis 1)
diag_p: symmetry along positive ([1, 1] plane) diagonal (shape of params must be square)
diag_n: symmetry along negative ([1, -1] plane) diagonal (shape of params must be square)
avg: Whether the symmetry should take the average (applies to ortho symmetries ONLY)
Returns:
The overall symmetry function
"""
identity = (lambda x: x)
diag_n_fn = (lambda x: (x + x.T) / 2) if diag_p else identity
diag_p_fn = (lambda x: (x + x[::-1, ::-1].T) / 2) if diag_n else identity
if avg:
ortho_x_fn = (lambda x: (x + x[::-1]) / 2) if ortho_x else identity
ortho_y_fn = (lambda x: (x + x[:, ::-1]) / 2) if ortho_y else identity
else:
ortho_x_fn = (lambda x: x.at[-(x.shape[0] // 2 + 1):, :].set(x[:x.shape[0] // 2 + 1:, :][::-1, :])) if ortho_x else identity
ortho_y_fn = (lambda x: x.at[:, -(x.shape[1] // 2 + 1):].set(x[:, :x.shape[1] // 2 + 1][:, ::-1])) if ortho_y else identity
return lambda x: diag_p_fn(diag_n_fn(ortho_x_fn(ortho_y_fn(x))))
[docs]def get_mask_fn(rho_init: jnp.ndarray, box: Union[Box, List[Box]]):
"""Given an initial param set, this function defines the box region(s) where the params are allowed to change.
Args:
rho_init: initial rho definition
box: Box (or list of boxes) defines position and orientation of the design region(s)
Returns:
The mask function
"""
mask = box.mask(rho_init) if isinstance(box, Box) else (sum([b.mask(rho_init) for b in box]) > 0).astype(np.float)
return lambda rho: jnp.array(rho_init) * (1 - mask) + rho * mask