Source code for itrails.combine_states

import numba as nb
import numpy as np


[docs] def combine_states( state_dict_1, state_dict_2, state_dict_sum, final_probs_1, final_probs_2 ): """ Function that combines two dictionaries of individual states and their final probabilities into a single dictionary of combined states and starting probabilities. :param state_dict_1: Dictionary of states and indices for the first CTMC, can be 1 sequence or 2 sequence CTMC. :type state_dict_1: Numba Dictionary of Key: Tuple of int64 and Value: int64. :param state_dict_2: Dictionary of states and indices for the second CTMC, can only be 1 sequence CTMC. :type state_dict_2: Numba Dictionary of Key: Tuple(int64, int64) and Value: int64 :param state_dict_sum: Dictionary of states and indices for the combined CTMC. :type state_dict_sum: Numba Dictionary of Key: Tuple of int64 and Value: int64. :param final_probs_1: Array of final probabilities for a single key of the first CTMC. :type final_probs_1: Array of type: float64[:, :] :param final_probs_2: Array of final probabilities for the only key of the second CTMC. :type final_probs_2: Array of type: float64[:, :] :return: Array of starting probabilities for each combined state in the combined CTMC. :rtype: Array of type: float64[:, :] """ len_array = len(list(state_dict_1.keys())[0]) + len(list(state_dict_2.keys())[0]) init_comb_dict = {} init_comb_all = np.zeros((len(state_dict_sum.keys())), dtype=np.float64) for key_1, index_1 in state_dict_1.items(): left_1 = key_1[: len(key_1) // 2] right_1 = key_1[len(key_1) // 2 :] for key_2, index_2 in state_dict_2.items(): left_2 = key_2[: len(key_2) // 2] right_2 = key_2[len(key_2) // 2 :] comb_state = np.zeros((len_array), dtype=np.int64) used_values_1 = {} used_values_2 = {} current_value = 1 index = 0 for value in left_1: if value in used_values_1: comb_state[index] = used_values_1[value] else: used_values_1[value] = current_value comb_state[index] = current_value current_value += 1 index += 1 for value in left_2: if value in used_values_2: comb_state[index] = used_values_2[value] else: used_values_2[value] = current_value comb_state[index] = current_value current_value += 1 index += 1 for value in right_1: if value in used_values_1: comb_state[index] = used_values_1[value] else: used_values_1[value] = current_value comb_state[index] = current_value current_value += 1 index += 1 for value in right_2: if value in used_values_2: comb_state[index] = used_values_2[value] else: used_values_2[value] = current_value comb_state[index] = current_value current_value += 1 index += 1 init_comb_dict[tuple(comb_state)] = ( final_probs_1[0, index_1] * final_probs_2[0, index_2] ) for state in init_comb_dict.keys(): index_AB = state_dict_sum[state] init_comb_all[index_AB] = init_comb_dict[state] return init_comb_all
[docs] def combine_states_wrapper( state_dict_1, state_dict_2, state_dict_sum, final_probs_1, final_probs_2, ): """ Wrapper function that combines dictionaries of states and their final probabilities into a single dictionary of combined states and starting probabilities. :param state_dict_1: Dictionary of states and indices for the first CTMC, can be 1 sequence or 2 sequence CTMC. :type state_dict_1: Numba Dictionary of Key: Tuple of int64 and Value: int64. :param state_dict_2: Dictionary of states and indices for the second CTMC, can only be 1 sequence CTMC. :type state_dict_2: Numba Dictionary of Key: Tuple(int64, int64) and Value: int64. :param state_dict_sum: Dictionary of states and indices for the combined CTMC. :type state_dict_sum: Numba Dictionary of Key: Tuple of int64 and Value: int64. :param final_probs_1: Final probability dictionary for the first CTMC. :type final_probs_1: Numba Dictionary of Key: UniTuple(nb.types.UniTuple(int64, 3), 2) and Value: float64[:, :]. :param final_probs_2: Final probability dictionary for the second CTMC. :type final_probs_2: Numba Dictionary of Key: UniTuple(nb.types.UniTuple(int64, 3), 2) and Value: float64[:, :]. :raises NotImplementedError: Not implemented for more than 2 species in state_dict_1 or and more than 1 species in state_dict_2. :raises Exception: Fallback if invalid format in final_probs_1 or final_probs_2. :return: Dictionary of combined states and starting probabilities for each state. :rtype: Numba Dictionary of Key: UniTuple(nb.types.UniTuple(int64, 3), 2) and Value: float64[:, :]. """ pi_dict = nb.typed.Dict.empty( key_type=nb.types.UniTuple(nb.types.UniTuple(nb.types.int64, 3), 2), value_type=nb.types.float64[:, :], ) start_placeholder = ((-1, -1, -1), (-1, -1, -1)) if len(final_probs_1.keys()) > 1 and len(final_probs_2.keys()) > 1: raise NotImplementedError elif len(final_probs_1.keys()) > 1 and len(final_probs_2.keys()) == 1: prob2 = final_probs_2[start_placeholder] for path, prob1 in final_probs_1.items(): pi_vector_combined = combine_states( state_dict_1, state_dict_2, state_dict_sum, prob1, prob2 ) pi_dict[path] = pi_vector_combined.reshape(1, -1) return pi_dict elif len(final_probs_1.keys()) == 1 and len(final_probs_2.keys()) == 1: prob1 = final_probs_1[start_placeholder] prob2 = final_probs_2[start_placeholder] pi_vector_combined = combine_states( state_dict_1, state_dict_2, state_dict_sum, prob1, prob2 ) pi_dict[start_placeholder] = pi_vector_combined.reshape(1, -1) return pi_dict else: raise Exception("Invalid final_probs_1 or final_probs_2")