Skip to content

Commit 434eb4e

Browse files
committed
add v Dict in augment for type stable .v[n] passing
1 parent 267f2ec commit 434eb4e

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

pkg/VoltoMapSim/src/conntest.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

22
function calc_STA(VI_sig, presynaptic_spikes, p::ExpParams)
3-
Δt::Float64 = p.sim.general.Δt # explicit type annotation needed
4-
win_size = round(Int, p.conntest.STA_window_length / Δt)
3+
win_size = STA_win_size(p)
54
STA = zeros(eltype(VI_sig), win_size)
6-
win_starts = round.(Int, presynaptic_spikes / Δt)
5+
win_starts = round.(Int, presynaptic_spikes / p.sim.general.Δt::Float64)
6+
# Explicit type annotation on Δt needed, as typeof(sim) unknown.
77
num_wins = 0
88
for a in win_starts
99
b = a + win_size - 1
@@ -17,7 +17,7 @@ function calc_STA(VI_sig, presynaptic_spikes, p::ExpParams)
1717
end
1818

1919
calc_STA((from, to), s::SimData, p::ExpParams) =
20-
calc_STA(s.signals[to].v, s.spike_times[from], p)
20+
calc_STA(s.v[to], s.spike_times[from], p)
2121

2222

2323
to_ISIs(spiketimes) = [first(spiketimes); diff(spiketimes)] # copying

pkg/VoltoMapSim/src/plot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function plotSTA(
4545
end
4646

4747
plotSTA(from::Int, to::Int, s #= simdata =#, p::ExpParams; kw...) =
48-
plotSTA(s.signals[to].v, s.spike_times[from], p; kw...)
48+
plotSTA(s.v[to], s.spike_times[from], p; kw...)
4949

5050

5151
function rasterplot(spiketimes; tlim, ms = 1)

pkg/VoltoMapSim/src/sim/post.jl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,15 @@ function augment(s::SimData, p::ExpParams)
3131
non_inputs = [[n for n in all if n inputs[m] && n != m] for m in all]
3232
num_inputs = [(exc = length(exc_inputs[m]), inh = length(inh_inputs[m])) for m in all]
3333

34-
return (;
35-
s.data...,
34+
v = Dict{Int, Vector{Float64}}()
35+
# Purely for type inference (`s.signals` is {Int, Any}).
36+
@unpack record_v, record_all = p.sim.network
37+
for n in unique!([record_v; record_all])
38+
v[n] = s.signals[n].v
39+
end
40+
41+
return s = (;
42+
s...,
3643
num_spikes_per_neuron,
3744
spike_rates,
3845
pre_post_pairs,
@@ -41,38 +48,29 @@ function augment(s::SimData, p::ExpParams)
4148
inh_inputs,
4249
non_inputs,
4350
num_inputs,
51+
v,
4452
)
4553
end
4654

47-
function calc_avg_STA(s::SimData, p::ExpParams, postsyn_neurons, inputs)
48-
acc = nothing
49-
N = 0
50-
# @showprogress(
51-
for n in postsyn_neurons
52-
for m in inputs[n]
53-
STA = calc_STA(m => n, s, p)
54-
if isnothing(acc) acc = STA
55-
else acc .+= STA end
56-
N += 1
57-
end
58-
end
59-
# end)
60-
return avgSTA = acc ./ N
61-
end
62-
63-
function calc_avg_STA_v2(s::SimData, p::ExpParams, postsyn_neurons, inputs)
64-
Δt::Float64 = p.sim.general.Δt
65-
win_size = round(Int, p.conntest.STA_window_length / Δt)
66-
acc = zeros(Float64, win_size)
55+
function calc_avg_STA(s::SimData, p::ExpParams; postsyn_neurons, input_type::Symbol)
56+
if (input_type == :exc) inputs = s.exc_inputs
57+
else inputs = s.inh_inputs end
58+
acc = zeros(Float64, STA_win_size(p))
6759
N = 0
68-
# @showprogress(
60+
@showprogress(
6961
for n in postsyn_neurons
7062
for m in inputs[n]
7163
STA = calc_STA(m => n, s, p)
7264
acc .+= STA
7365
N += 1
7466
end
75-
# end)
76-
end
67+
end)
7768
return avgSTA = acc ./ N
7869
end
70+
# We don't use the more compact generator form
71+
# `mean(calc_STA(m => n, s, p) for n in 1:40 for m in s.exc_inputs[n])`
72+
# ..as then we don't get a progress report. (And `@showprogress` on a `reduce` errors here).
73+
74+
STA_win_size(p::ExpParams) =
75+
round(Int, p.conntest.STA_window_length / p.sim.general.Δt::Float64)
76+
# Explicit type annotation on Δt needed, as typeof(sim) unknown.

0 commit comments

Comments
 (0)