Source code for save

import numpy as np  
import os  
import sys
from src.data import diffusion_schemes
import nibabel as nb
import time
import logging
import torch 
from datetime import datetime
import shutil

[docs] def _add_noise(signal, snr): """Add Gaussian noise to the forward simulated signal [2]_ :param signal: The forward simulated signal :type signal: np.ndarray :param snr: The desired signal to noise ratio of the b0 image :type snr: float :return: The noise-added signal :rtype: np.ndarray """ sigma = 1.0/snr real_channel_noise = np.random.normal(0, sigma, signal.shape[0]) return signal + real_channel_noise
[docs] def _signal(phase : np.ndarray, SNR : int = None) -> np.ndarray: """Calculates the PGSE signal from the forward simulated spin trajectories [3]_. Note that this computation is executed on the GPU using PyTorch. :param spins: A list of each objects.spin instance corresponding to a spin in the ensemble of random walkers :type spins: list :param bvals: The supplied b-values (diffusion weighting factors) :type bvals: np.ndarray :param bvecs: The supplied diffusion gradients :type bvecs: np.ndarray :param Delta: The diffusion time, in milliseconds :type Delta: float :param dt: The time step parameter, also equal to delta because of the narrow pulse approximation :type dt: float :param SNR: The snr of the b0 image. If a value is not entered, the SNR of the signal is infinite. (Defaults to ``None``) :type SNR: float, optional :return: the forward simulated PGSE signal (``signal``), the initial spin positions (``trajectory_t1m``), and the final spin positions ``trajectory_t2p`` :rtype: np.ndarray """ signal = (1/phase.shape[0]) * np.real(np.nansum(np.exp(1j * phase), axis = 0)) if SNR != None: noised_signal = _add_noise(signal, SNR) return noised_signal else: return signal
[docs] def _generate_signals_and_trajectories(self): """Helper function to organize and store compartment specific and combined trajectories and their incident signals :return: signals with associated labels (``signals_dict``), trajectories with associated labels (``trajectories_dict``) :rtype: dictionaries """ signals_dict = {} trajectories_dict = {} trajectory_t1m = [] trajectory_t2p = [] fiber_spins = [] cells = [] for spin in self.spins: fiber_spins.append([-1 if spin._get_bundle_index() is None else spin._get_bundle_index()]) cells.append([spin._get_cell_index()]) trajectory_t1m.append(spin.position_t1m) trajectory_t2p.append(spin.position_t2p) fiber_spins = np.concatenate(fiber_spins) cells = np.concatenate(cells) trajectory_t1m = np.array(trajectory_t1m) trajectory_t2p = np.array(trajectory_t2p) water = self.water_key logging.info('------------------------------') logging.info(' Signal Generation') logging.info('------------------------------') for i in range(1, int(np.amax(fiber_spins))+1): """ ith Fiber Signal """ ith_fiber_phase = self.phase[fiber_spins == i, :] if ith_fiber_phase.shape[0] > 0: logging.info(f" Computing fiber {i} signal...") Start = time.time() ith_fiber_signal = _signal(ith_fiber_phase) End = time.time() logging.info(' Done! Signal computed in {} sec'.format(round(End-Start),4)) signals_dict[f"fiber_{i}_signal"] = ith_fiber_signal trajectories_dict[f"fiber_{i}_trajectories"] = (trajectory_t1m[fiber_spins == i, :], trajectory_t2p[fiber_spins == i, :]) """ Total Fiber Signal """ total_fiber_phase = self.phase[fiber_spins > -1, :] if total_fiber_phase.shape[0] > 0: logging.info(' Computing total fiber signal...') Start = time.time() total_fiber_signal = _signal(total_fiber_phase) End = time.time() logging.info(' Done! Signal computed in {} sec'.format(round(End-Start),4)) signals_dict['total_fiber_signal'] = total_fiber_signal trajectories_dict['total_fiber_trajectories'] = (trajectory_t1m[fiber_spins > -1, :], trajectory_t2p[fiber_spins > -1, :]) """ Cell Signal """ cell_phase = self.phase[cells > -1, :] if cell_phase.shape[0] > 0: logging.info(' Computing cell signal...') Start = time.time() cell_signal = _signal(cell_phase) End = time.time() logging.info(' Done! Signal computed in {} sec'.format(round(End-Start),4)) signals_dict['cell_signal'] = cell_signal trajectories_dict['cell_trajectories'] = (trajectory_t1m[cells > -1, :], trajectory_t2p[cells > -1, :]) """ Water Signal """ total_water_phase = self.phase[water > -1, :] if total_water_phase.shape[0] > 0: logging.info(' Computing water signal...') Start = time.time() total_water_signal = _signal(total_water_phase) End = time.time() logging.info(' Done! Signal computed in {} sec'.format(round(End-Start),4)) signals_dict['water_signal'] = total_water_signal trajectories_dict['water_trajectories'] = (trajectory_t1m[water > -1, :], trajectory_t2p[water > -1, :]) """ Total Signal """ logging.info(' Computing total signal...') Start = time.time() total_signal = _signal(self.phase) End = time.time() logging.info(' Done! Signal computed in {} sec'.format(round(End-Start),4)) signals_dict['total_signal'] = total_signal trajectories_dict['total_trajectories'] = (trajectory_t1m, trajectory_t2p) return signals_dict, trajectories_dict
[docs] def _save_data(self): """Helper function that saves signals and trajectories to the current directory. """ RESULTS_DIR = self.results_directory SIGNALS_DIR = os.path.join(RESULTS_DIR, 'signals') TRAJ_DIR = os.path.join(RESULTS_DIR, 'trajectories') shutil.copyfile(src = self.cfg_path, dst = os.path.join(RESULTS_DIR, 'input_simulation_parameters.ini')) shutil.copyfile(src = os.path.join(os.getcwd(),'log'), dst = os.path.join(RESULTS_DIR, 'log')) signals_dict, trajectories_dict = _generate_signals_and_trajectories(self) logging.info('------------------------------') logging.info(' Saving outputs to {} ...'.format(RESULTS_DIR)) if not os.path.exists(SIGNALS_DIR): os.mkdir(SIGNALS_DIR) for key in signals_dict.keys(): Nifti = nb.Nifti1Image(signals_dict[key], affine = np.eye(4)) nb.save(Nifti, os.path.join(SIGNALS_DIR, '{}.nii'.format(key))) if not os.path.exists(TRAJ_DIR): os.mkdir(TRAJ_DIR) for key in trajectories_dict.keys(): np.save(os.path.join(TRAJ_DIR, '{}_t1m.npy'.format(key)), trajectories_dict[key][0]) np.save(os.path.join(TRAJ_DIR, '{}_t2p.npy'.format(key)), trajectories_dict[key][1]) logging.info(' Program complete!') logging.info('------------------------------') return