import itertools as it
import numba as nb
import numpy as np
[docs]
@nb.jit(nopython=True)
def bell_numbers(n):
"""Given a number 'n', returns the n'th Bell Number used for initializing the matrices with the correct number of rows.
:param n: number for which the Bell number is returned.
:type n: int.
:return: the n'th Bell Number.
:rtype: int."""
dp = [1] + [0] * n
for i in range(1, n + 1):
prev = dp[0]
dp[0] = dp[i - 1]
for j in range(1, i + 1):
temp = dp[j]
dp[j] = prev + dp[j - 1]
prev = temp
return dp[0]
[docs]
def partition(collection):
"""Generator that creates all set partitions from a list of consecutive numbers; for a given list, it yields nested lists representing all possible partitions of its elements into at most len(collection) subsets.
:param collection: list of integers ranging from 1 to n.
:type collection: list[int].
:return: generator yielding lists of lists representing the partitions.
:rtype: generator."""
if len(collection) == 1:
yield [collection]
return
first = collection[0]
for smaller in partition(collection[1:]):
for n, subset in enumerate(smaller):
yield smaller[:n] + [[first] + subset] + smaller[n + 1 :]
yield [[first]] + smaller
[docs]
def set_partitions(species):
"""Returns the set partitions for a given number of species in the CTMC by using the bell_numbers and partition generator; the resulting numpy array reformats each partition so that each number represents one nucleotide position (e.g., partition [[1,2,3],[4,5,6]] becomes [1,1,1,2,2,2]).
:param species: number of species for which partitions are generated.
:type species: int.
:return: numpy array with reformatted set partitions.
:rtype: np.array."""
num_rows = bell_numbers(2 * species)
num_cols = 2 * species
state_array = np.zeros((num_rows, num_cols), dtype=int)
for n, p in enumerate(partition(list(range(1, 2 * species + 1))), 1):
for j, subsublist in enumerate(sorted(p)):
for value in subsublist:
state_array[n - 1][value - 1] = j + 1
return state_array
[docs]
@nb.jit(nopython=True)
def translate_to_minimum(array):
"""Reformats a set partition so that its values are renumbered consecutively starting from 1; for example, converts [1,2,2,2,4,4] to [1,2,2,2,3,3].
:param array: set partition to reformat.
:type array: np.array of int with shape (1, n_species*2).
:return: reformatted set partition with consecutive minimum values.
:rtype: np.array of int with shape (1, n_species*2)."""
unique_values = np.unique(array)
value_map = {value: i + 1 for i, value in enumerate(np.sort(unique_values))}
translated_array = np.array([value_map[value] for value in array])
return translated_array
[docs]
def find_revcoal_recomb(state_array, species, state_dict):
"""Given a state array and number of species, returns all possible reversible coalescences (and recombinations) for each state; produces an array with (4*species + 3) columns where the first 2*species columns represent the state before the event and the next 2*species columns represent the state after the event.
:param state_array: set partition array for the given number of species.
:type state_array: np.array of int with shape (Bell number, n_species*2).
:param species: number of species. :type species: int.
:param state_dict: dictionary mapping each state tuple to its corresponding number.
:type state_dict: dict.
:return: array with reversible coalescences and recombinations.
:rtype: np.array."""
rev_coal_count = 0
rev_coals = np.empty((0, species * 4 + 3), dtype=int)
for row in state_array:
l_nucl = row[:species]
r_nucl = row[species:]
l_set = set(l_nucl)
r_set = set(r_nucl)
if l_set != r_set:
l_diff = l_set.difference(r_set)
r_diff = r_set.difference(l_set)
for i in r_diff:
for j in l_diff:
rev_coal_count += 1
rev_coal_state = np.zeros(species * 2, dtype=int)
rev_coal_state[:species] = l_nucl
for k, val in enumerate(r_nucl):
if val == i:
rev_coal_state[species + k] = j
else:
rev_coal_state[species + k] = val
ordered_rev_coal = translate_to_minimum(rev_coal_state)
new_coal_rec = np.empty((2, species * 4 + 3), dtype=int)
new_coal_rec[0, : species * 2] = row
new_coal_rec[0, species * 2 : species * 4] = ordered_rev_coal
new_coal_rec[0, species * 4] = state_dict[tuple(row)]
new_coal_rec[0, species * 4 + 1] = state_dict[
tuple(ordered_rev_coal)
]
new_coal_rec[0, species * 4 + 2] = 1
new_coal_rec[1, : species * 2] = ordered_rev_coal
new_coal_rec[1, species * 2 : species * 4] = row
new_coal_rec[1, species * 4] = state_dict[tuple(ordered_rev_coal)]
new_coal_rec[1, species * 4 + 1] = state_dict[tuple(row)]
new_coal_rec[1, species * 4 + 2] = 2
rev_coals = np.vstack((rev_coals, new_coal_rec))
return rev_coals
[docs]
def find_norevcoal(state_array, species, state_dict):
"""Given a state array and number of species, returns all possible non-reversible coalescences for each state; produces an array with (4*species + 3) columns where the first 2*species columns represent the state before the event and the next 2*species columns represent the state after the event.
:param state_array: set partition array for the given number of species.
:type state_array: np.array of int with shape (Bell number, n_species*2).
:param species: number of species.
:type species: int.
:param state_dict: dictionary mapping each state tuple to its corresponding number.
:type state_dict: dict.
:return: array with non-reversible coalescences.
:rtype: np.array."""
norevcoal_count = 0
norev_coals = np.empty((0, species * 4 + 3), dtype=int)
for row in state_array:
l_nucl = row[:species]
r_nucl = row[species:]
l_set = set(l_nucl)
r_set = set(r_nucl)
norevcoal_state = np.zeros(species * 2, dtype=int)
changed_pairs = []
if len(l_set) > 1:
for i, num1 in enumerate(l_nucl):
for j, num2 in enumerate(l_nucl):
if i < j and num1 != num2:
pair = sorted((num1, num2))
if pair not in changed_pairs:
norevcoal_count += 1
change = min(num1, num2)
for k, val in enumerate(row):
if val == num1 or val == num2:
norevcoal_state[k] = change
else:
norevcoal_state[k] = val
changed_pairs.append(pair)
ordered_norev_coal = translate_to_minimum(norevcoal_state)
new_norev_coal = np.empty((1, species * 4 + 3), dtype=int)
new_norev_coal[:, : species * 2] = row
new_norev_coal[:, species * 2 : species * 4] = (
ordered_norev_coal
)
new_norev_coal[:, species * 4] = state_dict[tuple(row)]
new_norev_coal[:, species * 4 + 1] = state_dict[
tuple(ordered_norev_coal)
]
new_norev_coal[:, species * 4 + 2] = 1
norev_coals = np.vstack((norev_coals, new_norev_coal))
if len(r_set) > 1:
for i, num1 in enumerate(r_nucl):
for j, num2 in enumerate(r_nucl):
if i < j and num1 != num2:
pair = sorted((num1, num2))
if pair not in changed_pairs:
norevcoal_count += 1
change = min(num1, num2)
for k, val in enumerate(row):
if val == num1 or val == num2:
norevcoal_state[k] = change
else:
norevcoal_state[k] = val
changed_pairs.append(pair)
ordered_norev_coal = translate_to_minimum(norevcoal_state)
new_norev_coal = np.empty((1, species * 4 + 3), dtype=int)
new_norev_coal[:, : species * 2] = row
new_norev_coal[:, species * 2 : species * 4] = (
ordered_norev_coal
)
new_norev_coal[:, species * 4] = state_dict[tuple(row)]
new_norev_coal[:, species * 4 + 1] = state_dict[
tuple(ordered_norev_coal)
]
new_norev_coal[:, species * 4 + 2] = 1
norev_coals = np.vstack((norev_coals, new_norev_coal))
return norev_coals
[docs]
@nb.jit(nopython=True)
def number_array_1(
state_array,
species,
mss,
tuple_omegas=nb.types.Tuple((nb.types.int64, nb.types.int64)),
tuple_states=nb.types.Tuple((nb.types.int64, nb.types.int64)),
):
"""For the 1-species CTMC, generates two dictionaries: an omega dictionary that tracks the location of each coalescence state (keys are tuples of minimum increasing substring sums) and a state dictionary mapping each state tuple to its index.
:param state_array: array with every set partition for 1 species.
:type state_array: np.array of int with shape (Bell number, n_species*2).
:param species: number of species.
:type species: int.
:param mss: list of minimum increasing substring sums.
:type mss: list[int].
:param tuple_omegas: nb.types.Tuple for the omega dictionary keys.
:type tuple_omegas: nb.types.Tuple.
:param tuple_states: nb.types.Tuple for the state dictionary keys.
:type tuple_states: nb.types.Tuple.
:return: omega dictionary and state dictionary.
:rtype: tuple."""
total_states = bell_numbers(2 * species)
state_dict = nb.typed.Dict.empty(key_type=tuple_states, value_type=nb.types.int64)
omega_dict = nb.typed.Dict.empty(
key_type=tuple_omegas, value_type=np.zeros(total_states, dtype=nb.types.boolean)
)
max_index = 2 * species + 1
for i, row in enumerate(state_array):
state_tuple = (row[0], row[1])
l_nucl = row[:species]
l_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(l_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=np.int32)
for index, number in enumerate(l_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
l_nucl_counts[number] = indices
r_nucl = row[species:]
r_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(r_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=nb.types.int64)
for index, number in enumerate(r_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
r_nucl_counts[number] = indices
l_omega = 0
r_omega = 0
for j in set(l_nucl):
if len(l_nucl_counts[j]) > 1:
for k in l_nucl_counts[j]:
l_omega += mss[k]
for j in set(r_nucl):
if len(r_nucl_counts[j]) > 1:
for k in r_nucl_counts[j]:
r_omega += mss[k]
state_dict[state_tuple] = i
if (l_omega, r_omega) not in omega_dict:
omega_dict[(l_omega, r_omega)] = np.zeros(
total_states, dtype=nb.types.boolean
)
omega_dict[(l_omega, r_omega)][i] = True
else:
omega_dict[(l_omega, r_omega)][i] = True
return omega_dict, state_dict
[docs]
@nb.jit(nopython=True)
def number_array_2(
state_array,
species,
mss,
tuple_omegas=nb.types.Tuple((nb.types.int64, nb.types.int64)),
tuple_states=nb.types.Tuple(
(nb.types.int64, nb.types.int64, nb.types.int64, nb.types.int64)
),
):
"""For the 2-species CTMC, generates two dictionaries: an omega dictionary tracking the location of each coalescence state (keys are tuples of minimum increasing substring sums) and a state dictionary mapping each state tuple to its index.
:param state_array: array with every set partition for 2 species.
:type state_array: np.array of int with shape (Bell number, n_species*2).
:param species: number of species.
:type species: int.
:param mss: list of minimum increasing substring sums.
:type mss: list[int].
:param tuple_omegas: nb.types.Tuple for omega dictionary keys.
:type tuple_omegas: nb.types.Tuple.
:param tuple_states: nb.types.Tuple for state dictionary keys.
:type tuple_states: nb.types.Tuple.
:return: omega dictionary and state dictionary.
:rtype: tuple."""
total_states = bell_numbers(2 * species)
state_dict = nb.typed.Dict.empty(key_type=tuple_states, value_type=nb.types.int64)
omega_dict = nb.typed.Dict.empty(
key_type=tuple_omegas, value_type=np.zeros(total_states, dtype=nb.types.boolean)
)
max_index = 2 * species + 1
for i, row in enumerate(state_array):
state_tuple = (row[0], row[1], row[2], row[3])
l_nucl = row[:species]
l_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(l_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=np.int32)
for index, number in enumerate(l_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
l_nucl_counts[number] = indices
r_nucl = row[species:]
r_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(r_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=nb.types.int64)
for index, number in enumerate(r_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
r_nucl_counts[number] = indices
l_omega = 0
r_omega = 0
for j in set(l_nucl):
if len(l_nucl_counts[j]) > 1:
for k in l_nucl_counts[j]:
l_omega += mss[k]
for j in set(r_nucl):
if len(r_nucl_counts[j]) > 1:
for k in r_nucl_counts[j]:
r_omega += mss[k]
state_dict[state_tuple] = i
if (l_omega, r_omega) not in omega_dict:
omega_dict[(l_omega, r_omega)] = np.zeros(
total_states, dtype=nb.types.boolean
)
omega_dict[(l_omega, r_omega)][i] = True
else:
omega_dict[(l_omega, r_omega)][i] = True
return omega_dict, state_dict
[docs]
@nb.jit(nopython=True)
def number_array_3(
state_array,
species,
mss,
tuple_omegas=nb.types.Tuple((nb.types.int64, nb.types.int64)),
tuple_states=nb.types.Tuple(
(
nb.types.int64,
nb.types.int64,
nb.types.int64,
nb.types.int64,
nb.types.int64,
nb.types.int64,
)
),
):
"""For the 3-species CTMC, generates two dictionaries: an omega dictionary tracking the location of each coalescence state (keys are tuples of minimum increasing substring sums) and a state dictionary mapping each state tuple to its index.
:param state_array: array with every set partition for 3 species.
:type state_array: np.array of int with shape (Bell number, n_species*2).
:param species: number of species.
:type species: int.
:param mss: list of minimum increasing substring sums.
:type mss: list[int].
:param tuple_omegas: nb.types.Tuple for omega dictionary keys.
:type tuple_omegas: nb.types.Tuple.
:param tuple_states: nb.types.Tuple for state dictionary keys.
:type tuple_states: nb.types.Tuple.
:return: omega dictionary and state dictionary.
:rtype: tuple."""
total_states = bell_numbers(2 * species)
state_dict = nb.typed.Dict.empty(key_type=tuple_states, value_type=nb.types.int64)
omega_dict = nb.typed.Dict.empty(
key_type=tuple_omegas, value_type=np.zeros(total_states, dtype=nb.types.boolean)
)
max_index = 2 * species + 1
for i, row in enumerate(state_array):
state_tuple = (row[0], row[1], row[2], row[3], row[4], row[5])
l_nucl = row[:species]
l_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(l_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=np.int32)
for index, number in enumerate(l_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
l_nucl_counts[number] = indices
r_nucl = row[species:]
r_nucl_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64[:]
)
index_tracker = np.zeros((max_index, len(r_nucl)), dtype=nb.types.int64) - 1
counts = np.zeros(max_index, dtype=nb.types.int64)
for index, number in enumerate(r_nucl):
current_count = counts[number]
index_tracker[number, current_count] = index
counts[number] += 1
number_entries = np.sum(counts > 0)
results_numbers = np.zeros(number_entries, dtype=nb.types.int64)
results_indices = []
result_index = 0
for number in range(max_index):
if counts[number] > 0:
results_numbers[result_index] = number
results_indices.append(index_tracker[number, : counts[number]])
result_index += 1
for number, indices in zip(results_numbers, results_indices):
r_nucl_counts[number] = indices
l_omega = 0
r_omega = 0
for j in set(l_nucl):
if len(l_nucl_counts[j]) > 1:
for k in l_nucl_counts[j]:
l_omega += mss[k]
for j in set(r_nucl):
if len(r_nucl_counts[j]) > 1:
for k in r_nucl_counts[j]:
r_omega += mss[k]
state_dict[state_tuple] = i
if (l_omega, r_omega) not in omega_dict:
omega_dict[(l_omega, r_omega)] = np.zeros(
total_states, dtype=nb.types.boolean
)
omega_dict[(l_omega, r_omega)][i] = True
else:
omega_dict[(l_omega, r_omega)][i] = True
return omega_dict, state_dict
[docs]
def get_trans_mat(transition_mat, species, coal, rho):
"""Computes the CTMC transition probability matrix given a transition matrix, number of species, coal rate, and recombination rate; the matrix size is determined by the Bell number for 2*species and the diagonal is set to the negative sum of the off-diagonals.
:param transition_mat: array representing the transition events.
:type transition_mat: np.array.
:param species: number of species.
:type species: int.
:param coal: rate for non-reversible coalescence.
:type coal: float.
:param rho: rate for reversible events (recombination).
:type rho: float.
:return: transition probability matrix.
:rtype: np.array of float64."""
mat_size = bell_numbers(2 * species)
trans_prob_array = np.zeros((mat_size, mat_size), dtype=np.float64)
for row in transition_mat:
x_value = row[species * 4 + 1]
y_value = row[species * 4]
trans_prob_array[y_value, x_value] = rho if row[species * 4 + 2] == 2 else coal
for i in range(mat_size):
trans_prob_array[i, i] = -sum(trans_prob_array[i])
return trans_prob_array
[docs]
def get_omega_nonrev_counts(species):
"""Computes and returns a dictionary mapping each omega value (sum of selected mss values) to its corresponding non-reversible coalescence count (subset size minus one).
:param species: number of species.
:type species: int.
:return: dictionary mapping omega values to non-reversible counts.
:rtype: nb.typed.Dict."""
omega_nonrev_counts = nb.typed.Dict.empty(
key_type=nb.types.int64, value_type=nb.types.int64
)
omega_nonrev_counts[0] = 0
mss = [2**i for i in range(species)]
for size in range(2, len(mss) + 1):
for subset in it.combinations(mss, size):
omega_nonrev_counts[sum(subset)] = len(subset) - 1
return omega_nonrev_counts
[docs]
def wrapper_state_1():
"""Wrapper function that returns the transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 1 species.
:return: tuple containing transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 1 species.
:rtype: tuple."""
species = 1
mss = [2**i for i in range(species)]
state_array_1 = set_partitions(species)
omega_dict_1, state_dict_1 = number_array_1(state_array_1, species, mss)
omega_nonrev_counts_1 = get_omega_nonrev_counts(species)
coal_and_rec_1 = find_revcoal_recomb(state_array_1, species, state_dict_1)
norev_coals_1 = find_norevcoal(state_array_1, species, state_dict_1)
transitions_1 = np.vstack((coal_and_rec_1, norev_coals_1))
return transitions_1, omega_dict_1, state_dict_1, omega_nonrev_counts_1
[docs]
def wrapper_state_2():
"""Wrapper function that returns the transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 2 species.
:return: tuple containing transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 2 species.
:rtype: tuple."""
species = 2
mss = [2**i for i in range(species)]
state_array_2 = set_partitions(species)
omega_dict_2, state_dict_2 = number_array_2(state_array_2, species, mss)
omega_nonrev_counts_2 = get_omega_nonrev_counts(species)
coal_and_rec_2 = find_revcoal_recomb(state_array_2, species, state_dict_2)
norev_coals_2 = find_norevcoal(state_array_2, species, state_dict_2)
transitions_2 = np.vstack((coal_and_rec_2, norev_coals_2))
return transitions_2, omega_dict_2, state_dict_2, omega_nonrev_counts_2
[docs]
def wrapper_state_3():
"""Wrapper function that returns the transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 3 species.
:return: tuple containing transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for 3 species.
:rtype: tuple."""
species = 3
mss = [2**i for i in range(species)]
state_array_3 = set_partitions(species)
omega_dict_3, state_dict_3 = number_array_3(state_array_3, species, mss)
omega_nonrev_counts_3 = get_omega_nonrev_counts(species)
coal_and_rec_3 = find_revcoal_recomb(state_array_3, species, state_dict_3)
norev_coals_3 = find_norevcoal(state_array_3, species, state_dict_3)
transitions_3 = np.vstack((coal_and_rec_3, norev_coals_3))
return transitions_3, omega_dict_3, state_dict_3, omega_nonrev_counts_3
[docs]
def wrapper_state_general(species):
"""Wrapper function that returns the transition matrix, omega dictionary, state dictionary, and omega non-reversible counts for n species (where n must be 1, 2, or 3).
:param species: number of species (must be 1, 2, or 3).
:type species: int.
:return: tuple containing transition matrix, omega dictionary, state dictionary, and omega non-reversible counts.
:rtype: tuple."""
mss = [2**i for i in range(species)]
state_array = set_partitions(species)
if species == 1:
omega_dict, state_dict = number_array_1(state_array, species, mss)
elif species == 2:
omega_dict, state_dict = number_array_2(state_array, species, mss)
elif species == 3:
omega_dict, state_dict = number_array_3(state_array, species, mss)
else:
raise ValueError("Species must be 1, 2 or 3")
omega_nonrev_counts = get_omega_nonrev_counts(species)
coal_and_rec = find_revcoal_recomb(state_array, species, state_dict)
norev_coals = find_norevcoal(state_array, species, state_dict)
transitions = np.vstack((coal_and_rec, norev_coals))
return transitions, omega_dict, state_dict, omega_nonrev_counts