Iterative reconstruction with illumination retrievalΒΆ

[1]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.magnification import M, MT
from holotomocupy.shift import S, ST, registration_shift
from holotomocupy.proc import remove_outliers
from holotomocupy.chunking import gpu_batch
from holotomocupy.recon_methods import CTFPurePhase, multiPaganin
from holotomocupy.proc import dai_yuan, linear
import holotomocupy.chunking as chunking
from holotomocupy.utils import *


%matplotlib inline

chunking.global_chunk = 1
[ ]:
# Init data sizes and parametes of the PXM of ID16A
[2]:
n = 2048
ndist = 4
ntheta = 1

# ID16a setup
ndist = 4

detector_pixelsize = 3e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.28  # [m]
sx0 = 3.7e-4
z1 = np.array([4.584e-3, 4.765e-3, 5.488e-3, 6.9895e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size

norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

z1p = z1[0]  # positions of the probe for reconstruction
z2p = z1-np.tile(z1p, len(z1))
# magnification when propagating from the probe plane to the detector
magnifications2 = (z1p+z2p)/z1p
# propagation distances after switching from the point source wave to plane wave,
distances2 = (z1p*z2p)/(z1p+z2p)
norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# scaled propagation distances due to magnified probes
distances2 = distances2*norm_magnifications2**2
distances2 = distances2*(z1p/z1)**2

# allow padding if there are shifts of the probe
pad = n//16
# sample size after demagnification
ne = int(np.ceil((n+2*pad)/norm_magnifications[-1]/8))*8  # make multiple of 8
[3]:

nref = 20 ndark = 20 data00 = np.zeros([ntheta,ndist,n,n],dtype='uint16') ref00 = np.zeros([nref,ndist,n,n],dtype='uint16') dark00 = np.zeros([ndark,ndist,n,n],dtype='uint16') for k in range(ndist): with h5py.File(f'/data/viktor/SiemensLH_33keV_010nm_holoNfpScan_0{k+1}/SiemensLH_33keV_010nm_holoNfpScan_0{k+1}0000.h5','r') as fid: data00[:,k] = fid['/entry_0000/measurement/data'][:1,1024-n//2:1024+n//2,1024-n//2:1024+n//2][:] with h5py.File(f'/data/viktor/SiemensLH_33keV_010nm_holoNfpScan_0{k+1}/ref_0000.h5','r') as fid: ref00[:,k]=fid['/entry_0000/measurement/data'][:nref,1024-n//2:1024+n//2,1024-n//2:1024+n//2][:] with h5py.File(f'/data/viktor/SiemensLH_33keV_010nm_holoNfpScan_0{k+1}/dark_0000.h5','r') as fid: dark00[:,k]=fid['/entry_0000/measurement/data'][:ndark,1024-n//2:1024+n//2,1024-n//2:1024+n//2][:] # remove outliers for k in range(ndist): radius = 7 threshold = 20000 data00[:,k] = remove_outliers(data00[:,k], radius, threshold) ref00[:,k] = remove_outliers(ref00[:,k], radius, threshold) fig, axs = plt.subplots(2, 2, figsize=(8,6)) im=axs[0,0].imshow(data00[0,0],cmap='gray',vmax=10000) axs[0,0].set_title(f'data for dist {0}') fig.colorbar(im) im=axs[0,1].imshow(data00[0,1],cmap='gray',vmax=10000) axs[0,1].set_title(f'data for dist {1}') fig.colorbar(im) im=axs[1,0].imshow(data00[0,2],cmap='gray',vmax=10000) axs[1,0].set_title(f'data for dist {2}') fig.colorbar(im) im=axs[1,1].imshow(data00[0,3],cmap='gray',vmax=10000) axs[1,1].set_title(f'data for dist {3}') fig.colorbar(im) fig, axs = plt.subplots(1, 2, figsize=(7, 3)) im=axs[0].imshow(ref00[0,0],cmap='gray',vmax=10000) fig.colorbar(im) axs[0].set_title(f'flat field') im=axs[1].imshow(dark00[0,0],cmap='gray',vmax=3000) fig.colorbar(im) axs[1].set_title(f'dark field') plt.show()
../../../_images/notebook_experimental_siemens_star_rec_iterative_4_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_4_1.png
[ ]:
## Take mean for flat and dark
[4]:
ref00 = np.mean(ref00,axis=0)[np.newaxis]
dark00 = np.mean(dark00,axis=0)[np.newaxis]
[ ]:
### Normalize everything wrt to the mean of the reference image
[5]:
mean_value = np.mean(ref00)
dark00 = dark00.astype('float32')/mean_value
ref00 = ref00.astype('float32')/mean_value
data00 = data00.astype('float32')/mean_value


data00 = data00-np.mean(dark00,axis=0)
ref00 = ref00-np.mean(dark00,axis=0)

data00[data00<0] = 0
ref00[ref00<0] = 0
[ ]:
# Find shifts of reference images
[6]:
shifts_ref0 = np.zeros([1, ndist, 2], dtype='float32')
for k in range(ndist):
    shifts_ref0[:, k] = registration_shift(ref00[:, k], ref00[:, 0], upsample_factor=1000)
print(f'Found shifts: \n{shifts_ref0=}')

shifts_ref = np.zeros([ntheta, ndist, 2], dtype='float32')
for k in range(ndist):
    im = np.tile(ref00[0, 0], [ntheta, 1, 1])
    shifts_ref[:, k] = registration_shift(data00[:, k], im, upsample_factor=1000)
print(f'Found shifts: \n{shifts_ref=}')

Found shifts:
shifts_ref0=array([[[ 0.   ,  0.   ],
        [ 0.003,  0.018],
        [-0.003,  0.076],
        [ 0.149,  0.197]]], dtype=float32)
Found shifts:
shifts_ref=array([[[ 0.003, -0.023],
        [ 0.007, -0.006],
        [-0.003,  0.05 ],
        [ 0.154,  0.166]]], dtype=float32)
[ ]:
### Assuming the shifts are calculated, shifts refs back
[7]:
data0 = data00.copy()
ref0 = ref00.copy()
# shifted refs for correction
for k in range(ndist):
    # shift refs back
    ref0[:, k] = ST(ref0[:, k].astype('complex64'), shifts_ref0[:, k]).real

ref0c = np.tile(np.array(ref0), (ntheta, 1, 1, 1))
for k in range(ndist):
    # shift refs the position where they were when collecting data
    ref0c[:, k] = S(ref0c[:, k].astype('complex64'), shifts_ref[:, k]).real

for k in range(4):
    fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    im = axs[0].imshow(ref00[0, 0]-ref00[0, k], cmap='gray',vmax=.03,vmin=-.03)
    axs[0].set_title('ref[0]-ref[k]')
    fig.colorbar(im)
    # ,vmin=-500,vmax=500)
    im = axs[1].imshow(ref0[0, 0]-ref0[0, k], cmap='gray',vmax=.03,vmin=-.03)
    axs[1].set_title('shifted ref[0]-ref[k] ')
    fig.colorbar(im)
../../../_images/notebook_experimental_siemens_star_rec_iterative_12_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_12_1.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_12_2.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_12_3.png
[ ]:
### divide data by the reference image
[8]:
rdata = data0/(ref0+1e-9)

[9]:
for k in range(ndist):
    fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    im=axs[0].imshow(data0[0,k],cmap='gray',vmax=2)
    axs[0].set_title(f'data dist {k}')
    fig.colorbar(im)
    im=axs[1].imshow(rdata[0,k],cmap='gray',vmax=1.1,vmin=0.9)
    axs[1].set_title(f'rdata dist {k}')
    fig.colorbar(im)
../../../_images/notebook_experimental_siemens_star_rec_iterative_15_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_15_1.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_15_2.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_15_3.png
[ ]:
### Scale images
[10]:
rdata_scaled = rdata.copy()

for k in range(ndist):
    rdata_scaled[:, k] = M(rdata_scaled[:, k], 1/norm_magnifications[k], n).real

for k in range(ndist):
    fig, axs = plt.subplots(1, 3, figsize=(12, 3))
    im = axs[0].imshow(rdata_scaled[0, 0], cmap='gray', vmin=0.9, vmax=1.1)
    axs[0].set_title(f'shifted rdata_scaled dist {k}')
    fig.colorbar(im)
    im = axs[1].imshow(rdata_scaled[0, k], cmap='gray', vmin=0.9, vmax=1.1)
    axs[1].set_title(f'shifted rdata_scaled dist {k}')
    fig.colorbar(im)
    im = axs[2].imshow(rdata_scaled[0, k]-rdata_scaled[0, 0], cmap='gray', vmin=-0.1, vmax=0.1)
    axs[2].set_title(f'difference')
    fig.colorbar(im)
../../../_images/notebook_experimental_siemens_star_rec_iterative_17_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_17_1.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_17_2.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_17_3.png
[ ]:
### Align images between different planes

[ ]:
#### Approach 2. Align CTF reconstructions from 1 distance
[11]:
recCTF_1dist = np.zeros([ntheta, ndist, n, n], dtype='float32')
distances_ctf = (distances/norm_magnifications**2)[:ndist]

for k in range(ndist):
    recCTF_1dist[:, k] = CTFPurePhase(
        rdata_scaled[:, k:k+1], distances_ctf[k:k+1], wavelength, voxelsize, 1e-2)

plt.figure(figsize=(4, 4))
plt.title(f'CTF reconstruction for distance {ndist-1}')
plt.imshow(recCTF_1dist[0, -1], cmap='gray',vmax=0.06,vmin=-0.06)
plt.show()

shifts_drift = np.zeros([ntheta, ndist, 2], dtype='float32')

for k in range(1, ndist):
    shifts_drift[:, k] = registration_shift(
        recCTF_1dist[:, k], recCTF_1dist[:, 0], upsample_factor=1000)

# note shifts_drift should be after magnification.
shifts_drift *= norm_magnifications[np.newaxis, :, np.newaxis]


print(f'Found shifts: \n{shifts_drift=}')

../../../_images/notebook_experimental_siemens_star_rec_iterative_20_0.png
Found shifts:
shifts_drift=array([[[  0.       ,   0.       ],
        [-12.4732485,  79.67193  ],
        [ 21.501446 , -35.64609  ],
        [ 72.24373  ,  32.312107 ]]], dtype=float32)
[12]:
rdata_scaled_aligned = rdata_scaled.copy()
for k in range(ndist):
    rdata_scaled_aligned[:, k] = ST(rdata_scaled[:, k], shifts_drift[:, k]/norm_magnifications[k]).real

for k in range(ndist):
    fig, axs = plt.subplots(1, 3, figsize=(11, 3))
    im = axs[0].imshow(rdata_scaled_aligned[0, 0], cmap='gray', vmin=.9, vmax=1.1)
    axs[0].set_title(f'shifted rdata_scaled dist {k}')
    fig.colorbar(im)
    im = axs[1].imshow(rdata_scaled_aligned[0, k], cmap='gray', vmin=.9, vmax=1.1)
    axs[1].set_title(f'shifted rdata_scaled dist {k}')
    fig.colorbar(im)
    im = axs[2].imshow(rdata_scaled_aligned[0, k] - rdata_scaled_aligned[0, 0], cmap='gray', vmin=-0.1, vmax=.1)
    axs[2].set_title(f'difference')
    fig.colorbar(im)
../../../_images/notebook_experimental_siemens_star_rec_iterative_21_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_21_1.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_21_2.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_21_3.png
[ ]:
#### set global shifts to drift shifts
[13]:
shifts = shifts_drift
[14]:
# distances should not be normalized
distances_pag = (distances/norm_magnifications**2)[:ndist]
recMultiPaganin = np.exp(1j*multiPaganin(rdata_scaled_aligned,
                         distances_pag, wavelength, voxelsize,  19, 1e-12))
mshow(np.angle(recMultiPaganin[0]))
# dxchange.write_tiff(np.angle(recMultiPaganin), f'data/rec/MultiPaganin.tiff', overwrite=True)
../../../_images/notebook_experimental_siemens_star_rec_iterative_24_0.png
[ ]:
# Construct operators

[ ]:
#### Forward holo: $d=\mathcal{G}_{z_j}\left((\mathcal{G}_{z_j'}\mathcal{S}_{s'_{kj}}q)\mathcal{M}_j\mathcal{S}_{s_{kj}}\psi_k\right)$,
#### Adjoint holo: $\psi=\sum_j\mathcal{S}^H_{s_j}\mathcal{M}_j^H\left((\mathcal{G}_{z_j'}\mathcal{S}_{s'_{kj}}q)^*\mathcal{G}^H_{z_j}d\right)$,
#### Adjoint holo wrt probe: $q=\sum_{j,k}\mathcal{S}_{s_{kj}'}\mathcal{G}_{z_j'}^H\left((\mathcal{M}_j\mathcal{S}_{s_{kj}}\psi_k)^*\mathcal{G}^H_{z_j}d\right)$


[15]:
@gpu_batch
def _fwd_holo(psi, shifts_ref, shifts, prb):
    # print(prb.shape)
    prb = cp.array(prb)
    shifts_ref = cp.array(shifts_ref)
    shifts = cp.array(shifts)

    data = cp.zeros([psi.shape[0], ndist, n, n], dtype='complex64')
    for i in range(ndist):
        # ill shift for each acquisition
        prbr = cp.tile(prb, [psi.shape[0], 1, 1])

        prbr = S(prbr, shifts_ref[:, i])
        # propagate illumination
        prbr = G(prbr, wavelength, voxelsize, distances2[i])
        # object shift for each acquisition
        psir = S(psi, shifts[:, i]/norm_magnifications[i])

        # scale object
        if ne != n:
            psir = M(psir, norm_magnifications[i]*ne/(n+2*pad), n+2*pad)

        # multiply the ill and object
        psir *= prbr

        # propagate both
        psir = G(psir, wavelength, voxelsize, distances[i])
        data[:, i] = psir[:, pad:n+pad, pad:n+pad]
    return data


def fwd_holo(psi, prb):
    return _fwd_holo(psi, shifts_ref, shifts, prb)


@gpu_batch
def _adj_holo(data, shifts_ref, shifts, prb):
    prb = cp.array(prb)
    shifts_ref = cp.array(shifts_ref)
    shifts = cp.array(shifts)
    psi = cp.zeros([data.shape[0], ne, ne], dtype='complex64')
    for j in range(ndist):
        psir = cp.pad(data[:, j], ((0, 0), (pad, pad), (pad, pad)))

        # propagate data back
        psir = GT(psir, wavelength, voxelsize, distances[j])

        # ill shift for each acquisition
        prbr = cp.tile(prb, [data.shape[0], 1, 1])
        prbr = S(prbr, shifts_ref[:, j])

        # propagate illumination
        prbr = G(prbr, wavelength, voxelsize, distances2[j])

        # multiply the conj ill and object
        psir *= cp.conj(prbr)

        # scale object
        if ne != n:
            psir = MT(psir, norm_magnifications[j]*ne/(n+2*pad), ne)
        # object shift for each acquisition
        psi += ST(psir, shifts[:, j]/norm_magnifications[j])
    return psi


def adj_holo(data, prb):
    return _adj_holo(data, shifts_ref, shifts, prb)


@gpu_batch
def _adj_holo_prb(data, shifts_ref, shifts, psi):
    shifts_ref = cp.array(shifts_ref)
    shifts = cp.array(shifts)
    prb = cp.zeros([data.shape[0], n+2*pad, n+2*pad], dtype='complex64')
    for j in range(ndist):
        prbr = np.pad(data[:, j], ((0, 0), (pad, pad), (pad, pad)))
        psir = psi.copy()

        # propagate data back
        prbr = GT(prbr, wavelength, voxelsize, distances[j])

        # object shift for each acquisition
        psir = S(psir, shifts[:, j]/norm_magnifications[j])

        # scale object
        psir = M(psir, norm_magnifications[j]*ne/(n+2*pad), n+2*pad)

        # multiply the conj object and ill
        prbr *= cp.conj(psir)

        # propagate illumination
        prbr = GT(prbr, wavelength, voxelsize, distances2[j])

        # ill shift for each acquisition
        prbr = ST(prbr, shifts_ref[:, j])
        prb += prbr
    return prb


def adj_holo_prb(data, psi):
    ''' Adjoint Holography operator '''
    return np.sum(_adj_holo_prb(data, shifts_ref, shifts, psi), axis=0)[np.newaxis]


# adjoint test
data = data0.copy()
ref = ref0.copy()
arr1 = np.pad(np.array(data[:, 0]+1j*data[:, 0]).astype('complex64'),
              ((0, 0), (ne//2-n//2, ne//2-n//2), (ne//2-n//2, ne//2-n//2)), 'symmetric')
prb1 = np.array(ref[0, :1]+1j*ref[0, :1]).astype('complex64')
prb1 = np.pad(prb1, ((0, 0), (pad, pad), (pad, pad)))

arr2 = fwd_holo(arr1, prb1)
arr3 = adj_holo(arr2, prb1)
arr4 = adj_holo_prb(arr2, arr1)

print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')
print(f'{np.sum(prb1*np.conj(arr4))}==\n{np.sum(arr2*np.conj(arr2))}')
(60660204+5.663764953613281j)==
(60663612+0.0001766374771250412j)
(60663532-9.498046875j)==
(60663612+0.0001766374771250412j)
[ ]:
#### Forward holo without sample: $d=\mathcal{G}_{z_j}\mathcal{S}_{s'_{j}}q$,
#### Adjoint holo without sample: $q=\sum_j\mathcal{S}^H_{s'_{j}}\mathcal{G}^H_{z_j}d$
[16]:
@gpu_batch
def _fwd_holo0(prb, shifts_ref0):
    shifts_ref0 = cp.array(shifts_ref0)
    data = cp.zeros([1, ndist, n, n], dtype='complex64')
    for j in range(ndist):
        # ill shift for each acquisition
        prbr = S(prb, shifts_ref0[:, j])
        # propagate illumination
        data[:, j] = G(prbr, wavelength, voxelsize, distances[0])[:, pad:n+pad, pad:n+pad]
    return data


def fwd_holo0(prb):
    return _fwd_holo0(prb, shifts_ref0)


@gpu_batch
def _adj_holo0(data, shifts_ref0):
    shifts_ref0 = cp.array(shifts_ref0)
    prb = cp.zeros([1, n+2*pad, n+2*pad], dtype='complex64')
    for j in range(ndist):
        # ill shift for each acquisition
        prbr = cp.pad(data[:, j], ((0, 0), (pad, pad), (pad, pad)))
        # propagate illumination
        prbr = GT(prbr, wavelength, voxelsize, distances[0])
        # ill shift for each acquisition
        prb += ST(prbr, shifts_ref0[:, j])
    return prb


def adj_holo0(data):
    return _adj_holo0(data, shifts_ref0)


# adjoint test
data = data0[0, :].copy()
ref = ref0.copy()
prb1 = np.array(ref[0, :1]+1j*ref[0, :1]).astype('complex64')
prb1 = np.pad(prb1, ((0, 0), (pad, pad), (pad, pad)))
arr2 = fwd_holo0(prb1)
arr3 = adj_holo0(arr2)

print(f'{np.sum(prb1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')
(25824470-0.56884765625j)==
(25824472+0.00010867322271224111j)
[ ]:
#### Approximate the probe by solving the L2-norm minimization problem for reference images
[17]:
def line_search(minf, gamma, fu, fd):
    """ Line search for the step sizes gamma"""
    while (minf(fu)-minf(fu+gamma*fd) < 0 and gamma > 1e-12):
        gamma *= 0.5
    if (gamma <= 1e-12):  # direction not found
        # print('no direction')
        gamma = 0
    return gamma


def cg_holo(ref, init_prb,  pars):
    """Conjugate gradients method for holography"""
    # minimization functional
    def minf(fprb):
        f = np.linalg.norm(np.abs(fprb)-ref)**2
        return f

    ref = np.sqrt(ref)
    prb = init_prb.copy()

    for i in range(pars['niter']):
        fprb0 = fwd_holo0(prb)
        gradprb = adj_holo0(fprb0-ref*np.exp(1j*np.angle(fprb0)))

        if i == 0:
            dprb = -gradprb
        else:
            dprb = dai_yuan(dprb,gradprb,gradprb0)
        gradprb0 = gradprb

        # line search
        fdprb0 = fwd_holo0(dprb)
        gammaprb = line_search(minf, pars['gammaprb'], fprb0, fdprb0)
        prb = prb + gammaprb*dprb

        if i % pars['err_step'] == 0:
            fprb0 = fwd_holo0(prb)
            err = minf(fprb0)
            print(f'{i}) {gammaprb=}, {err=:1.5e}')

        if i % pars['vis_step'] == 0:
            mshow_polar(prb[0])

    return prb


rec_prb0 = np.ones([1, n+2*pad, n+2*pad], dtype='complex64')
ref = ref00.copy()
pars = {'niter': 17, 'err_step': 1, 'vis_step': 16, 'gammaprb': 0.5}
rec_prb0 = cg_holo(ref, rec_prb0, pars)
0) gammaprb=0.5, err=9.20249e+05
../../../_images/notebook_experimental_siemens_star_rec_iterative_31_1.png
1) gammaprb=0.5, err=2.94941e+05
2) gammaprb=0.5, err=2.05872e+05
3) gammaprb=0.5, err=9.08458e+04
4) gammaprb=0.5, err=4.41711e+04
5) gammaprb=0.5, err=2.52774e+04
6) gammaprb=0.5, err=1.13493e+04
7) gammaprb=0.5, err=7.13952e+03
8) gammaprb=0.5, err=3.78985e+03
9) gammaprb=0.5, err=2.67607e+03
10) gammaprb=0.5, err=1.82202e+03
11) gammaprb=0.5, err=1.51369e+03
12) gammaprb=0.5, err=1.25078e+03
13) gammaprb=0.5, err=1.14044e+03
14) gammaprb=0.5, err=1.03703e+03
15) gammaprb=0.5, err=9.79795e+02
16) gammaprb=0.5, err=9.15427e+02
../../../_images/notebook_experimental_siemens_star_rec_iterative_31_3.png
[ ]:
#### Main reconstruction. $\ \sum_k\sum_j||\mathcal{G}_{z_j}((\mathcal{G}_{z'_j}S_{s'_{kj}}q)(M_j S_{s_{kj}}\psi_k))|-\sqrt{d_{kj}}\|^2_2 + \||\mathcal{G}_{z_0}S_{s^r_j}q|-\sqrt{d^r}\|_2^2\to \text{min}_{\psi_k,q}$
[18]:

def line_search(minf, gamma, fu, fu0, fd, fd0): """ Line search for the step sizes gamma""" while (minf(fu, fu0)-minf(fu+gamma*fd, fu0+gamma*fd0) < 0 and gamma >= 1/64): gamma *= 0.5 if (gamma < 1/64): # direction not found # print('no direction') gamma = 0 return gamma @gpu_batch def _gradient(psi, data, shifts_ref, shifts, prb): prb = cp.array(prb) shifts_ref = cp.array(shifts_ref) shifts = cp.array(shifts) res = cp.zeros([psi.shape[0], ne, ne], dtype='complex64') fpsires = cp.zeros([psi.shape[0], ndist, n, n], dtype='complex64') for j in range(ndist): # ill shift for each acquisition prbr = cp.tile(prb, [psi.shape[0], 1, 1]) prbr = S(prbr, shifts_ref[:, j]) # propagate illumination prbr = G(prbr, wavelength, voxelsize, distances2[j]) # object shift for each acquisition psir = S(psi, shifts[:, j]/norm_magnifications[j]) # scale object if ne != n: psir = M(psir, norm_magnifications[j]*ne/(n+2*pad), n+2*pad) # multiply the ill and object psir *= prbr # propagate both psir = G(psir, wavelength, voxelsize, distances[j]) fpsi = psir[:, pad:n+pad, pad:n+pad] fpsires[:, j] = fpsi ########################### psir = fpsi-data[:, j]*np.exp(1j*(np.angle(fpsi))) psir = cp.pad(psir, ((0, 0), (pad, pad), (pad, pad))) # propagate data back psir = GT(psir, wavelength, voxelsize, distances[j]) # ill shift for each acquisition prbr = cp.tile(prb, [psi.shape[0], 1, 1]) prbr = S(prbr, shifts_ref[:, j]) # propagate illumination prbr = G(prbr, wavelength, voxelsize, distances2[j]) # multiply the conj ill and object psir *= cp.conj(prbr) # scale object if ne != n: psir = MT(psir, norm_magnifications[j]*ne/(n+2*pad), ne) # object shift for each acquisition res += ST(psir, shifts[:, j]/norm_magnifications[j]) # probe normalization res /= cp.amax(cp.abs(prb))**2 return [res, fpsires] def gradient(psi, data, prb): ''' Gradient wrt psi''' return _gradient(psi, data, shifts_ref, shifts, prb) @gpu_batch def _gradientprb(psi, data, shifts_ref, shifts, prb): prb = cp.array(prb) shifts_ref = cp.array(shifts_ref) shifts = cp.array(shifts) res = cp.zeros([psi.shape[0], n+2*pad, n+2*pad], dtype='complex64') fpsires = cp.zeros([psi.shape[0], ndist, n, n], dtype='complex64') for j in range(ndist): # ill shift for each acquisition prbr = cp.tile(prb, [psi.shape[0], 1, 1]) prbr = S(prbr, shifts_ref[:, j]) # propagate illumination prbr = G(prbr, wavelength, voxelsize, distances2[j]) # object shift for each acquisition psir = S(psi, shifts[:, j]/norm_magnifications[j]) # scale object if ne != n: psir = M(psir, norm_magnifications[j]*ne/(n+2*pad), n+2*pad) # multiply the ill and object psir *= prbr # propagate both psir = G(psir, wavelength, voxelsize, distances[j]) fpsi = psir[:, pad:n+pad, pad:n+pad] fpsires[:, j] = fpsi ########################### fpsi = fpsi-data[:, j]*np.exp(1j*(np.angle(fpsi))) prbr = np.pad(fpsi, ((0, 0), (pad, pad), (pad, pad))) psir = psi.copy() # propagate data back prbr = GT(prbr, wavelength, voxelsize, distances[j]) # object shift for each acquisition psir = S(psir, shifts[:, j]/norm_magnifications[j]) # scale object psir = M(psir, norm_magnifications[j]*ne/(n+2*pad), n+2*pad) # multiply the conj object and ill prbr *= cp.conj(psir) # propagate illumination prbr = GT(prbr, wavelength, voxelsize, distances2[j]) # ill shift for each acquisition prbr = ST(prbr, shifts_ref[:, j]) res += prbr return [res, fpsires] def gradientprb(psi, data, prb): ''' Gradient wrt prb''' [gradprb, fprb] = _gradientprb(psi, data, shifts_ref, shifts, prb) gradprb = np.sum(gradprb, axis=0)[np.newaxis] return [gradprb, fprb] def cg_holo(data, ref, init, init_prb, pars): """Conjugate gradients method for holography""" # minimization functional def minf(fpsi, fprb): f = np.linalg.norm(np.abs(fpsi)-data)**2 if isinstance(fprb, np.ndarray) or isinstance(fprb, cp.ndarray): f += np.linalg.norm(np.abs(fprb)-ref)**2 return f data = np.sqrt(data) ref = np.sqrt(ref) psi = init.copy() prb = init_prb.copy() conv = np.zeros(1+pars['niter']//pars['err_step']) for i in range(pars['niter']): if pars['upd_psi']: [grad, fpsi] = gradient(psi, data, prb) # Dai-Yuan direction if i == 0: d = -grad else: d = dai_yuan(d,grad,grad0) grad0 = grad fd = fwd_holo(d, prb) gammapsi = line_search(minf, pars['gammapsi'], fpsi, 0, fd, 0) psi = linear(psi,d,1,gammapsi) if pars['upd_prb']: [gradprb, fprb] = gradientprb(psi, data, prb) fprb0 = fwd_holo0(prb) gradprb += adj_holo0(fprb0-ref*np.exp(1j*np.angle(fprb0))) gradprb *= 1/(ntheta+1) # Dai-Yuan direction if i == 0: dprb = -gradprb else: dprb = dai_yuan(dprb,gradprb,gradprb0) gradprb0 = gradprb # line search fdprb = fwd_holo(psi, dprb) fdprb0 = fwd_holo0(dprb) gammaprb = line_search( minf, pars['gammaprb'], fprb, fprb0, fdprb, fdprb0) prb = prb + gammaprb*dprb if i % pars['err_step'] == 0: fprb = fwd_holo(psi, prb) fprb0 = fwd_holo0(prb) err = minf(fprb, fprb0) conv[i//pars['err_step']] = err print(f'{i}) {gammapsi=} {gammaprb=}, {err=:1.5e}') if i % pars['vis_step'] == 0: mshow_polar(psi[0,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]) mshow_polar(prb[0]) return psi, prb, conv # if by chunk on gpu # rec = np.pad(recMultiPaganin, ((0, 0), (ne//2-n//2, ne//2-n//2), # (ne//2-n//2, ne//2-n//2)), 'edge') # rec_prb = rec_prb0.copy() # ref = ref00.copy() # data = data00.copy() #if fully on gpu: rec = cp.pad(cp.array(recMultiPaganin), ((0, 0), (ne//2-n//2, ne//2-n//2), (ne//2-n//2, ne//2-n//2)), 'edge') rec_prb = cp.array(rec_prb0) ref = cp.array(ref00) data = cp.array(data00) pars = {'niter': 65, 'upd_psi': True, 'upd_prb': True, 'err_step': 4, 'vis_step': 32, 'gammapsi': 0.5, 'gammaprb': 0.5} rec, rec_prb, conv = cg_holo(data, ref, rec, rec_prb, pars)
0) gammapsi=0.5 gammaprb=0.5, err=4.73713e+03
../../../_images/notebook_experimental_siemens_star_rec_iterative_33_1.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_33_2.png
4) gammapsi=0.5 gammaprb=0, err=2.19222e+03
8) gammapsi=0.5 gammaprb=0.25, err=1.82810e+03
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[18], line 220
    214 data = cp.array(data00)
    217 pars = {'niter': 65, 'upd_psi': True, 'upd_prb': True,
    218         'err_step': 4, 'vis_step': 32, 'gammapsi': 0.5, 'gammaprb': 0.5}
--> 220 rec, rec_prb, conv = cg_holo(data, ref, rec, rec_prb, pars)

Cell In[18], line 168, in cg_holo(data, ref, init, init_prb, pars)
    165     psi = linear(psi,d,1,gammapsi)
    167 if pars['upd_prb']:
--> 168     [gradprb, fprb] = gradientprb(psi, data, prb)
    169     fprb0 = fwd_holo0(prb)
    170     gradprb += adj_holo0(fprb0-ref*np.exp(1j*np.angle(fprb0)))

Cell In[18], line 131, in gradientprb(psi, data, prb)
    129 def gradientprb(psi, data, prb):
    130     ''' Gradient wrt prb'''
--> 131     [gradprb, fprb] = _gradientprb(psi, data, shifts_ref, shifts, prb)
    132     gradprb = np.sum(gradprb, axis=0)[np.newaxis]
    133     return [gradprb, fprb]

File ~/conda/miniforge3/envs/holotomo/lib/python3.10/site-packages/holotomocupy/chunking.py:38, in gpu_batch.<locals>.inner(*args, **kwargs)
     36 # if array is on gpu then just run the function
     37 if isinstance(args[0], cp.ndarray):
---> 38     out = func(*args, **kwargs)
     39     return out
     41 #else do processing by chunks

Cell In[18], line 114, in _gradientprb(psi, data, shifts_ref, shifts, prb)
    111 psir = S(psir, shifts[:, j]/norm_magnifications[j])
    113 # scale object
--> 114 psir = M(psir, norm_magnifications[j]*ne/(n+2*pad), n+2*pad)
    116 # multiply the conj object and ill
    117 prbr *= cp.conj(psir)

File ~/conda/miniforge3/envs/holotomo/lib/python3.10/site-packages/holotomocupy/chunking.py:38, in gpu_batch.<locals>.inner(*args, **kwargs)
     36 # if array is on gpu then just run the function
     37 if isinstance(args[0], cp.ndarray):
---> 38     out = func(*args, **kwargs)
     39     return out
     41 #else do processing by chunks

File ~/conda/miniforge3/envs/holotomo/lib/python3.10/site-packages/holotomocupy/magnification.py:53, in M(f, magnification, n)
     50 if ne == n and (magnification-1.0) < 1e-6:
     51     return f.copy()
---> 53 m, mu, phi, c2dfftshift, c2dfftshift0 = _init(ne)
     54 # FFT2D
     55 fde = cp.fft.fft2(f*c2dfftshift0)*c2dfftshift0

File ~/conda/miniforge3/envs/holotomo/lib/python3.10/site-packages/holotomocupy/magnification.py:11, in _init(ne)
      9 eps = 1e-3  # accuracy of usfft
     10 mu = -cp.log(eps) / (2 * ne * ne)
---> 11 m = int(cp.ceil(2 * ne * 1 / cp.pi * cp.sqrt(-mu *
     12         cp.log(eps) + (mu * ne) * (mu * ne) / 4)))
     13 # extra arrays
     14 # interpolation kernel
     15 t = cp.linspace(-1/2, 1/2, ne, endpoint=False).astype('float32')

KeyboardInterrupt:
[ ]:
rrec=rec[0,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
plt.imshow(cp.angle(rrec).get(), cmap='gray',vmax=0.13,vmin=0.02)
plt.colorbar()
plt.show()

plt.imshow(cp.angle(rrec[750:750+500,500:1000]).get(), cmap='gray',vmax=0.13,vmin=0.02)
# plt.imshow(cp.angle(rrec[750:750+500,500:1000]).get(), cmap='gray',vmax=0.05,vmin=-0.05)
plt.colorbar()
plt.show()
../../../_images/notebook_experimental_siemens_star_rec_iterative_34_0.png
../../../_images/notebook_experimental_siemens_star_rec_iterative_34_1.png
[ ]: