Commit 70fee117 authored by Viktor Reshniak's avatar Viktor Reshniak
Browse files

add `def integrate`

parent 35eaad64
Loading
Loading
Loading
Loading
+111 −0
Original line number Diff line number Diff line
@@ -1028,7 +1028,113 @@ class Histogram(object):
		plt.close('all')


	def integrate(self, peak_estimate='data', background_estimate='data', peak_std=4, bkgr_std=None):
		r'''Integrate peak intensity with background correction

		Inputs
		------
		  hist_ws:	histogram workspace of the peak
		  params:	parameters of the fit
		  peak_estimate:		how to calculate peak,       one of ['fit','data']
		  background_estimate:	how to calculate background, one of ['fit','data']

		Output
		------
		  corrected_intensity
		  corrected_sigma
		'''

		if bkgr_std is None: bkgr_std = peak_std + 3

		# parameters of the model
		nbkgr   = 1 + self.ndims
		npeak   = self.fit_params.size - nbkgr
		ncnt    = self.ndims
		ncov    = (self.ndims*(self.ndims+1))//2
		nangles = (self.ndims*(self.ndims-1))//2
		nskew   = self.ndims

		bkgr     = self.fit_params[:nbkgr]
		intst    = self.fit_params[nbkgr]
		mu       = self.fit_params[1+nbkgr:1+nbkgr+ncnt]
		sqrtP    = self.fit_params[1+nbkgr+ncnt:1+nbkgr+ncnt+ncov]
		angles   = sqrtP[:nangles]
		svals    = sqrtP[nangles:]
		skew     = self.fit_params[1+nbkgr+ncnt+ncov:1+nbkgr+ncnt+ncov+nskew]


		# inverse rotation matrix
		R = rotation_matrix(angles)

		#
		data, points, edges = self.get_grid_data(return_edges=True)
		points = points.reshape((3,-1))
		fit = self.fit_params[0]**2 + gaussian_mixture(self.fit_params[nbkgr:],points,npeaks=1,covariance_parameterization='givens').reshape(data.shape)

		data = data.ravel()
		fit  = fit.ravel()

		# detector_mask = None
		if self.detector_mask is None:
			detector_mask = (data==data)
		else:
			detector_mask = self.detector_mask

		mah_dist = mahalanobis_distance(mu, svals*R, points)

		# true distance mask
		# dist_mask = np.sqrt(np.sum((points.reshape((3,-1))-mu.reshape((3,1)))**2,axis=0)) < 0.3

		# mahalanobis distance masks
		peak_mask = np.logical_and(detector_mask.ravel(), mah_dist<peak_std)
		bkgr_mask = np.logical_and(detector_mask.ravel(), mah_dist>peak_std)
		bkgr_mask = np.logical_and(bkgr_mask,mah_dist<bkgr_std)
		# bkgr_mask = np.logical_and(bkgr_mask,dist_mask)

		peak_vol  = peak_mask.sum()
		bkgr_vol  = bkgr_mask.sum()
		peak2bkgr = peak_vol / bkgr_vol

		# peak_mask_2d = marginalize_2d(peak_mask.reshape(data.shape))
		# fig = plt.figure(constrained_layout=True, figsize=(10,6))
		# axes = fig.subplots(1,3)
		# for i,ax in enumerate(axes):
		# 	yind, xind = [j for j in range(3) if j!=i]
		# 	left, right, bottom, top = edges[xind][0], edges[xind][-1], edges[yind][0], edges[yind][-1]
		# 	ax.imshow(peak_mask_2d[i], interpolation='none', extent=(left,right,bottom,top), origin='lower')
		# # plt.imshow(peak_mask_2d[0], interpolation='none', origin='lower')
		# plt.savefig(f'mask.png')

		# print(f"Estimated  background density {bkgr:.2e}")
		# print(f"Calculated background density {data[bkgr_mask].sum()/bkgr_mask.sum():.2e}")
		# print(bkgr,data[bkgr_mask].mean())

		# total peak intensity
		if peak_estimate=='data':
			total_peak_intensity = data[peak_mask].sum()
		elif peak_estimate=='fit':
			total_peak_intensity = fit[peak_mask].sum()

		# background correction
		if background_estimate=='data':
			total_bkgr_intensity = data[bkgr_mask].sum()
		elif background_estimate=='fit':
			total_bkgr_intensity = bkgr
		peak_bkgr_correction = peak2bkgr    * total_bkgr_intensity
		peak_bkgr_variance   = peak2bkgr**2 * total_bkgr_intensity

		peak_chi2 = ((data-fit)**2/fit)[peak_mask].mean()
		# print('Chi2: ',chi2)

		# intensity = total_peak_intensity - bkgr * peak_vol
		# sigma     = total_peak_intensity + peak_vol**2 * smth
		intensity = total_peak_intensity - peak_bkgr_correction
		sigma     = total_peak_intensity + peak_bkgr_variance
		if peak_estimate=='fit':
			sigma = sigma + ((data-fit)**2/(data+1))[peak_mask].sum()
		sigma = np.sqrt(sigma)

		return intensity, sigma, peak_chi2, total_bkgr_intensity



@@ -1067,6 +1173,11 @@ if __name__ == '__main__':
		h.plot(bins=plot_bins, prefix=str(peak_id))
		print(f"Plot: {time.time()-start} sec")

		start = time.time()
		print(h.integrate())
		print(f"Integrate: {time.time()-start} sec")