Source code for itrails.run_markov_chain_ABC

import pickle

import numpy as np
from joblib import Parallel, delayed

import itrails.ncpu as ncpu
from itrails.deepest_ti import deep_identify_wrapper, deepest_ti
from itrails.expm import expm
from itrails.helper_omegas import remove_absorbing_indices, translate_to_omega
from itrails.vanloan import vanloan, vanloan_identify_wrapper


[docs] def compute_matrix_start_end( prob_mat, exponential_time, omega_start_mask, omega_end_mask ): """ Helper function that computes all matrix multiplications but the first, then slices the matrix to get the rows corresponding to the starting state and columns corresponding to the end state. :param prob_mat: Matrix of probabilities :type prob_mat: Numpy array :param exponential_time: Exponential matrix of transition matrix multiplied by time. :type exponential_time: Numpy array :param omega_start_mask: Vector of booleans that masks the rows of the matrix. :type omega_start_mask: Numpy array of booleans. :param omega_end_mask: Vector of booleans that masks the columns of the matrix. :type omega_end_mask: Numpy array of booleans. :return: Sliced matrix :rtype: Numpy array """ sliced_mat = (omega_start_mask) @ (exponential_time) @ (omega_end_mask) result = (prob_mat) @ (sliced_mat) return result
[docs] def compute_matrices_start_end_wrapper( prob_mats, exponential_time, omega_start_masks, omega_end_masks, num_combinations ): """Parallel wrapper that computes the sliced matrix result for each combination by invoking compute_matrix_start_end on each set of inputs; for every index from 0 to num_combinations-1, it multiplies the corresponding probability matrix (from prob_mats) with the result of slicing the exponential_time matrix using the corresponding omega_start_mask and omega_end_mask, and aggregates all results into a single numpy array; the computations are performed in parallel using joblib.Parallel with ncpu.N_CPU_GLOBAL workers. :param prob_mats: List or array of probability matrices where each element is a numpy array. :type prob_mats: list or np.ndarray. :param exponential_time: Numpy array representing the matrix computed by applying the exponential function to the transition matrix multiplied by time. :type exponential_time: np.ndarray. :param omega_start_masks: List or array of boolean vectors used to mask the rows of the exponential_time matrix for each combination. :type omega_start_masks: list or np.ndarray of bool. :param omega_end_masks: List or array of boolean vectors used to mask the columns of the exponential_time matrix for each combination. :type omega_end_masks: list or np.ndarray of bool. :param num_combinations: Total number of combinations (iterations) over which to compute the sliced matrix. :type num_combinations: int. :return: Numpy array containing the resulting matrices computed for each combination. :rtype: np.ndarray.""" results = Parallel(n_jobs=ncpu.N_CPU_GLOBAL)( delayed(compute_matrix_start_end)( prob_mats[i], exponential_time, omega_start_masks[i], omega_end_masks[i] ) for i in range(num_combinations) ) return np.array(results)
[docs] def vanloan_worker_inner( idx_i, idx_j, trans_mat, omega_dict_serialized, key, time, paths_array_j, omega_start_mask, omega_end_mask, prob_mat, ): """ Ray worker function for every combination of subpaths when multiple coalescents happen in a same time interval (Van Loan). :param idx_i: Marker of the path index within all possible Van Loan paths. :type idx_i: int :param idx_j: Marker of the subpath index within all possible subpaths. :type idx_j: int :param trans_mat: Transition matrix. :type trans_mat: Numpy array :param omega_dict_serialized: Serialized dictionary of omega indices (key) and vector of booleans where each key has the states (value). :type omega_dict_serialized: Serialized dictionary :param key: Key that represents the current state. :type key: Numpy array :param time: End time of the interval. :type time: float :param paths_array_j: Array of possible transitions within a subpath. :type paths_array_j: Numpy array :param omega_start_mask: 2D array to mask the matrix based on initial omega, array of booleans. :type omega_start_mask: Numpy array :param omega_end_mask: 2D array to mask the matrix based on final omega, array of booleans. :type omega_end_mask: Numpy array :param prob_mat: Array of probabilities for each state at the start of the time interval. :type prob_mat: Numpy array :return: I and J indices, Probability matrix summed over every possible transitions for a subpath, Updated Key. :rtype: Tuple(int, int, Numpy array, Numpy array) """ # Deserialize omega_dict omega_dict = pickle.loads(omega_dict_serialized) key_last = key[-1] results = np.zeros((key_last, 203, 203)) # Compute intermediate results for k in range(key_last): path = paths_array_j[k][1 : paths_array_j[k][0][0] + 1] results[k] = vanloan(trans_mat, path, time, omega_dict) # Summing results and computing sliced matrix sliced_mat = omega_start_mask @ results.sum(axis=0) @ omega_end_mask final_result = prob_mat @ sliced_mat final_key = key[:-1] return idx_i, idx_j, final_result, final_key
[docs] def vanloan_parallel_inner( vl_idx, time, trans_mat, omega_dict, vl_keys_acc_array, vl_paths_acc_array, vl_omega_masks_start, vl_omega_masks_end, vl_prob_mats, ): """Parallel wrapper that schedules vanloan_worker_inner tasks to compute the Van Loan integral for every combination of subpaths when multiple coalescents occur in the same time interval; it first serializes the provided omega dictionary, then builds a list of tasks from the accumulated keys, paths, omega masks, and probability matrices arrays, executes these tasks in parallel using ncpu.N_CPU_GLOBAL workers, and finally aggregates and flattens the results into arrays of updated keys and computed probability matrices along with the total number of valid tasks processed. :param vl_idx: number of indices in the vanloan keys accumulator array. :type vl_idx: int. :param time: end time of the current time interval used in the Van Loan integral computation. :type time: float. :param trans_mat: transition matrix used for the Van Loan integral computation. :type trans_mat: numpy array. :param omega_dict: dictionary mapping omega state keys to boolean vectors indicating valid states. :type omega_dict: dict. :param vl_keys_acc_array: accumulated array of keys for each vanloan subpath combination. :type vl_keys_acc_array: list or numpy array. :param vl_paths_acc_array: accumulated array of paths for each vanloan subpath combination. :type vl_paths_acc_array: list or numpy array. :param vl_omega_masks_start: list or array of boolean masks applied to the rows of the matrix for each vanloan subpath. :type vl_omega_masks_start: list or numpy array of bool. :param vl_omega_masks_end: list or array of boolean masks applied to the columns of the matrix for each vanloan subpath. :type vl_omega_masks_end: list or numpy array of bool. :param vl_prob_mats: list or array of probability matrices representing the initial state for each vanloan subpath. :type vl_prob_mats: list or numpy array. :return: tuple containing (flattened_keys, flattened_results, total_valid) where flattened_keys is an array of updated keys, flattened_results is an array of computed probability matrices, and total_valid is the total number of valid tasks processed. :rtype: tuple.""" # Serialize the dictionary once omega_dict_python = dict(omega_dict) omega_dict_serialized = pickle.dumps(omega_dict_python) # Build a list of arguments for each task tasks = [] for i in range(vl_idx): key_array = vl_keys_acc_array[i] paths_array = vl_paths_acc_array[i] omega_start_mask = vl_omega_masks_start[i] omega_end_mask = vl_omega_masks_end[i] prob_mat = vl_prob_mats[i] for j, key in enumerate(key_array): if key[-1] == 0: # stop if invalid key encountered break tasks.append( ( i, j, trans_mat, omega_dict_serialized, key, time, paths_array[j], omega_start_mask, omega_end_mask, prob_mat, ) ) # Run tasks in parallel results = Parallel(n_jobs=ncpu.N_CPU_GLOBAL)( delayed(vanloan_worker_inner)(*args) for args in tasks ) # Combine results total_valid = len(results) flattened_results = np.zeros((total_valid, 1, 203)) flattened_keys = np.zeros((total_valid, 6), dtype=np.int64) for idx, (i, j, final_result, final_key) in enumerate(results): flattened_results[idx] = final_result flattened_keys[idx] = final_key return flattened_keys, flattened_results, total_valid
[docs] def deepest_worker_inner( idx_i, idx_j, trans_mat_noabs, omega_dict_noabs_serialized, # Serialized dictionary key, paths_array_j, acc_prob_mat_noabs, path_lengths_j, ): """ Ray worker function for every combination of subpaths when multiple coalescents happen in the last time interval (Deepest TI). :param idx_i: Marker of the path index within all possible Deepest TI paths. :type idx_i: int :param idx_j: Marker of the subpath index within all possible subpaths. :type idx_j: int :param trans_mat_noabs: Transition matrix without absorbing states. :type trans_mat_noabs: Numpy array :param omega_dict_noabs_serialized: Serialized dictionary of omega indices (key) and vector of booleans where each key has the states (value), lacks absorbing states. :type omega_dict_noabs_serialized: Serialized dictionary :param paths_array_j: Array of possible transitions within a subpath. :type paths_array_j: Numpy array :param acc_prob_mat_noabs: Array of probabilities for each state at the start of the time interval without absorbing states. :type acc_prob_mat_noabs: Numpy array :param path_lengths_j: Array of lengths of each subpath. :type path_lengths_j: Numpy array :return: I and J indices, Probability matrix summed over every possible transitions for a subpath, Updated Key. :rtype: Tuple(int, int, Numpy array, Numpy array) """ # Deserialize omega_dict omega_dict_noabs = pickle.loads(omega_dict_noabs_serialized) num_subpaths = len(path_lengths_j) results = np.zeros((num_subpaths, 201, 201)) # Compute intermediate results for k in range(num_subpaths): path = paths_array_j[k][: path_lengths_j[k]] results[k] = deepest_ti(trans_mat_noabs, omega_dict_noabs, path) # Summing results and computing sliced matrix deep_ti_sum = results.sum(axis=0) final_result = acc_prob_mat_noabs @ deep_ti_sum return idx_i, idx_j, final_result, key
[docs] def deepest_parallel_inner( deepest_idx, trans_mat_noabs, omega_dict_noabs, deepest_keys_acc_array, deepest_paths_acc_array, deepest_path_lengths_array, acc_prob_mats_noabs, ): """Parallel wrapper that schedules deepest_worker_inner tasks to compute the matrix for the deepest time interval when multiple coalescents occur in the last time interval; it first serializes the omega dictionary (excluding absorbing states), then constructs a list of tasks from the accumulated keys, paths, and path lengths arrays along with the corresponding probability matrices without absorbing states, executes these tasks in parallel using ncpu.N_CPU_GLOBAL workers, and finally aggregates and flattens the results into arrays of updated keys and computed probability matrices along with the total count of valid tasks processed. :param deepest_idx: number of indices in the deepest keys accumulator array. :type deepest_idx: int. :param trans_mat_noabs: transition matrix without absorbing states used for deepest time interval computations. :type trans_mat_noabs: numpy array. :param omega_dict_noabs: dictionary mapping omega state keys (without absorbing states) to boolean vectors. :type omega_dict_noabs: dict. :param deepest_keys_acc_array: accumulated array of keys for each deepest subpath combination. :type deepest_keys_acc_array: list or numpy array. :param deepest_paths_acc_array: accumulated array of paths for each deepest subpath combination. :type deepest_paths_acc_array: list or numpy array. :param deepest_path_lengths_array: accumulated array of subpath lengths for each deepest subpath combination. :type deepest_path_lengths_array: list or numpy array. :param acc_prob_mats_noabs: list or array of probability matrices representing the initial state for deepest subpaths without absorbing states. :type acc_prob_mats_noabs: list or numpy array. :return: tuple containing (flattened_keys, flattened_results, total_valid) where flattened_keys is an array of updated keys, flattened_results is an array of computed probability matrices, and total_valid is the total number of valid tasks processed. :rtype: tuple.""" omega_dict_python = dict(omega_dict_noabs) omega_dict_noabs_serialized = pickle.dumps(omega_dict_python) tasks = [] for i in range(deepest_idx): key_array = deepest_keys_acc_array[i] paths_array = deepest_paths_acc_array[i] acc_prob_mat_noabs = acc_prob_mats_noabs[i] for j, key in enumerate(key_array): path_lengths_j = deepest_path_lengths_array[i][j] if all(x == 0 for x in key): break tasks.append( ( i, j, trans_mat_noabs, omega_dict_noabs_serialized, key, paths_array[j][: np.count_nonzero(path_lengths_j)], acc_prob_mat_noabs, path_lengths_j[: np.count_nonzero(path_lengths_j)], ) ) results = Parallel(n_jobs=ncpu.N_CPU_GLOBAL)( delayed(deepest_worker_inner)(*args) for args in tasks ) total_valid = len(results) flattened_results = np.zeros((total_valid, 1, 201)) flattened_keys = np.zeros((total_valid, 6), dtype=np.int64) for idx, (i, j, final_result, final_key) in enumerate(results): flattened_results[idx] = final_result flattened_keys[idx] = final_key return flattened_keys, flattened_results, total_valid
[docs] def run_markov_chain_ABC( trans_mat, times, omega_dict, prob_dict, omega_nonrev_counts, inverted_omega_nonrev_counts, n_int_ABC, species, absorbing_state=(7, 7), ): """ Function that runs the Discrete Time Markov chain for species A, B, and C. :param trans_mat: Transition matrix. :type trans_mat: Numpy array :param times: Array of cut times, last timepoint is infinity. :type times: Numpy array :param omega_dict: Dictionary of omega indices (key) and vector of booleans where each key has the states (value). :type omega_dict: Numba typed dictionary :param prob_dict: Dictionary of each path (keys) and probabilities for each state at the start of the first time interval (values). :type prob_dict: Numba typed dictionary :param omega_nonrev_counts: Dictionary of omega indices (key) and number of non-reversible transitions (value). :type omega_nonrev_counts: Numba typed dictionary :param inverted_omega_nonrev_counts: Dictionary of omega indices (key) and number of non-reversible transitions (value), inverted. :type inverted_omega_nonrev_counts: Numba typed dictionary :param n_int_ABC: Number of intervals for species A, B, and C. :type n_int_ABC: int :param species: Number of species. :type species: int :param absorbing_state: State where all coalecents have happened, defaults to (7, 7) :type absorbing_state: tuple, optional :return: Updated dictionary of each path (keys) and probabilities for each state at the end of the Markov chain (time equals inf)(values). :rtype: Numba typed dictionary """ for step in range(n_int_ABC - 1): exponential_time = expm(trans_mat * times[step]) og_keys = list(prob_dict.keys()) for path in og_keys: prob_mats = np.zeros((324, 1, 203), dtype=np.float64) vl_prob_mats = np.zeros((324, 1, 203), dtype=np.float64) omega_masks_start = np.zeros((324, 203, 203), dtype=np.float64) omega_masks_end = np.zeros((324, 203, 203), dtype=np.float64) vl_omega_masks_start = np.zeros((324, 203, 203), dtype=np.float64) vl_omega_masks_end = np.zeros((324, 203, 203), dtype=np.float64) keys = np.zeros((324, 6), dtype=np.int64) vl_keys_acc_array = np.zeros((324, 9, 7), dtype=np.int64) vl_paths_acc_array = np.zeros((324, 9, 15, 16, 2), dtype=np.int64) l_results = np.full((6, 3), -1, dtype=np.int64) r_results = np.full((6, 3), -1, dtype=np.int64) result_idx = 0 vl_idx = 0 prob_mat = prob_dict[path] l_path, r_path = path[0], path[1] l_results[0] = l_path r_results[0] = r_path l_results[1] = (l_path[0], step, step) if l_path[0] == -1 else l_path r_results[1] = (r_path[0], step, step) if r_path[0] == -1 else r_path l_results[2] = (1, step, l_path[2]) if l_path[0] == -1 else l_path r_results[2] = (1, step, r_path[2]) if r_path[0] == -1 else r_path l_results[3] = (2, step, l_path[2]) if l_path[0] == -1 else l_path r_results[3] = (2, step, r_path[2]) if r_path[0] == -1 else r_path l_results[4] = (3, step, l_path[2]) if l_path[0] == -1 else l_path r_results[4] = (3, step, r_path[2]) if r_path[0] == -1 else r_path l_results[5] = ( (l_path[0], l_path[1], step) if l_path[0] != -1 and l_path[2] == -1 else l_path ) r_results[5] = ( (r_path[0], r_path[1], step) if r_path[0] != -1 and r_path[2] == -1 else r_path ) for l_row in l_results: l_tuple = (int(l_row[0]), int(l_row[1]), int(l_row[2])) for r_row in r_results: r_tuple = (int(r_row[0]), int(r_row[1]), int(r_row[2])) if (l_tuple, r_tuple) in og_keys and not ( np.array_equal(l_row, l_path) and np.array_equal(r_row, r_path) ): continue else: new_key = ( l_tuple, r_tuple, ) omega_start = translate_to_omega(path) omega_end = translate_to_omega(new_key) omega_start_mask = omega_dict[omega_start] omega_end_mask = omega_dict[omega_end] if ( l_tuple[0] != 0 and l_tuple[1] == l_tuple[2] and l_tuple[1] != -1 ) or ( r_tuple[0] != 0 and r_tuple[1] == r_tuple[2] and r_tuple[1] != -1 ): omega_start_array = np.array( [omega_start[0], omega_start[1]] ) omega_end_array = np.array([omega_end[0], omega_end[1]]) ( key_array, paths_array, ) = vanloan_identify_wrapper( omega_start_array, omega_end_array, omega_nonrev_counts, inverted_omega_nonrev_counts, l_tuple, r_tuple, l_row, r_row, max_num_keys=10, max_num_subpaths_per_key=20, max_path_length=15, max_total_subpaths=200, ) num_keys = key_array.shape[0] num_paths = paths_array.shape[0] vl_keys_acc_array[vl_idx, :num_keys] = key_array[:num_keys] vl_paths_acc_array[vl_idx, :num_paths] = paths_array[ :num_paths ] vl_omega_masks_start[vl_idx] = np.diag(omega_start_mask) vl_prob_mats[vl_idx] = prob_mat vl_omega_masks_end[vl_idx] = np.diag(omega_end_mask) vl_idx += 1 else: new_row = np.array( [ l_row[0], l_row[1], l_row[2], r_row[0], r_row[1], r_row[2], ], dtype=np.int64, ) prob_mats[result_idx] = prob_mat omega_masks_start[result_idx] = np.diag(omega_start_mask) omega_masks_end[result_idx] = np.diag(omega_end_mask) keys[result_idx] = new_row result_idx += 1 flattened_keys, flattened_results, total_valid = vanloan_parallel_inner( vl_idx, times[step], trans_mat, omega_dict, vl_keys_acc_array, vl_paths_acc_array, vl_omega_masks_start, vl_omega_masks_end, vl_prob_mats, ) results_novl = compute_matrices_start_end_wrapper( prob_mats, exponential_time, omega_masks_start, omega_masks_end, result_idx, ) for i in range(result_idx): prob_dict[ ( (keys[i][0], keys[i][1], keys[i][2]), ( keys[i][3], keys[i][4], keys[i][5], ), ) ] = results_novl[i] for i in range(total_valid): prob_dict[ ( ( flattened_keys[i][0], flattened_keys[i][1], flattened_keys[i][2], ), ( flattened_keys[i][3], flattened_keys[i][4], flattened_keys[i][5], ), ) ] = flattened_results[i] og_keys = list(prob_dict.keys()) noabs_mask = np.logical_not(omega_dict[absorbing_state]) trans_mat_noabs = trans_mat[noabs_mask][:, noabs_mask] omega_dict_noabs = remove_absorbing_indices( omega_dict=omega_dict, absorbing_key=absorbing_state, species=species ) prob_dict_sum = {} for path in og_keys: (l_path, r_path) = path acc_prob_mats_noabs = np.zeros((324, 1, 201), dtype=np.float64) deepest_keys_acc_array = np.zeros((324, 9, 6), dtype=np.int64) deepest_paths_acc_array = np.zeros((324, 9, 15, 16, 2), dtype=np.int64) deepest_path_lengths_array = np.zeros((324, 9, 9), dtype=np.int64) deepest_idx = 0 prob_mat = prob_dict[path] # Case 1: ((Number != -1, Number != -1, Number != -1), (Number != -1, Number != -1, Number != -1)) if all(x != -1 for x in l_path) and all(x != -1 for x in r_path): prob_dict_sum[path] = np.sum(prob_mat) continue # Case 2: ((Number != -1, Number != -1, Number != -1), (Number != -1, Number != -1, -1)) elif ( all(x != -1 for x in l_path) and r_path[2] == -1 and all(x != -1 for x in r_path[:2]) ): new_key = (l_path, (r_path[0], r_path[1], n_int_ABC - 1)) prob_dict_sum[new_key] = np.sum(prob_dict[path]) prob_dict[new_key] = prob_dict.pop(path) continue # Case 3: ((Number != -1, Number != -1, Number != -1), (-1, -1, -1)) elif all(x != -1 for x in l_path) and all(x == -1 for x in r_path): l_tuple = (int(l_path[0]), int(l_path[1]), int(l_path[2])) r_tuple = (int(r_path[0]), n_int_ABC - 1, n_int_ABC - 1) new_key = ( l_tuple, r_tuple, ) l_row = [l_tuple[0], l_tuple[1], l_tuple[2]] r_row = [r_tuple[0], r_tuple[1], r_tuple[2]] new_path = new_key omega_start = translate_to_omega(path) (keys_array, paths_array, path_lengths_array, max_subpaths) = ( deep_identify_wrapper( omega_start, absorbing_state, omega_nonrev_counts, inverted_omega_nonrev_counts, new_path, ) ) num_keys = keys_array.shape[0] num_paths = paths_array.shape[0] keys_per_path = paths_array.shape[1] subpaths_per_path = path_lengths_array.shape[1] deepest_keys_acc_array[deepest_idx, :num_keys] = keys_array[:num_keys] deepest_paths_acc_array[ deepest_idx, :num_paths, :keys_per_path, :max_subpaths ] = paths_array deepest_path_lengths_array[deepest_idx, :num_keys, :subpaths_per_path] = ( path_lengths_array ) acc_prob_mats_noabs[deepest_idx] = prob_mat[:, noabs_mask] deepest_idx += 1 prob_dict.pop(path) # Case 4: ((Number != -1, Number != -1, -1), (Number != -1, Number != -1, Number != -1)) elif ( l_path[2] == -1 and all(x != -1 for x in l_path[:2]) and all(x != -1 for x in r_path) ): new_key = ((l_path[0], l_path[1], n_int_ABC - 1), r_path) prob_dict_sum[new_key] = np.sum(prob_dict[path]) prob_dict[new_key] = prob_dict.pop(path) continue # Case 5: ((Number != -1, Number != -1, -1), (Number != -1, Number != -1, -1)) elif ( l_path[2] == -1 and all(x != -1 for x in l_path[:2]) and r_path[2] == -1 and all(x != -1 for x in r_path[:2]) ): new_key = ( (l_path[0], l_path[1], n_int_ABC - 1), (r_path[0], r_path[1], n_int_ABC - 1), ) prob_dict_sum[new_key] = np.sum(prob_dict[path]) prob_dict[new_key] = prob_dict.pop(path) continue # Case 6: ((Number != -1, Number != -1, -1), (-1, -1, -1)) elif ( l_path[2] == -1 and all(x != -1 for x in l_path[:2]) and all(x == -1 for x in r_path) ): l_tuple = (int(l_path[0]), int(l_path[1]), n_int_ABC - 1) r_tuple = (int(r_path[0]), n_int_ABC - 1, n_int_ABC - 1) new_key = ( l_tuple, r_tuple, ) l_row = [l_tuple[0], l_tuple[1], l_tuple[2]] r_row = [r_tuple[0], r_tuple[1], r_tuple[2]] new_path = new_key omega_start = translate_to_omega(path) (keys_array, paths_array, path_lengths_array, max_subpaths) = ( deep_identify_wrapper( omega_start, absorbing_state, omega_nonrev_counts, inverted_omega_nonrev_counts, new_path, ) ) num_keys = keys_array.shape[0] num_paths = paths_array.shape[0] keys_per_path = paths_array.shape[1] subpaths_per_path = path_lengths_array.shape[1] deepest_keys_acc_array[deepest_idx, :num_keys] = keys_array[:num_keys] deepest_paths_acc_array[ deepest_idx, :num_paths, :keys_per_path, :max_subpaths ] = paths_array # [:num_paths] deepest_path_lengths_array[deepest_idx, :num_keys, :subpaths_per_path] = ( path_lengths_array ) acc_prob_mats_noabs[deepest_idx] = prob_mat[:, noabs_mask] deepest_idx += 1 prob_dict.pop(path) # Case 7: ((-1, -1, -1), (Number != -1, Number != -1, Number != -1)) elif all(x == -1 for x in l_path) and all(x != -1 for x in r_path): l_tuple = (int(l_path[0]), n_int_ABC - 1, n_int_ABC - 1) r_tuple = (int(r_path[0]), int(r_path[1]), int(r_path[2])) new_key = ( l_tuple, r_tuple, ) l_row = [l_tuple[0], l_tuple[1], l_tuple[2]] r_row = [r_tuple[0], r_tuple[1], r_tuple[2]] new_path = new_key omega_start = translate_to_omega(path) (keys_array, paths_array, path_lengths_array, max_subpaths) = ( deep_identify_wrapper( omega_start, absorbing_state, omega_nonrev_counts, inverted_omega_nonrev_counts, new_path, ) ) num_keys = keys_array.shape[0] num_paths = paths_array.shape[0] keys_per_path = paths_array.shape[1] subpaths_per_path = path_lengths_array.shape[1] deepest_keys_acc_array[deepest_idx, :num_keys] = keys_array[:num_keys] deepest_paths_acc_array[ deepest_idx, :num_paths, :keys_per_path, :max_subpaths ] = paths_array deepest_path_lengths_array[deepest_idx, :num_keys, :subpaths_per_path] = ( path_lengths_array ) acc_prob_mats_noabs[deepest_idx] = prob_mat[:, noabs_mask] deepest_idx += 1 prob_dict.pop(path) # Case 8: ((-1, -1, -1), (Number != -1, Number != -1, -1)) elif ( all(x == -1 for x in l_path) and r_path[2] == -1 and all(x != -1 for x in r_path[:2]) ): l_tuple = (int(l_path[0]), n_int_ABC - 1, n_int_ABC - 1) r_tuple = (int(r_path[0]), int(r_path[1]), n_int_ABC - 1) new_key = ( l_tuple, r_tuple, ) l_row = [l_tuple[0], l_tuple[1], l_tuple[2]] r_row = [r_tuple[0], r_tuple[1], r_tuple[2]] new_path = new_key omega_start = translate_to_omega(path) (keys_array, paths_array, path_lengths_array, max_subpaths) = ( deep_identify_wrapper( omega_start, absorbing_state, omega_nonrev_counts, inverted_omega_nonrev_counts, new_path, ) ) num_keys = keys_array.shape[0] num_paths = paths_array.shape[0] keys_per_path = paths_array.shape[1] subpaths_per_path = path_lengths_array.shape[1] deepest_keys_acc_array[deepest_idx, :num_keys] = keys_array[:num_keys] deepest_paths_acc_array[ deepest_idx, :num_paths, :keys_per_path, :max_subpaths ] = paths_array deepest_path_lengths_array[deepest_idx, :num_keys, :subpaths_per_path] = ( path_lengths_array ) acc_prob_mats_noabs[deepest_idx] = prob_mat[:, noabs_mask] deepest_idx += 1 prob_dict.pop(path) # Case 9: ((-1, -1, -1), (-1, -1, -1)) elif all(x == -1 for x in l_path) and all(x == -1 for x in r_path): l_tuple = (int(l_path[0]), n_int_ABC - 1, n_int_ABC - 1) r_tuple = (int(r_path[0]), n_int_ABC - 1, n_int_ABC - 1) new_key = ( l_tuple, r_tuple, ) l_row = [l_tuple[0], l_tuple[1], l_tuple[2]] r_row = [r_tuple[0], r_tuple[1], r_tuple[2]] new_path = new_key omega_start = translate_to_omega(path) (keys_array, paths_array, path_lengths_array, max_subpaths) = ( deep_identify_wrapper( omega_start, absorbing_state, omega_nonrev_counts, inverted_omega_nonrev_counts, new_path, ) ) num_keys = keys_array.shape[0] num_paths = paths_array.shape[0] keys_per_path = paths_array.shape[1] subpaths_per_path = path_lengths_array.shape[1] deepest_keys_acc_array[deepest_idx, :num_keys] = keys_array[:num_keys] deepest_paths_acc_array[ deepest_idx, :num_paths, :keys_per_path, :max_subpaths ] = paths_array deepest_path_lengths_array[deepest_idx, :num_keys, :subpaths_per_path] = ( path_lengths_array ) acc_prob_mats_noabs[deepest_idx] = prob_mat[:, noabs_mask] deepest_idx += 1 prob_dict.pop(path) flattened_keys, flattened_results, total_valid = deepest_parallel_inner( deepest_idx, trans_mat_noabs, omega_dict_noabs, deepest_keys_acc_array, deepest_paths_acc_array, deepest_path_lengths_array, acc_prob_mats_noabs, ) for i in range(total_valid): prob_dict_sum[ ( ( flattened_keys[i][0], flattened_keys[i][1], flattened_keys[i][2], ), ( flattened_keys[i][3], flattened_keys[i][4], flattened_keys[i][5], ), ) ] = np.sum(flattened_results[i]) return prob_dict_sum