Source code for NiBAx.plugins.harmonization.harmonization

from PyQt5.QtGui import *
from PyQt5 import QtGui, QtCore, QtWidgets, uic
import sys, os
import neuroHarmonize as nh
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import numpy as np
import pandas as pd
import re
from NiBAx.core.plotcanvas import PlotCanvas
from NiBAx.core.baseplugin import BasePlugin
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox

from NiBAx.core import iStagingLogger

[docs]logger = iStagingLogger.get_logger(__name__)
[docs]class Harmonization(QtWidgets.QWidget,BasePlugin): #constructor def __init__(self): super(Harmonization,self).__init__() self.datamodel = None root = os.path.dirname(__file__) self.readAdditionalInformation(root) self.ui = uic.loadUi(os.path.join(root, 'harmonization.ui'),self) self.ui.Harmonization_Model_Loaded_Lbl.setHidden(True) self.ui.comboBoxROI = SearchableQComboBox(self.ui) self.plotCanvas = PlotCanvas(self.ui.page_2) self.plotCanvas.axes1 = self.plotCanvas.fig.add_subplot(131) self.plotCanvas.axes2 = self.plotCanvas.fig.add_subplot(132) self.plotCanvas.axes3 = self.plotCanvas.fig.add_subplot(133) self.ui.horizontalLayout_4.insertWidget(0,self.plotCanvas) self.ui.horizontalLayout_3.insertWidget(0,self.comboBoxROI) self.MUSE = None self.ui.stackedWidget.setCurrentIndex(0)
[docs] def getUI(self): return self.ui
[docs] def SetupConnections(self): self.ui.load_harmonization_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked()) if self.datamodel.data is None: self.ui.load_harmonization_model_Btn.setEnabled(False) self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.') self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked()) self.ui.show_data_Btn.clicked.connect(lambda: self.OnShowDataBtnClicked()) self.ui.apply_model_to_dataset_Btn.clicked.connect(lambda: self.OnApplyModelToDatasetBtnClicked()) self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame()) self.ui.comboBoxROI.currentIndexChanged.connect(self.UpdatePlot) self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black") self.datamodel.data_changed.connect(lambda: self.OnDataChanged()) self.ui.apply_model_to_dataset_Btn.setEnabled(False) if ('RES_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames() and 'RAW_RES_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames()): self.ui.show_data_Btn.setEnabled(True) self.ui.show_data_Btn.setStyleSheet("background-color: rgb(230,230,255); color: black") else: self.ui.show_data_Btn.setEnabled(False)
[docs] def LoadHarmonizationModel(self, filename): self.filename = os.path.basename(filename) if filename == "": text_1=('Harmonization model not selected') self.ui.Harmonized_Data_Information_Lbl.setText(text_1) self.ui.Harmonized_Data_Information_Lbl.setObjectName('Missing_label') self.ui.Harmonized_Data_Information_Lbl.setStyleSheet('QLabel#Missing_label {color: red}') else: self.datamodel.harmonization_model = pd.read_pickle(filename) if not (isinstance(self.datamodel.harmonization_model,dict) and 'SITE_labels' in self.datamodel.harmonization_model): text_2=('Selected file is not a viable harmonization model') self.ui.Harmonized_Data_Information_Lbl.setText(text_2) self.ui.Harmonized_Data_Information_Lbl.setObjectName('Error_label') self.ui.Harmonized_Data_Information_Lbl.setStyleSheet('QLabel#Error_label {color: red}') else: self.ui.Harmonization_Model_Loaded_Lbl.setHidden(False) self.ui.Harmonization_Model_Loaded_Lbl.setObjectName('correct_label') self.ui.Harmonization_Model_Loaded_Lbl.setStyleSheet('QLabel#correct_label {color: green}') self.ui.Harmonization_Model_Loaded_Lbl.setText('Harmonization model compatible') self.ui.Harmonized_Data_Information_Lbl.setObjectName('correct_label') self.ui.Harmonized_Data_Information_Lbl.setStyleSheet('QLabel#correct_label {color: black}') model_text1 = (os.path.basename(filename) +' loaded') model_text2 = ('SITES in training set: '+ ' '.join([str(elem) for elem in list(self.datamodel.harmonization_model['SITE_labels'])])) model_text2 = wrap_by_word(model_text2,4) model_text1 += '\n\n'+model_text2 if 'Covariates' in self.datamodel.harmonization_model: covariates = self.datamodel.harmonization_model['Covariates'] model_text3 = ('Harmonization Covariates: '+ str(covariates)) model_text1 += '\n'+model_text3 else: model_text3 = ('Harmonization Covariates Unavailable') model_text1 += '\n'+model_text3 age_max = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['upper_bound'] age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound'] model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']') model_text1 += '\n'+model_text4 self.ui.Harmonized_Data_Information_Lbl.setText(model_text1) if self.datamodel.data is None: self.ui.apply_model_to_dataset_Btn.setEnabled(False) self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection or application.\nReturn to Load and Save Data tab to select data.') else: self.ui.apply_model_to_dataset_Btn.setEnabled(True) self.ui.apply_model_to_dataset_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black") self.ui.stackedWidget.setCurrentIndex(0)
[docs] def OnLoadHarmonizationModelBtnClicked(self): self.filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, 'Open harmonization model file', QtCore.QDir().homePath(), "Pickle files (*.pkl.gz *.pkl)") self.LoadHarmonizationModel(self.filename)
[docs] def PopulateROI(self): MUSEDictDataFrame = self.datamodel.GetMUSEDictDataFrame() _, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries() roiList = list(set(self.datamodel.GetColumnHeaderNames()).intersection(set(MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='SINGLE']['ROI_COL']))) roiList.sort() roiList = ['(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [k]))[0] if k.startswith('MUSE_') else k for k in roiList] if ('MUSE_Volume_301' in list(self.datamodel.harmonization_model['ROIs'])): logger.info('Model includes derived volumes') derivedROIs = list(set(self.datamodel.GetColumnHeaderNames()).intersection(set(MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_COL']))) derivedROIs.sort() derivedROIs = ['(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [k]))[0] if k.startswith('MUSE_') else k for k in derivedROIs] roiList = roiList + derivedROIs else: logger.info('No derived volumes in model') #add the list items to comboBox self.ui.comboBoxROI.blockSignals(True) self.ui.comboBoxROI.clear() self.ui.comboBoxROI.blockSignals(False) self.ui.comboBoxROI.addItems(roiList)
[docs] def OnShowDataBtnClicked(self): self.MUSE = self.datamodel.data self.PopulateROI()
[docs] def OnApplyModelToDatasetBtnClicked(self): self.MUSE= self.DoHarmonization() self.PopulateROI()
[docs] def UpdatePlot(self): #get current selected combobox item currentROI = self.ui.comboBoxROI.currentText() # Translate ROI name back to ROI ID AllItems = [self.ui.comboBoxROI.itemText(i) for i in range(self.ui.comboBoxROI.count())] MUSEDictNAMEtoID, _ = self.datamodel.GetMUSEDictionaries() if currentROI not in AllItems[:-1]: self.ui.comboBoxROI.blockSignals(True) self.ui.comboBoxROI.clear() self.ui.comboBoxROI.blockSignals(False) self.ui.comboBoxROI.addItems(AllItems[:-1]) currentROI = self.ui.comboBoxROI.itemText(0) self.ui.comboBoxROI.setCurrentText(currentROI) print("Invalid input. Setting to %s." % (currentROI)) currentROI = list(map(MUSEDictNAMEtoID.get, [currentROI[7:]]))[0] #create empty dictionary of plot options plotOptions = dict() #fill dictionary with options plotOptions['ROI'] = currentROI self.plotMUSE(plotOptions)
[docs] def plotMUSE(self,plotOptions): self.ui.stackedWidget.setCurrentIndex(1) self.plotCanvas.axes1.clear() self.plotCanvas.axes2.clear() self.plotCanvas.axes3.clear() # select roi currentROI = plotOptions['ROI'] h_res = 'RES_'+currentROI raw_res = 'RAW_RES_'+currentROI selected_gamma = 'gamma_'+currentROI selected_delta = 'delta_'+currentROI if 'isTrainMUSEHarmonization' in self.MUSE: print('Plotting controls only') data = self.MUSE[self.MUSE['isTrainMUSEHarmonization']==1] else: data = self.MUSE data.dropna(subset=[raw_res],inplace=True) data.loc[:,'SITE'] = pd.Categorical(data['SITE']) data.loc[:,'SITE'] = data.SITE.cat.remove_unused_categories() # make palette if 'SITE_colors' in self.datamodel.harmonization_model: print('Creating color palette from model...') cSite = self.datamodel.harmonization_model['SITE_colors'] wanted = set(data.SITE.unique()).intersection(set(self.datamodel.harmonization_model['SITE_labels'])) cSite = { your_key: cSite[your_key] for your_key in wanted } site_extra = list(set(data.SITE.unique())-set(self.datamodel.harmonization_model['SITE_labels'])) palette_extra= sns.color_palette("Set2", n_colors=len(site_extra)) cSite_extra = dict(zip(site_extra,palette_extra)) cSite.update(cSite_extra) else: print('Color palette not available in model. Creating new color palette...') colors=sns.color_palette("cubehelix", n_colors=len(list(data.SITE.unique()))) cSite = dict(zip(list(data.SITE.unique()),colors)) labels = sorted([x + ' (N=' for x in list(cSite.keys())]) sd_raw = data[raw_res].std() sd_h = data[h_res].std() ci_plus_raw = 0.65*sd_raw ci_minus_raw = -0.65*sd_raw ci_plus_h = 0.65*sd_h ci_minus_h = -0.65*sd_h PROPS = { 'boxprops':{'edgecolor':'none'}, 'medianprops':{'color':'black'}, 'whiskerprops':{'color':'black'}, 'capprops':{'color':'black'} } parameters = self.parameters parameters = parameters[parameters.index.isin(list(cSite.keys()))] parameters['SITE']=parameters.index gamma_values = [x for x in parameters[selected_gamma].values.round(3).tolist()] gamma_values = [str("{:.3f}".format(x)) for x in gamma_values] delta_values = [x for x in parameters[selected_delta].values.round(3).tolist()] delta_values = [str("{:.3f}".format(x)) for x in delta_values] parameters.loc[:,'gamma_values'] = gamma_values parameters.loc[:,'delta_values'] = delta_values self.plotCanvas.axes1.get_figure().set_tight_layout(True) self.plotCanvas.axes1.set_xlim(-4*sd_raw, 4*sd_raw) sns.set(style='white') a = sns.boxplot(x=raw_res, y="SITE", data=data, palette=cSite,linewidth=.25,showfliers = False,ax=self.plotCanvas.axes1,**PROPS) nobs1 = data['SITE'].value_counts().sort_index(ascending=True).values nobs1 = [str(x) for x in nobs1.tolist()] nobs1 = [i for i in nobs1] labels = [''.join(i) for i in zip(labels, nobs1)] labels = [x + ')' for x in labels] self.plotCanvas.axes1.axvline(ci_plus_raw,color='grey',ls='--') self.plotCanvas.axes1.axvline(ci_minus_raw,color='grey',ls='--') self.plotCanvas.axes1.yaxis.set_ticks_position('left') self.plotCanvas.axes1.xaxis.set_ticks_position('bottom') a.tick_params(axis='both', which='major', length=4) a.set_xlabel('Residuals before harmonization') self.plotCanvas.axes2.get_figure().set_tight_layout(True) upper_limit = np.nanmax(parameters[selected_gamma]+parameters[selected_delta]) lower_limit = np.nanmin(parameters[selected_gamma]-parameters[selected_delta]) limit = max(abs(upper_limit),abs(lower_limit)) self.plotCanvas.axes2.set_xlim(-limit,limit) self.plotCanvas.axes2.set_ylim(self.plotCanvas.axes1.get_ylim()) sns.set(style='white') self.plotCanvas.axes2.errorbar(x=parameters[selected_gamma],y=parameters['SITE'],xerr=parameters[selected_delta],ecolor='black',elinewidth=0.25,capsize=0.25,zorder=-1,fmt='none') kws = {"s": 4, "facecolor": "black", "linewidth": 0.5} color = parameters[parameters[selected_gamma].notna()]['SITE'].map(cSite) b = sns.scatterplot(x=selected_gamma,y='SITE',data=parameters.reset_index(),marker='s',edgecolor=color,zorder=1,**kws,ax=self.plotCanvas.axes2,legend=False) b.text(-0.05,self.plotCanvas.axes2.get_yticks()[0]-1.3,'Location (\u03B3*)',ha='right',fontsize='small') b.text(0,self.plotCanvas.axes2.get_yticks()[0]-1.3,'|',ha='center',fontsize='small') b.text(0.05,self.plotCanvas.axes2.get_yticks()[0]-1.3,'Scale (\u03B4*)',ha='left',fontsize='small') for count,site in enumerate(parameters['SITE']): self.plotCanvas.axes2.get_yticks() if 'nan' in parameters.loc[site]['gamma_values']: b.text(-0.05,self.plotCanvas.axes2.get_yticks()[count]-0.2,' nan',fontsize='x-small',ha='right') else: b.text(-0.05,self.plotCanvas.axes2.get_yticks()[count]-0.2,parameters.loc[site]['gamma_values'],fontsize='x-small',ha='right') b.text(0.05,self.plotCanvas.axes2.get_yticks()[count]-0.2,parameters.loc[site]['delta_values'],fontsize='x-small',ha='left') b.set_xlabel('') b.set_ylabel('') b.set(yticklabels=[]) self.plotCanvas.axes2.xaxis.set_ticks_position('bottom') b.tick_params(axis='both',left=False,right=False, length=4) self.plotCanvas.axes2.axvline(0,color='black',linewidth=0.25) sns.despine(ax=self.plotCanvas.axes2, left=True) # get title as ROI name relative to scale/shift label _, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries() title = currentROI if title.startswith('MUSE_'): title = '(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI]))[0] if title.startswith('WMLS_'): title = '(WMLS) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('WMLS_', 'MUSE_')]))[0] self.plotCanvas.axes2.get_figure().subplots_adjust(top=.8) self.plotCanvas.axes2.get_figure().suptitle(title) self.plotCanvas.axes3.get_figure().set_tight_layout(True) self.plotCanvas.axes3.set_xlim(-4*sd_raw, 4*sd_raw) self.plotCanvas.axes3.set_ylim( self.plotCanvas.axes1.get_ylim() ) sns.set(style='white') c = sns.boxplot(x=h_res, y="SITE", data=data, palette=cSite,linewidth=0.25,showfliers = False,ax=self.plotCanvas.axes3,**PROPS) nobs2 = data['SITE'].value_counts().sort_index(ascending=True).values nobs2 = [str(x) for x in nobs2.tolist()] nobs2 = [i for i in nobs2] if nobs1 != nobs2: print('not equal sample sizes') a.set_yticklabels(labels) self.plotCanvas.axes3.axvline(ci_plus_h,color='grey',ls='--') self.plotCanvas.axes3.axvline(ci_minus_h,color='grey',ls='--') self.plotCanvas.axes3.yaxis.set_ticks_position('left') self.plotCanvas.axes3.xaxis.set_ticks_position('bottom') c.tick_params(axis='both', which='major', length=4) c.set_xlabel('Residuals after harmonization') c.set_ylabel('') c.set(yticklabels=[]) sns.despine(ax=self.plotCanvas.axes1, trim=True) sns.despine(ax=self.plotCanvas.axes3, trim=True) self.plotCanvas.draw()
[docs] def OnAddToDataFrame(self): print('Saving modified data to pickle file...') ROI_list = list(self.datamodel.harmonization_model['ROIs']) if ('MUSE_Volume_301' not in ROI_list): logger.info('No derived volumes in model') MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame() Derived_numbers = list(MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']) Derived_MUSE_Volumes = list('MUSE_Volume_' + str(x) for x in Derived_numbers) ROI_list = ROI_list + Derived_MUSE_Volumes ROI_list.remove('MUSE_Volume_702') else: logger.info('Model includes derived volumes') H_ROIs = list('H_' + str(x) for x in ROI_list) ROIs_ICV_Sex_Residuals = ['RES_ICV_Sex_' + x for x in self.datamodel.harmonization_model['ROIs']] ROIs_Residuals = ['RES_' + x for x in self.datamodel.harmonization_model['ROIs']] RAW_Residuals = ['RAW_RES_' + x for x in self.datamodel.harmonization_model['ROIs']] if ('H_MUSE_Volume_47' not in self.datamodel.data.keys()): self.datamodel.data.loc[:,H_ROIs] = self.MUSE[H_ROIs] self.datamodel.data.loc[:,ROIs_ICV_Sex_Residuals] = self.MUSE[ROIs_ICV_Sex_Residuals] self.datamodel.data.loc[:,ROIs_Residuals] = self.MUSE[ROIs_Residuals] self.datamodel.data.loc[:,RAW_Residuals] = self.MUSE[RAW_Residuals] self.datamodel.data_changed.emit()
[docs] def OnDataChanged(self): self.ui.stackedWidget.setCurrentIndex(0) self.plotCanvas.axes1.clear() self.plotCanvas.axes2.clear() self.plotCanvas.axes3.clear() self.MUSE=None if ('RES_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames() and 'RAW_RES_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames()): self.ui.show_data_Btn.setEnabled(True) self.ui.show_data_Btn.setStyleSheet("background-color: lightBlue; color: white") else: self.ui.show_data_Btn.setEnabled(False) if self.datamodel.data is None: self.ui.load_harmonization_model_Btn.setEnabled(False) self.ui.apply_model_to_dataset_Btn.setEnabled(False) self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.') else: self.ui.load_harmonization_model_Btn.setEnabled(True) if self.datamodel.harmonization_model is None: self.ui.apply_model_to_dataset_Btn.setEnabled(False) self.ui.Harmonized_Data_Information_Lbl.setText('No harmonization model has been selected') else: self.ui.apply_model_to_dataset_Btn.setEnabled(True) self.ui.apply_model_to_dataset_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black") self.ui.Harmonized_Data_Information_Lbl.setObjectName('correct_label') self.ui.Harmonized_Data_Information_Lbl.setStyleSheet('QLabel#correct_label {color: black}') model_text1 = (self.filename +' loaded') model_text2 = ('SITES in training set: '+ ' '.join([str(elem) for elem in list(self.datamodel.harmonization_model['SITE_labels'])])) model_text2 = wrap_by_word(model_text2,4) model_text1 += '\n\n'+model_text2 if 'Covariates' in self.datamodel.harmonization_model: covariates = self.datamodel.harmonization_model['Covariates'] model_text3 = ('Harmonization Covariates: '+ str(covariates)) model_text1 += '\n'+model_text3 else: model_text3 = ('Harmonization Covariates Unavailable') model_text1 += '\n'+model_text3 age_max = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['upper_bound'] age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound'] model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']') model_text1 += '\n'+model_text4 self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
[docs] def DoHarmonization(self): print('Running harmonization.') if 'Covariates' in self.datamodel.harmonization_model: covariates = self.datamodel.harmonization_model['Covariates'] logger.info('Covariates hard-coded in model.') else: covariates = ['SITE','Age','Sex','DLICV_baseline'] logger.info('Covariates default to `SITE`, `Age`, `Sex`, and `DLICV_baseline`.') # create list of new SITEs to loop through new_sites = set(self.datamodel.data['SITE'].value_counts().index.tolist())^set(self.datamodel.harmonization_model['SITE_labels']) covars = self.datamodel.data[['SITE','Age','Sex','DLICV_baseline']].reset_index(drop=True).copy() covars.loc[:,'Sex'] = covars['Sex'].map({'M':1,'F':0}) covars.loc[covars.Age>100, 'Age']=100 # Parameter table for plotting gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']] delta_ROIs = ['delta_'+ x for x in self.datamodel.harmonization_model['ROIs']] calculated_gamma = pd.DataFrame([]) calculated_delta = pd.DataFrame([]) if 'UseForComBatGAMHarmonization' in self.datamodel.data.columns: sites_to_harmonize = [] for site in new_sites: dataToHarmonize = np.array(self.datamodel.data['SITE']==site,dtype=bool) training = np.array(self.datamodel.data['UseForComBatGAMHarmonization'].notnull(),dtype=bool) new_site_is_train = np.logical_and(dataToHarmonize, training) new_site_is_train = new_site_is_train[~np.isnan(new_site_is_train).any(axis=0)] if np.count_nonzero(new_site_is_train)<5: site_gamma = pd.DataFrame(np.nan,columns=gamma_ROIs,index=[site]) calculated_gamma = calculated_gamma.append(site_gamma) site_delta = pd.DataFrame(np.nan,columns=delta_ROIs,index=[site]) calculated_delta = calculated_delta.append(site_delta) print('New site `'+site+'` has less than 25 reference data points. Skipping harmonization.') continue else: print('Harmonizing '+ site) sites_to_harmonize.append(site) if not sites_to_harmonize: print('No new sites that meet out-of-sample harmonization requirement. Proceeding with harmonization.') bayes_data, stand_mean = nh.harmonizationApply(self.datamodel.data[[x for x in self.datamodel.harmonization_model['ROIs']]].values, covars, self.datamodel.harmonization_model,True) gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']] delta_ROIs = ['delta_'+ x for x in self.datamodel.harmonization_model['ROIs']] model_gamma= pd.DataFrame(self.datamodel.harmonization_model['gamma_star'],columns=gamma_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']]) model_delta = pd.DataFrame(self.datamodel.harmonization_model['delta_star'],columns=delta_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']]) self.parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index() else: oos_data = self.datamodel.data[self.datamodel.data['SITE'].isin(sites_to_harmonize)].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values oos_covars = self.datamodel.data[self.datamodel.data.SITE.isin(sites_to_harmonize)].dropna(subset=covariates)[covariates] oos_covars.loc[:,'Sex'] = oos_covars['Sex'].map({'M':1,'F':0}) self.model, _ = nh.harmonizationLearn(oos_data, oos_covars, smooth_terms=['Age'], smooth_term_bounds=(np.floor(np.min(self.datamodel.data.Age)),np.ceil(np.max(self.datamodel.data.Age))), orig_model=self.datamodel.harmonization_model,seed=20220601) bayes_data, stand_mean = nh.harmonizationApply(self.datamodel.data[[x for x in self.datamodel.harmonization_model['ROIs']]].values, covars, self.model,True) gamma_ROIs = ['gamma_'+ x for x in self.model['ROIs']] delta_ROIs = ['delta_'+ x for x in self.model['ROIs']] model_gamma= pd.DataFrame(self.model['gamma_star'],columns=gamma_ROIs,index=[x for x in self.model['SITE_labels']]) model_delta = pd.DataFrame(self.model['delta_star'],columns=delta_ROIs,index=[x for x in self.model['SITE_labels']]) self.parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index() else: print('Skipping out-of-sample harmonization because `UseForComBatGAMHarmonization` does not exist.') for site in new_sites: site_gamma = pd.DataFrame(np.nan,columns=gamma_ROIs,index=[site]) calculated_gamma = calculated_gamma.append(site_gamma) site_delta = pd.DataFrame(np.nan,columns=delta_ROIs,index=[site]) calculated_delta = calculated_delta.append(site_delta) # populate calculated parameter table calculated_parameters = pd.concat([calculated_gamma,calculated_delta],axis=1).sort_index() gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']] delta_ROIs = ['delta_'+ x for x in self.datamodel.harmonization_model['ROIs']] model_gamma= pd.DataFrame(self.datamodel.harmonization_model['gamma_star'],columns=gamma_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']]) model_delta = pd.DataFrame(self.datamodel.harmonization_model['delta_star'],columns=delta_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']]) model_parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index() self.parameters = pd.concat([model_parameters,calculated_parameters],axis=0).sort_index() bayes_data, stand_mean = nh.harmonizationApply(self.datamodel.data[[x for x in self.datamodel.harmonization_model['ROIs']]].values, covars, self.datamodel.harmonization_model,True) Raw_ROIs_Residuals = self.datamodel.data[self.datamodel.harmonization_model['ROIs']].values - stand_mean if 'isTrainMUSEHarmonization' in self.datamodel.data.columns: muse = pd.concat([self.datamodel.data['isTrainMUSEHarmonization'].reset_index(drop=True).copy(), covars, pd.DataFrame(bayes_data, columns=['H_' + s for s in self.datamodel.harmonization_model['ROIs']])],axis=1) else: muse = pd.concat([covars,pd.DataFrame(bayes_data, columns=['H_' + s for s in self.datamodel.harmonization_model['ROIs']])],axis=1) # harmonize derived volumes if ('MUSE_Volume_301' not in list(self.datamodel.harmonization_model['ROIs'])): logger.info('No derived volumes in model.') logger.info('Calculating using derived mapping dictionary.') MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame() muse_mappings = self.datamodel.GetDerivedMUSEMap() for ROI in MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']: single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float) single_ROIs = ['H_MUSE_Volume_%0d' % x for x in single_ROIs] muse['H_MUSE_Volume_%d' % ROI] = muse[single_ROIs].sum(axis=1,skipna=False) muse.drop(columns=['H_MUSE_Volume_702'], inplace=True) start_index = len(self.datamodel.harmonization_model['SITE_labels']) sex_icv_effect = np.dot(muse[['Sex','DLICV_baseline']].copy(), self.datamodel.harmonization_model['B_hat'][start_index:(start_index+2),:]) ROIs_ICV_Sex_Residuals = ['RES_ICV_Sex_' + x for x in self.datamodel.harmonization_model['ROIs']] muse.loc[:,ROIs_ICV_Sex_Residuals] = muse[['H_' + x for x in self.datamodel.harmonization_model['ROIs']]].values - sex_icv_effect muse.loc[:,'Sex'] = muse['Sex'].map({1:'M',0:'F'}) ROIs_Residuals = ['RES_' + x for x in self.datamodel.harmonization_model['ROIs']] RAW_Residuals = ['RAW_RES_' + x for x in self.datamodel.harmonization_model['ROIs']] muse.loc[:,ROIs_Residuals] = bayes_data-stand_mean muse.loc[:,RAW_Residuals] = Raw_ROIs_Residuals print('Harmonization done.') return muse
[docs]def wrap_by_word(s, n): a = s.split() ret = '' for i in range(0, len(a), n): ret += ' '.join(a[i:i+n]) + '\n' return ret