Skip to content

Commit fe3c69b

Browse files
authored
Merge pull request #200 from spice-herald/dt0_estimate
Dt0 estimate
2 parents c8cf437 + 2d53728 commit fe3c69b

File tree

6 files changed

+267
-43
lines changed

6 files changed

+267
-43
lines changed

qetpy/core/_of_nsmb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ def of_nsmb_con(pulset, phi, Pfs, P, sbtemplatef, sbtemplate, psddnu, fs, indwin
10371037
# check the gradient at the best fit polarity constrained min
10381038

10391039
# note that we cast ind_tdel_New to an int
1040+
ind_tdel_New_nowindow = int(np.asarray(ind_tdel_New_nowindow).item())
10401041
Pt_tmin, iPt_tmin = of_nsmb_getPt(Pfs, P, combind=((2**nsb)-1),
10411042
bindelay=int(ind_tdel_New_nowindow),bitcomb=bitcomb,bitmask=None)
10421043

@@ -1109,6 +1110,7 @@ def _interpchi2(indmin, chi2, amp, time):
11091110
Amplitude at interpolated chi^2 minimum
11101111
11111112
"""
1113+
indmin = int(np.asarray(indmin).item())
11121114

11131115
t_to_interp = time[int(indmin-1):int(indmin+2)]
11141116

qetpy/core/didv/_base_didv.py

Lines changed: 194 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
"compleximpedance",
1212
"complexadmittance",
1313
"squarewaveresponse",
14+
"estimate_dt0_fft",
15+
"estimate_dt0_cross_spectrum"
1416
]
1517

1618

@@ -293,6 +295,171 @@ def squarewaveresponse(t, sgamp, sgfreq, params, dutycycle=0.5,
293295

294296
return response
295297

298+
299+
def estimate_dt0_cross_spectrum(x, fs, sgfreq, duty=0.5, sgamp=1.0,
300+
nharm_max=9, band_bins=1, zpad=1):
301+
"""
302+
Estimate time delay dt0 between data x and an ideal square wave
303+
by fitting the slope of the cross-spectrum phase vs frequency.
304+
305+
x : 1D signal (average trace is best)
306+
fs : sample rate [Hz]
307+
sgfreq : square-wave frequency [Hz]
308+
duty : duty cycle in [0,1]
309+
sgamp : peak-to-peak square amplitude (only scales weights)
310+
nharm_max : use odd harmonics m = 1,3,5,... up to this m
311+
band_bins : include ±band_bins FFT bins around each harmonic
312+
zpad : zero-padding factor (e.g., 4 gives 4× frequency resolution)
313+
314+
returns dt0 in seconds wrapped to (-T/2, T/2], where T=1/sgfreq
315+
"""
316+
x = np.asarray(x)
317+
N = x.size
318+
T = 1.0 / sgfreq
319+
320+
# optional zero-padding for finer frequency grid
321+
Nz = int(zpad * N)
322+
def rfft(x):
323+
return np.fft.rfft(x, n=Nz)
324+
freqs = np.fft.rfftfreq(Nz, d=1/fs)
325+
326+
# build ideal square (same length), no delay
327+
t = np.arange(N) / fs
328+
# exact Fourier construction causes leakage if non-integer cycles;
329+
# time-domain square is fine for cross-spectrum:
330+
level_hi = +sgamp/2.0
331+
level_lo = -sgamp/2.0
332+
phase_in_period = np.mod(t, T)
333+
square = np.where(phase_in_period < duty*T, level_hi, level_lo)
334+
335+
# remove means to kill DC dominance
336+
X = rfft(x - np.mean(x))
337+
S = rfft(square - np.mean(square))
338+
339+
# cross-spectrum
340+
Gamma = X * np.conj(S)
341+
342+
# pick harmonic neighborhoods
343+
use_f = []
344+
use_phi = []
345+
use_w = []
346+
347+
# list of harmonic center freqs (odd only typical for duty=0.5, but we don?t rely on that)
348+
ms = list(range(1, nharm_max+1)) # 1..nharm_max
349+
for m in ms:
350+
f0 = m * sgfreq
351+
if f0 >= freqs[-1]:
352+
break
353+
k0 = int(np.round(f0 / (fs / Nz)))
354+
kmin = max(1, k0 - band_bins)
355+
kmax = min(len(freqs)-1, k0 + band_bins)
356+
idx = np.arange(kmin, kmax+1)
357+
# discard bins where either spectrum is too weak
358+
mag = np.abs(Gamma[idx])
359+
good = mag > (1e-12 * np.max(np.abs(Gamma))) # guard
360+
idx = idx[good]
361+
if idx.size == 0:
362+
continue
363+
use_f.append(freqs[idx])
364+
use_phi.append(np.angle(Gamma[idx]))
365+
use_w.append(mag[good])
366+
367+
if not use_f:
368+
return 0.0
369+
370+
f = np.concatenate(use_f)
371+
phi = np.concatenate(use_phi)
372+
w = np.concatenate(use_w)
373+
374+
# unwrap phase by sorting on frequency first
375+
order = np.argsort(f)
376+
f = f[order]
377+
phi = np.unwrap(phi[order])
378+
w = w[order]
379+
380+
# weighted linear fit: phi ? a*f + b => dt0 = -a/(2?)
381+
# Solve with weights
382+
W = np.diag(w)
383+
A = np.vstack([f, np.ones_like(f)]).T
384+
# (A^T W A)^{-1} A^T W y
385+
AtW = A.T @ W
386+
sol = np.linalg.lstsq(AtW @ A, AtW @ phi, rcond=None)[0]
387+
slope = sol[0]
388+
389+
dt0 = -slope / (2*np.pi)
390+
391+
# wrap into (-T/2, T/2]
392+
while dt0 <= -T/2: dt0 += T
393+
while dt0 > T/2: dt0 -= T
394+
return float(dt0)
395+
396+
397+
398+
def estimate_dt0_fft(x, t, sgfreq, duty=0.5):
399+
"""
400+
Estimate dt0 between average trace x(t) and an ideal square-wave template
401+
using circular FFT cross-correlation. Returns dt0 in seconds in (-T/2, T/2].
402+
"""
403+
x = np.asarray(x)
404+
t = np.asarray(t)
405+
N = x.size
406+
dt = t[1]-t[0]
407+
fs = 1.0/dt
408+
T = 1.0/sgfreq
409+
410+
# build ideal template that matches your model polarity
411+
phase = (t % T) / T
412+
tmpl = np.where(phase < duty, 1.0, -1.0)
413+
414+
# remove means (very important)
415+
x0 = x - x.mean()
416+
tmpl0 = tmpl - tmpl.mean()
417+
418+
# FFT xcorr (circular)
419+
X = np.fft.rfft(x0)
420+
H = np.fft.rfft(tmpl0)
421+
r = np.fft.irfft(X * np.conj(H), n=N)
422+
423+
# allow negative lags by rolling the correlation
424+
lags = np.arange(N)
425+
lags = np.where(lags <= N//2, lags, lags-N)
426+
kmax = int(np.argmax(r))
427+
lag_samples = lags[kmax]
428+
429+
# sub-sample refine by quadratic fit around peak (if possible)
430+
if 1 <= kmax < N-1:
431+
y_m, y_0, y_p = r[kmax-1], r[kmax], r[kmax+1]
432+
denom = (y_m - 2*y_0 + y_p)
433+
if denom != 0:
434+
delta = 0.5*(y_m - y_p)/denom # in samples
435+
lag_samples = lag_samples + delta
436+
437+
# wrap result into (-T/2, T/2] to fix period ambiguity
438+
dt0 = lag_samples / fs
439+
while dt0 <= -T/2: dt0 += T
440+
while dt0 > T/2: dt0 -= T
441+
442+
# resolve the 180° ambiguity for duty=0.5:
443+
# pick the sign that maximizes correlation with the *high* level
444+
# (i.e., prefer aligning +1 with the actual 'high' in x)
445+
t_shift = t - dt0
446+
phase_s = (t_shift % T) / T
447+
tmpl_s = np.where(phase_s < duty, 1.0, -1.0)
448+
if np.sum(x0 * tmpl_s) < 0:
449+
# flip by half-period if anti-correlated
450+
dt0 += (T/2 if dt0 <= 0 else -T/2)
451+
452+
# final wrap again to (-T/2, T/2]
453+
while dt0 <= -T/2: dt0 += T
454+
while dt0 > T/2: dt0 -= T
455+
return dt0
456+
457+
458+
459+
460+
461+
462+
296463
class _BaseDIDV(object):
297464
"""
298465
Class for fitting a didv curve for different types of models of the
@@ -304,7 +471,7 @@ class _BaseDIDV(object):
304471

305472
def __init__(self, rawtraces, fs, sgfreq, sgamp, rsh, tracegain=1.0,
306473
r0=0.3, rp=0.005, dutycycle=0.5, add180phase=False,
307-
dt0=10.0e-6, autoresample=False):
474+
dt0=None, autoresample=False):
308475
"""
309476
Initialization of the _BaseDIDV class object
310477
@@ -911,13 +1078,35 @@ def processtraces(self):
9111078
nbinsraw = len(self._rawtraces[0])
9121079
bins = np.arange(0, nbinsraw)
9131080

914-
# add half a period of the square wave frequency to the
915-
# initial offset if add180phase is True
916-
if (self._add180phase):
917-
self._dt0 = self._dt0 + 1/(2*self._sgfreq)
1081+
1082+
if self._dt0 is None:
1083+
1084+
raw_avg = np.mean(self._rawtraces, axis=0) # shape: (nbinsraw,)
1085+
raw_time = bins * dt
1086+
1087+
#dt0_estimate_cross = estimate_dt0_cross_spectrum(
1088+
# raw_avg.copy(), fs=self._fs, sgfreq=self._sgfreq,
1089+
# duty=getattr(self, "_dutycycle", 0.5),
1090+
# sgamp=self._sgamp, nharm_max= 9 or 11, band_bins=1,
1091+
# zpad=4)
1092+
1093+
dt0_estimate_fft = estimate_dt0_fft(
1094+
raw_avg.copy(), raw_time, self._sgfreq,
1095+
duty=getattr(self, "_dutycycle", 0.5))
1096+
1097+
#self._dt0 = (dt0_estimate_cross + dt0_estimate_fft)/2
1098+
self._dt0 = dt0_estimate_fft
1099+
1100+
else:
1101+
1102+
# add half a period of the square wave frequency to the
1103+
# initial offset if add180phase is True
1104+
if (self._add180phase):
1105+
self._dt0 = self._dt0 + 1/(2*self._sgfreq)
9181106

9191107
self._time = bins*dt - self._dt0
9201108

1109+
9211110
# figure out how many didv periods are in the trace, including
9221111
# the time offset
9231112
period = 1.0/self._sgfreq

qetpy/core/didv/_didv.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class DIDV(_BaseDIDV, _PlotDIDV):
137137

138138
def __init__(self, rawtraces, fs, sgfreq, sgamp, rsh, tracegain=1.0,
139139
r0=0.3, rp=0.005, dutycycle=0.5, add180phase=False,
140-
dt0=1.5e-6, autoresample=False):
140+
dt0=None, autoresample=False):
141141
"""
142142
Initialization of the DIDV class object
143143
@@ -524,7 +524,7 @@ def dofit(self, poles, fcutoff=np.inf,
524524

525525
# time shift guess
526526
dt = self._dt0
527-
527+
528528
# overrite guessed values if provided by user
529529
if guess_params is not None:
530530
if len(guess_params) != 3:
@@ -595,7 +595,6 @@ def dofit(self, poles, fcutoff=np.inf,
595595
# time shift
596596
dt0 = self._dt0
597597

598-
599598
# overrite guessed values if provided by user
600599
if guess_params is not None:
601600
if len(guess_params) != 5:
@@ -693,12 +692,11 @@ def dofit(self, poles, fcutoff=np.inf,
693692
self._fit_results[2]['offset_err'] = self._offset_err
694693

695694
elif poles==3:
696-
695+
697696
if (self._fit_results[2] is None
698-
or 'param' not in self._fit_results[2]):
697+
or 'params' not in self._fit_results[2]):
699698

700-
# Guess the 3-pole fit starting parameters from
701-
# 2-pole fit guess
699+
# Guess the 3-pole fit
702700
A0, B0, tau10, tau20 = DIDV._guessdidvparams(
703701
self._tmean,
704702
self._tmean[self._flatinds],
@@ -720,8 +718,7 @@ def dofit(self, poles, fcutoff=np.inf,
720718
tau20 = self._fit_results[2]['params']['tau2']
721719
tau30 = 1.0e-3
722720
dt0 = self._fit_results[2]['params']['dt']
723-
724-
721+
725722
# is loop gain < 1
726723
isloopgainsub1 = DIDV._guessdidvparams(
727724
self._tmean,
@@ -759,6 +756,7 @@ def dofit(self, poles, fcutoff=np.inf,
759756
isloopgainsub1 = guess_isloopgainsub1
760757

761758
# 3 pole fitting
759+
762760
fitparams3, fitcov3, fitcost3 = DIDV._fitdidv(
763761
self._freq[fit_freqs],
764762
self._didvmean[fit_freqs],

0 commit comments

Comments
 (0)