| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import random as rd |
| |
|
| | from diffusion import Diffusion |
| | from scoring.scoring_functions import ScoringFunctions |
| | from utils.filter import PeptideAnalyzer |
| | import noise_schedule |
| |
|
| | """" |
| | Notes: store rolled out sequence? |
| | path of node objects or strings? |
| | should we only select valid expandable leaf nodes? |
| | calculate similarity between sibling nodes? |
| | should we evaluate generated sequences? |
| | """ |
| | class Node: |
| | """ |
| | Node class: partially unmasked SMILES string |
| | - parentNode: Node object at previous time step |
| | - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes |
| | - totalReward: vector of cumulative rewards for all K objectives |
| | - visits: number of times the node has been visited by an interation |
| | - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node |
| | - timestep: the time step where the sequence was sampled |
| | - sampleProb: probability of sampling the sequence from the diffusion model |
| | """ |
| | def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None): |
| | self.config = config |
| | self.parentNode = parentNode |
| | self.childNodes = childNodes |
| | self.scoreVector = scoreVector |
| | |
| | |
| | if totalReward is not None: |
| | self.totalReward = totalReward |
| | else: |
| | self.totalReward = np.zeros(self.config.mcts.num_objectives) |
| | |
| | |
| | self.visits = 1 |
| | |
| | |
| | |
| | self.timestep = timestep |
| | |
| | self.sampleProb = sampleProb |
| | |
| | |
| | self.tokens = tokens |
| | |
| | |
| | |
| | def selectNode(self, num_func): |
| | """ |
| | Selects a node to move to among the children nodes |
| | """ |
| | |
| | nodeStatus = self.getExpandStatus() |
| | |
| | |
| | if (nodeStatus == 3): |
| | |
| | paretoFront = {} |
| | for childNode in self.childNodes: |
| | childStatus = childNode.getExpandStatus() |
| | |
| | if childStatus == 2 or childStatus == 3: |
| | selectScore = childNode.calcSelectScore() |
| | paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func) |
| | |
| | |
| | |
| | selected = rd.choice(list(paretoFront.keys())) |
| | |
| | return selected, selected.getExpandStatus() |
| | |
| | |
| | return self, nodeStatus |
| |
|
| | def addChildNode(self, tokens, totalReward, prob=None): |
| | """" |
| | Adds a child node |
| | """ |
| | child = Node(config=self.config, |
| | tokens=tokens, |
| | parentNode=self, |
| | childNodes=[], |
| | totalReward=totalReward, |
| | timestep=self.timestep+1, |
| | sampleProb=prob) |
| | |
| | self.childNodes.append(child) |
| | return child |
| | |
| | def updateNode(self, rewards): |
| | """ |
| | Updates the cumulative rewards vector with the reward vector at a descendent leaf node. |
| | Increments the number of visits to the node. |
| | """ |
| | self.visits += 1 |
| | self.totalReward += rewards |
| | |
| | def calcSelectScore(self): |
| | """ |
| | Calculates the select score for the node from the cumulative rewards vector and number of visits. |
| | - c: determines the degree of exploration |
| | - minSelectScore: determines the |
| | """ |
| | """" |
| | if not self.parentNode: |
| | return 0.0 |
| | """ |
| | |
| | normRewards = self.totalReward / self.visits |
| | if self.sampleProb is not None: |
| | print("Sample Prob") |
| | print(self.sampleProb) |
| | return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits) |
| | return normRewards |
| | |
| | def getExpandStatus(self): |
| | """ |
| | Returns an integer indicating whether the node is a: |
| | 1. terminal node (sequence is fully unmasked) |
| | 2. legal leaf node (partially unmasked sequence that can be expanded) |
| | 3. legal non-leaf node (already expanded sequence with M child nodes) |
| | """ |
| | if self.timestep == self.config.sampling.steps: |
| | return 1 |
| | elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0): |
| | return 2 |
| | return 3 |
| | |
| | """END OF NODE CLASS""" |
| |
|
| | def updateParetoFront(paretoFront, node, scoreVector, num_func): |
| | """ |
| | Removes sequences that are dominated by scoreVector |
| | adds the SMILES sequence if it is non-dominated and its scoreVector |
| | """ |
| | paretoSize = len(paretoFront) |
| | if paretoSize == 0: |
| | |
| | paretoFront[node] = scoreVector |
| | else: |
| | |
| | |
| | |
| | nondominate = [] |
| | |
| | delete = [] |
| | for k, v in paretoFront.items(): |
| | nondominated = scoreVector >= np.asarray(v) |
| | dominant = scoreVector > np.asarray(v) |
| | |
| | if num_func <= len(nondominated): |
| | attn_nondominated = nondominated[:num_func] |
| | attn_dominant = dominant[:num_func] |
| | |
| | |
| | if attn_nondominated.all() and attn_dominant.any(): |
| | |
| | delete.append(k) |
| | |
| | nondominate.append(True) |
| | elif attn_nondominated.all(): |
| | |
| | nondominate.append(True) |
| | else: |
| | |
| | nondominate.append(False) |
| | |
| | nondominate = np.asarray(nondominate) |
| | |
| | if nondominate.all(): |
| | paretoFront[node] = scoreVector |
| | |
| | |
| | while (paretoSize > 0) and (len(delete) > 0): |
| | |
| | del paretoFront[delete[0]] |
| | del delete[0] |
| | paretoSize -= 1 |
| | return paretoFront |
| | |
| | """BEGINNING OF MCTS CLASS""" |
| |
|
| | class MCTS: |
| | def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []): |
| | self.config = config |
| | self.noise = noise_schedule.get_noise(config) |
| | self.time_conditioning = self.config.time_conditioning |
| | |
| | self.peptideParetoFront = {} |
| | self.num_steps = config.sampling.steps |
| | self.num_sequences = config.sampling.num_sequences |
| | |
| | |
| | self.mdlm = mdlm |
| | self.tokenizer = mdlm.tokenizer |
| | self.device = mdlm.device |
| | |
| | if max_sequence_length is None: |
| | self.sequence_length = self.config.sampling.seq_length |
| | else: |
| | self.sequence_length = max_sequence_length |
| | |
| | self.num_iter = config.mcts.num_iter |
| | |
| | self.num_child = config.mcts.num_children |
| | |
| | |
| | self.score_functions = ScoringFunctions(score_func_names, prot_seqs) |
| | self.num_func = num_func |
| | self.iter_num = 0 |
| | self.curr_num_func = 1 |
| | self.analyzer = PeptideAnalyzer() |
| | |
| | |
| | self.valid_fraction_log = [] |
| | self.affinity1_log = [] |
| | self.affinity2_log = [] |
| | self.permeability_log = [] |
| | self.sol_log = [] |
| | self.hemo_log = [] |
| | self.nf_log = [] |
| | |
| | def reset(self): |
| | self.iter_num = 0 |
| | self.valid_fraction_log = [] |
| | self.affinity1_log = [] |
| | self.affinity2_log = [] |
| | self.permeability_log = [] |
| | self.sol_log = [] |
| | self.hemo_log = [] |
| | self.nf_log = [] |
| | self.peptideParetoFront = {} |
| | |
| | def forward(self, rootNode): |
| | self.reset() |
| | |
| | while (self.iter_num < self.num_iter): |
| | self.iter_num += 1 |
| | |
| | |
| | leafNode, _ = self.select(rootNode) |
| | |
| | |
| | |
| | self.expand(leafNode) |
| | |
| | |
| | return self.peptideParetoFront |
| |
|
| | |
| | def updateParetoFront(self, sequence, scoreVector, tokens): |
| | """ |
| | Removes sequences that are dominated by scoreVector |
| | adds the SMILES sequence if it is non-dominated and its scoreVector |
| | |
| | num_func: index of the last objective to consider when updating the pareto front from 0 to K |
| | """ |
| | paretoSize = len(self.peptideParetoFront) |
| | |
| | self.curr_num_func = 1 |
| | |
| | for i in range(len(self.num_func)): |
| | if self.iter_num >= self.num_func[i]: |
| | self.curr_num_func = i+1 |
| | |
| | if paretoSize == 0: |
| | |
| | self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
| | |
| | rewardVector = np.ones(len(scoreVector)) |
| | else: |
| | |
| | |
| | |
| | nondominate = [] |
| | |
| | delete = [] |
| | |
| | rewardVector = np.zeros(len(scoreVector)) |
| | for k, v in self.peptideParetoFront.items(): |
| | |
| | |
| | |
| | nondominated = scoreVector >= np.asarray(v['scores']) |
| | dominant = scoreVector > np.asarray(v['scores']) |
| | |
| | rewardVector += nondominated |
| |
|
| | if self.curr_num_func <= len(nondominated): |
| | attn_nondominated = nondominated[:self.curr_num_func] |
| | attn_dominant = dominant[:self.curr_num_func] |
| | |
| | |
| | |
| | if attn_nondominated.all() and attn_dominant.any(): |
| | |
| | delete.append(k) |
| | |
| | nondominate.append(True) |
| | elif attn_nondominated.all(): |
| | |
| | nondominate.append(True) |
| | else: |
| | |
| | nondominate.append(False) |
| | |
| | assert len(nondominate) == paretoSize |
| | nondominate = np.asarray(nondominate) |
| | |
| | |
| | if nondominate.all() or paretoSize < self.num_sequences: |
| | self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
| | |
| | rewardVector = rewardVector / paretoSize |
| | |
| | |
| | while (paretoSize > self.num_sequences) and (len(delete) > 0): |
| | |
| | del self.peptideParetoFront[delete[0]] |
| | del delete[0] |
| | paretoSize -= 1 |
| | |
| | return rewardVector |
| |
|
| | def isPathEnd(self, path, maxDepth): |
| | """ |
| | Checks if the node is completely unmasked (ie. end of path) |
| | or if the path is at the max depth |
| | """ |
| | if (path[-1] != self.config.mcts.mask_token).all(): |
| | return True |
| | elif len(path) >= maxDepth: |
| | return True |
| | return False |
| | |
| | def select(self, currNode): |
| | """ |
| | Traverse the tree from the root node until reaching a legal leaf node |
| | """ |
| | while True: |
| | currNode, nodeStatus = currNode.selectNode(self.curr_num_func) |
| | if nodeStatus != 3: |
| | return currNode, nodeStatus |
| | |
| | def expand(self, parentNode, eps=1e-5, checkSimilarity = True): |
| | """ |
| | Sample unmasking steps from the pre-trained MDLM |
| | adds num_children partially unmasked sequences to the children of the parentNode |
| | """ |
| | |
| | num_children = self.config.mcts.num_children |
| | |
| | allChildReward = np.zeros_like(parentNode.totalReward) |
| | |
| | |
| | |
| | |
| | num_rollout_steps = self.num_steps - parentNode.timestep |
| | |
| | rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device) |
| | dt = (1 - eps) / self.num_steps |
| | p_x0_cache = None |
| | |
| | |
| | x = parentNode.tokens['input_ids'].to(self.device) |
| | attn_mask = parentNode.tokens['attention_mask'].to(self.device) |
| | |
| | t = rollout_t[0] * torch.ones(num_children, 1, device = self.device) |
| | |
| | print("token array:") |
| | print(x) |
| | p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x, |
| | t=t, dt=dt, |
| | batch_size=num_children, |
| | attn_mask=attn_mask) |
| | x_rollout = x_children |
| | |
| | for i in range(1, num_rollout_steps): |
| | t = rollout_t[i] * torch.ones(num_children, 1, device = self.device) |
| | |
| | p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout, |
| | t=t, dt=dt, p_x0=p_x0_cache, |
| | attn_mask=attn_mask) |
| | |
| | if (not torch.allclose(x_next, x) or self.time_conditioning): |
| | |
| | p_x0_cache = None |
| | |
| | x_rollout = x_next |
| | |
| | if self.config.sampling.noise_removal: |
| | t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device) |
| | """if self.sampler == 'analytic': |
| | x = self.mdlm._denoiser_update(x, t) |
| | else:""" |
| | time_cond = self.noise(t)[0] |
| | x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) |
| | |
| | childSequences = self.tokenizer.batch_decode(x_rollout) |
| | |
| | validSequences = [] |
| | maskedTokens = [] |
| | unmaskedTokens = [] |
| | for i in range(num_children): |
| | childSeq = childSequences[i] |
| | |
| | rewardVector = np.zeros(self.config.mcts.num_objectives) |
| | |
| | |
| | if self.analyzer.is_peptide(childSeq): |
| | validSequences.append(childSeq) |
| | maskedTokens.append(x_children[i]) |
| | unmaskedTokens.append(x_rollout[i]) |
| | else: |
| | childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask} |
| | parentNode.addChildNode(tokens=childTokens, |
| | totalReward=rewardVector) |
| | |
| | if (len(validSequences) != 0): |
| | scoreVectors = self.score_functions(input_seqs=validSequences) |
| | average_scores = scoreVectors.T |
| | if self.config.mcts.single: |
| | self.permeability_log.append(average_scores[0]) |
| | else: |
| | self.affinity1_log.append(average_scores[0]) |
| | self.sol_log.append(average_scores[1]) |
| | self.hemo_log.append(average_scores[2]) |
| | self.nf_log.append(average_scores[3]) |
| | if self.config.mcts.perm: |
| | self.permeability_log.append(average_scores[4]) |
| | elif self.config.mcts.dual: |
| | self.affinity2_log.append(average_scores[4]) |
| | else: |
| | self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | |
| | if self.config.mcts.perm: |
| | self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | elif self.config.mcts.dual: |
| | self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| | |
| | for i, validSeq in enumerate(validSequences): |
| | |
| | scoreVector = scoreVectors[i] |
| | |
| | |
| | rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i]) |
| | print(scoreVector) |
| | print(rewardVector) |
| | |
| | |
| | allChildReward += rewardVector |
| | |
| | |
| | childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask} |
| | parentNode.addChildNode(tokens=childTokens, |
| | totalReward=rewardVector) |
| | |
| | |
| | invalid = (num_children - len(validSequences)) / num_children |
| |
|
| | valid_fraction = len(validSequences) / num_children |
| | print(f"Valid fraction: {valid_fraction}") |
| | self.valid_fraction_log.append(valid_fraction) |
| | |
| | print(self.config.mcts.invalid_penalty) |
| | |
| | allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid) |
| | |
| | self.backprop(parentNode, allChildReward) |
| |
|
| |
|
| | def backprop(self, node, reward_vector): |
| | |
| | while node: |
| | node.updateNode(reward_vector) |
| | node = node.parentNode |
| | |
| |
|
| | def getSequenceForObjective(self, objective_index, k): |
| | """ |
| | Returns the top-k sequences in the pareto front that has the best score for |
| | a given objective and their score vectors for all objectives |
| | """ |
| | |
| | |
| | topk = {} |
| | |
| | peptides = [] |
| | objectiveScores = [] |
| | for k, v in self.peptideParetoFront.items(): |
| | |
| | peptides.append(k) |
| | |
| | objectiveScores.append(v['token_ids'][objective_index]) |
| | |
| | objectiveScores = torch.tensor(objectiveScores) |
| | topKScores = torch.topk(objectiveScores, k) |
| | for (_, index) in topKScores.items(): |
| | seq = peptides[index] |
| | |
| | topk[seq] = self.peptideParetoFront.get(seq) |
| | |
| | return topk |
| | |
| |
|