#!/bin/python3 # All in one :noexport: # Put everything in a =Python= script: # #+NAME: cch_discrete.py # #+HEADER: :tangle code/cch_discrete.py :comments both # #+HEADER: :shebang "#!/bin/python3" # [[file:../cross-correlation-histogram-solution.org::cch_discrete.py][cch_discrete.py]] def cch_discrete(ref, test, lag_max): """Auto/Cross-Correlation Histogram between `ref` and `test` up to `lag_max` Parameters ---------- ref: a strictly increasing list (or any 'iterable') of integers, with the (discrete) times at which the 'reference' neuron spiked. test: a strictly increasing list (or any 'iterable') of integers, with the (discrete) times at which the 'test' neuron spiked. lag_max: a positive integer the largest lag used. Returns ------- A list of floats containing the computed CCH. Only null or positive lags are considered. """ from collections.abc import Iterable # Make sure ref is an iterable with strictly increasing integers if not isinstance(ref,Iterable): # Check that 'ref' is iterable raise TypeError('ref must be an iterable.') n_ref = len(ref) all_integers = all([isinstance(x,int) for x in ref]) if not all_integers: raise TypeError('ref must contain integers.') diff = [ref[i+1]-ref[i] for i in range(n_ref-1)] all_positive = sum([x > 0 for x in diff]) == n_ref-1 if not all_positive: # Check that 'ref' elements are increasing raise ValueError('ref must be increasing.') if not isinstance(test,Iterable): # Check that 'test' is iterable raise TypeError('test must be an iterable.') # Make sure test is an iterable with strictly increasing integers n_test = len(test) all_integers = all([isinstance(x,int) for x in test]) if not all_integers: raise TypeError('test must contain integers.') diff = [test[i+1]-test[i] for i in range(n_test-1)] all_positive = sum([x > 0 for x in diff]) == n_test-1 if not all_positive: # Check that 'test' elements are increasing raise ValueError('test must be increasing.') # Make sure lag_max is a positive integer if not isinstance(lag_max,int): raise TypeError('lag_max must be an integer.') if not lag_max > 0: raise ValueError('lag_max must be > 0.') test_last = test[-1] # the last spike time in test idx_on_ref = 0 # an index running on the spikes of the ref train ref_time = ref[idx_on_ref] # the time of the 'current' spike from ref idx_on_test = 0 # an index running on the spikes of the test train test_time = test[idx_on_test] # the time of the 'current' spike from test denominator = 0 # counts how many spikes from ref are considered cch = [0]*(lag_max+1) # this list will store the result while ref_time + lag_max <= test_last: # Keep spike at ref_time from ref denominator += 1 while test_time < ref_time: # Advance test_time until it is equal to or larger than # the current ref_time idx_on_test += 1 if idx_on_test >= n_test: # If the index is larger than or equal to the number # of spikes in test quite while loop break test_time = test[idx_on_test] # Now test_time is the first spike from test coming at the same time as # or later than the ref spike currently considered if idx_on_test < n_test: # Add an extra index on test that will run from 'idx_on_test' # to the last spike time from test smaller than or equal to # ref_time + lag_max running_idx = idx_on_test running_time = test[running_idx] while running_time <= ref_time + lag_max: # Add an event at the proper lag cch[running_time-ref_time] += 1 running_idx +=1 # go to the next spike on test_st if running_idx >= n_test: # If the index is larger than or equal to the number # of spikes in test quite while loop break # Get the time of next spike from test running_time = test[running_idx] # Go to the next spike on ref idx_on_ref += 1 if idx_on_ref >= n_ref: # If the index is larger than or equal to the number # of spikes in ref_st quite while loop break # Get the time of next spike from ref_st ref_time = ref[idx_on_ref] cch = [x/denominator for x in cch] return cch if __name__ == "__main__": from math import sqrt print("Testing cch_discrete with deterministic periodic spike train.") print("The result should be:") ref_1 = [1.0]+3*(4*[0.0]+[1.0]) print(", ".join([str(round(x,2)) for x in ref_1])) print("We get:") ref_train1 = [i*5 for i in range(100)] autoco1 = cch_discrete(ref_train1,ref_train1,15) print(", ".join([str(round(x,2)) for x in autoco1])) print("") print("Testing cch_discrete with stochastic periodic spike train.") print("The result should be:") ref_2 = [1.0]+3*(4*[0.0]+[0.5]) print(", ".join([str(round(x,2)) for x in ref_2])) print("We get:") from random import seed, random seed(20061001) ref_train2 = [i*5 for i in range(10000) if random() <= 0.5] autoco2 = cch_discrete(ref_train2,ref_train2,15) print(", ".join([str(round(x,2)) for x in autoco2])) print("") print("Testing cch_discrete with pre- and post spike trains.") ref_pre2post_l = [0.0732]*3+[0.524]+[0.0732]*7 ref_pre2post_u = [0.0742]*3+[0.526]+[0.0742]*7 ref_post2pre_l = [0.05-1.96*sqrt(0.05*0.95)/10**3]*11 ref_post2pre_u = [0.05+1.96*sqrt(0.05*0.95)/10**3]*11 print("The result should be (lower and upper bounds of 95% prob.):") print("Pre -> Post") print(", ".join([str(round(x,4)) for x in ref_pre2post_l])) print(", ".join([str(round(x,4)) for x in ref_pre2post_u])) print("Post -> Pre") print(", ".join([str(round(x,4)) for x in ref_post2pre_l])) print(", ".join([str(round(x,4)) for x in ref_post2pre_u])) print("We get:") seed(20110928) p_spike = 0.05 # probability of a spike in any time step n_spikes = 10**6 # number of spikes to simulate pre = [0]*n_spikes # list that will contain the pre spikes # Simulate a 'geometric' pre spike idx = 0 time = 0 while idx < n_spikes: time += 1 if random() <= p_spike: pre[idx] = time idx += 1 post = [0]*n_spikes # list that will contain the post spikes idx = 0 time = 0 while idx < n_spikes: time += 1 if random() <= p_spike: post[idx] = time idx += 1 p_synaptic = 0.5 # probability that a pre spike causes a post one # Add pre effect onto post in a 'brute force way' for time in pre: if random() <= p_synaptic: post.append(time+3) post = sorted(set(post)) # set() eliminates times that occur twice pre2post = cch_discrete(pre,post,10) print("Pre -> Post:") print(", ".join([str(round(x,4)) for x in pre2post])) post2pre = cch_discrete(post,pre,10) print("Post -> Pre:") print(", ".join([str(round(x,4)) for x in post2pre])) print("") # cch_discrete.py ends here