Cross-Correlation Histogram: Coding Exercise
Table of Contents
1. The question
Write a Python
function returning the cross-correlation histogram (or cross-correlogram) associated with two observed spike trains. We use the definition of Brillinger, Bryant and Segundo (1976) Eq. 13, p. 218 (see also Perkel, Gerstein and Moore (1967)). These cross-correlation histograms are estimators of what Cox and Lewis (1966) call a cross intensity function. More precisely, given:
- a simulation/observation window \([0,t_{obs}]\), chosen by the experimentalist,
- the spike trains of two neurons, \(i\) and \(j\): \(\{T^i_1,\ldots,T^i_{N^i}\}\) and \(\{T^j_1,\ldots,T^j_{N^j}\}\), where \(N^i\) and \(N^j\) are the number of spikes of neurons \(i\) and \(j\) during the observation time window (notice that the two neurons are identical, \(i=j\), when we are dealing with an auto-correlation histogram),
- a maximum lag, \(0 < \mathcal{L} < t_{obs}\) (in practice \(\mathcal{L}\) must be much smaller than \(t_{obs}\) in order to have enough events, see below),
- \(K^i_{min} = \inf \left\{k \in \{1,\ldots, N^i\}: T^i_k \ge \mathcal{L}\right\}\),
- \(K^i_{max} = \sup \left\{k \in \{1,\ldots, N^i\}: T^i_k \le t_{obs}- \mathcal{L}\right\}\),
- an integer \(m \in \{m_{min} = -\mathcal{L}, \ldots, m_{max} = \mathcal{L}\}\),
the correlation histogram between \(i\) and \(j\) at \(m\) is defined by: \[ \widehat{cc}_{i \to j}(m) \equiv \frac{1}{\left(K^i_{max} -K^i_{min}+1\right)} \sum_{k = K^i_{min}}^{K^i_{max}} \sum_{l=1}^{N^j} \mathbb{1}_{\{T^j_l - T^i_k = m\}} \, . \]
We have by definition the following relation: \[\left(K^i_{max} -K^i_{min}+1\right) \widehat{cc}_{i \to j}(m) = \left(K^j_{max} -K^j_{min}+1\right) \widehat{cc}_{j \to i}(-m)\, .\] When we are dealing with auto-correlation histograms (\(i=j\)), we have therefore: \[\widehat{cc}_{i \to i}(m) = \widehat{cc}_{i \to i}(-m)\, .\] By construction: \(\widehat{cc}_{i \to i}(0) = 1\) The term "correlation" is a misnomer here, since the means are not subtracted (therefore the estimated quantities cannot be negative): following Glaser and Ruchkin (1976), Sec. 7.2 the term "covariance" would have been more appropriate. You can therefore write a function that computes the correlation histogram only for lags that are null or positive.
Your function should take the following parameters:
- a list or array of strictly increasing integers containing the discrete spike times of the "reference" neuron (neuron \(i\) above).
- a list or array of strictly increasing integers containing the discrete spike times of the "test" neuron (neuron \(j\) above).
- a positive integer specifying the maximum lag.
Your function should return a list or array with the \(\widehat{cc}_{i \to j}(m)\) for m going from 0 to the maximum lag.
You should apply basic tests to your code like:
- generate a perfectly periodic spike train, say with period 5, compute the auto-correlation histogram: \(\widehat{cc}_{i \to i}(m)\) and check that would get what is expected (you must figure out what is expected first).
- generate a periodic spike train, say with period 5, that exhibits a spike (at the prescribed times) with a probability say 0.5, compute the auto-correlation histogram: \(\widehat{cc}_{i \to i}(m)\) and check that would get what is expected (you must figure out what is expected first).
- generate two "geometric" (that is with inter spike intervals following a geometric distribution) and independent at first trains,
ref
andtest
, that have a probability 0.05 to exhibit a spike at any given time. For every spike inref
add a spike in test (if there is not already one) 3 time steps later with a probability 0.5. Compute \(\widehat{cc}_{\text{ref} \to \text{test}}(m)\) and \(\widehat{cc}_{\text{test} \to \text{ref}}(m)\) and check that would get what is expected (you must figure out what is expected first).
2. The solution (Python
)
We start by defining cch_discrete
function, block by block, before running the tests. A Python
file containing the codes and comands described next is available for download.
cch_discrete
outline
The code outline is:
def cch_discrete(ref, test, lag_max): <<cch_discrete-docstring>> <<cch_discrete-check-par>> <<cch_discrete-do-job>> return cch
cch_discrete-docstring
The docstring of function cch_discrete
is:
"""Auto/Cross-Correlation Histogram between `ref` and `test` up to `lag_max` Parameters ---------- ref: a strictly increasing list (or any 'iterable') of integers, with the (discrete) times at which the 'reference' neuron spiked. test: a strictly increasing list (or any 'iterable') of integers, with the (discrete) times at which the 'test' neuron spiked. lag_max: a positive integer the largest lag used. Returns ------- A list of floats containing the computed CCH. Only null or positive lags are considered. """
cch_discrete-check-par
This block checks that the three parameters:
ref
test
lag_max
fulfill the requirements specified in the docstring
. The following variables are defined: n_ref
and n_test
; the number of elements of ref
and test
:
from collections.abc import Iterable # Make sure ref is an iterable with strictly increasing integers if not isinstance(ref,Iterable): # Check that 'ref' is iterable raise TypeError('ref must be an iterable.') n_ref = len(ref) all_integers = all([isinstance(x,int) for x in ref]) if not all_integers: raise TypeError('ref must contain integers.') diff = [ref[i+1]-ref[i] for i in range(n_ref-1)] all_positive = sum([x > 0 for x in diff]) == n_ref-1 if not all_positive: # Check that 'ref' elements are increasing raise ValueError('ref must be increasing.') if not isinstance(test,Iterable): # Check that 'test' is iterable raise TypeError('test must be an iterable.') # Make sure test is an iterable with strictly increasing integers n_test = len(test) all_integers = all([isinstance(x,int) for x in test]) if not all_integers: raise TypeError('test must contain integers.') diff = [test[i+1]-test[i] for i in range(n_test-1)] all_positive = sum([x > 0 for x in diff]) == n_test-1 if not all_positive: # Check that 'test' elements are increasing raise ValueError('test must be increasing.') # Make sure lag_max is a positive integer if not isinstance(lag_max,int): raise TypeError('lag_max must be an integer.') if not lag_max > 0: raise ValueError('lag_max must be > 0.')
<<cch_discrete-do-job>>
This block takes care of the actual computation. Since only positive or null lags are considered, the code is designed to explore a minimal number of spike times (or elements) on test
:
test_last = test[-1] # the last spike time in test idx_on_ref = 0 # an index running on the spikes of the ref train ref_time = ref[idx_on_ref] # the time of the 'current' spike from ref idx_on_test = 0 # an index running on the spikes of the test train test_time = test[idx_on_test] # the time of the 'current' spike from test denominator = 0 # counts how many spikes from ref are considered cch = [0]*(lag_max+1) # this list will store the result while ref_time + lag_max <= test_last: # Keep spike at ref_time from ref denominator += 1 while test_time < ref_time: # Advance test_time until it is equal to or larger than # the current ref_time idx_on_test += 1 if idx_on_test >= n_test: # If the index is larger than or equal to the number # of spikes in test quite while loop break test_time = test[idx_on_test] # Now test_time is the first spike from test coming at the same time as # or later than the ref spike currently considered if idx_on_test < n_test: # Add an extra index on test that will run from 'idx_on_test' # to the last spike time from test smaller than or equal to # ref_time + lag_max running_idx = idx_on_test running_time = test[running_idx] while running_time <= ref_time + lag_max: # Add an event at the proper lag cch[running_time-ref_time] += 1 running_idx +=1 # go to the next spike on test_st if running_idx >= n_test: # If the index is larger than or equal to the number # of spikes in test quite while loop break # Get the time of next spike from test running_time = test[running_idx] # Go to the next spike on ref idx_on_ref += 1 if idx_on_ref >= n_ref: # If the index is larger than or equal to the number # of spikes in ref_st quite while loop break # Get the time of next spike from ref_st ref_time = ref[idx_on_ref] cch = [x/denominator for x in cch]
Testing the function
We can test our cch_discrete
function by making compute the auto-correlation function of a strictly periodic train with period five. In this case if we look at an interval of length \(k \times 5\) (\(k \in \mathbb{N}\)) after a spike, we are sure to find another spike from the neuron, the auto-correlation histogram should therefore have a value 1 (and zero everywhere else):
ref_train1 = [i*5 for i in range(100)] autoco1 = cch_discrete(ref_train1,ref_train1,15) print(", ".join([str(round(x,2)) for x in autoco1]))
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0
We can refine this case by drawing an event with a probability 0.5 every five time steps (we do not forget to set our PRNG seed for reproducibility). In this case if we look at an interval of length \(k \times 5\) (\(k \in \mathbb{N}\)) after a spike, we have a probability 0.5 find another spike from the neuron, the auto-correlation histogram should therefore have a value 0.5 for (and zero everywhere else):
from random import seed, random seed(20061001) ref_train2 = [i*5 for i in range(10000) if random() <= 0.5] autoco2 = cch_discrete(ref_train2,ref_train2,15) print(", ".join([str(round(x,2)) for x in autoco2]))
1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.51
We generate next two spike trains, pre
and post
, whose intervals between successive events follow first a geometric distribution, then we add to post
extra spike following the spikes from pre
with a lag of 3 and a probability of 0.5. The probability for finding a spike from pre
at an arbitrarily chosen time must be 0.05. The probability for finding a spike from post
at an arbitrarily chosen time is a bit trickier to compute. Let us compute rather the probability of the complementary event: not having a spike at an arbitrarily chosen time. The latter event can occur in two ways:
- the basic generation mechanism (the geometric one) did not generate a spike (probability 0.95) and there was no spike from
pre
three time units earlier (probability 0.95): the probability for this compound event made of two independent sub-events is \(0.95 \times 0.95\) - the basic generation mechanism (the geometric one) did not generate a spike (probability 0.95) and there was a spike from
pre
(probability 0.05) but this spike was not followed by an increase inpost
rate (probability 0.5): the probability is \(0.95 \times 0.05 \times 0.5\).
The probability for observing a spike from post
at an arbitrarily chosen time is thepreore (result rounded to the fourth decimal):
\[1 - 0.95 (0.95 + 0.05 \, 0.5) = 0.0738\]
Using a sample size of \(10^6\) as we do next and the normal approximation, we get a standard deviation of:
\[\sigma_{p=0.0738} = \sqrt{\frac{0.0738 \, (1-0.0738)}{10^6}} = 0.0003\]
and an interval within which 95% of the estimated probability should fall of:
\[[0.0738 - 1.96 \, \sigma_{p=0.0738} = 0.0732\, ,\, 0.07438 + 1.96 \, \sigma_{p=0.0738} = 0.0742]\, .\]
The probability for finding a spike from post
\(k \times 3\) time units after a spike from pre
is computed in the same way. For not having a spike we can only have: the basic generation mechanism (the geometric one) did not generate a spike (probability 0.95) and the spike from pre
three time units earlier did not increase in post
rate (probability 0.5). The probability for having a spike is therefore:
\[1 - (0.95 \times 0.5) = 0.525 \]
The standard deviation with a sample size of \(10^6\) and the normal approximation is:
\[\sigma_{p=0.525} = \sqrt{\frac{0.525 \, (1-0.525)}{10^6}} = 0.0005\]
giving an interval within which 95% of the estimated probability should fall of:
\[[0.525 - 1.96 \, \sigma_{p=0.525} = 0.524 \, , \, 0.525 + 1.96 \, \sigma_{p=0.525} = 0.526] \, .\]
seed(20110928) p_spike = 0.05 # probability of a spike in any time step n_spikes = 10**6 # number of spikes to simulate pre = [0]*n_spikes # list that will contain the pre spikes # Simulate a 'geometric' pre spike idx = 0 time = 0 while idx < n_spikes: time += 1 if random() <= p_spike: pre[idx] = time idx += 1 post = [0]*n_spikes # list that will contain the post spikes idx = 0 time = 0 while idx < n_spikes: time += 1 if random() <= p_spike: post[idx] = time idx += 1 p_synaptic = 0.5 # probability that a pre spike causes a post one # Add pre effect onto post in a 'brute force way' for time in pre: if random() <= p_synaptic: post.append(time+3) post = sorted(set(post)) # set() eliminates times that occur twice pre2post = cch_discrete(pre,post,10) print("Pre -> Post:") print(", ".join([str(round(x,4)) for x in pre2post])) post2pre = cch_discrete(post,pre,10) print("Post -> Pre:") print(", ".join([str(round(x,4)) for x in post2pre]))
Pre -> Post: 0.0739, 0.0742, 0.0738, 0.5247, 0.0739, 0.0739, 0.0735, 0.0739, 0.0735, 0.0738, 0.0734 Post -> Pre: 0.0501, 0.0498, 0.0497, 0.0499, 0.0502, 0.0501, 0.0499, 0.05, 0.05, 0.0499, 0.0496