diff --git a/alphafold_pytorch_jit/backbones.py b/alphafold_pytorch_jit/backbones.py index 28d4e194733cdea18f2375e85dde53913d3c47b0..67470cc9dd48572f02c302bf53b419a4a716e16d 100644 --- a/alphafold_pytorch_jit/backbones.py +++ b/alphafold_pytorch_jit/backbones.py @@ -8,7 +8,7 @@ from alphafold_pytorch_jit.basics import ( import os bf16 = (os.environ.get('AF2_BF16') == '1') - +fp16 = (os.environ.get('AF2_FP16') == '1') class Transition(nn.Module): def __init__(self,config, global_config, act_dim): @@ -108,6 +108,18 @@ class TriangleMultiplication(nn.Module): self.gating_linear = nn.Linear(act_dim,act_dim) def forward(self, act, mask): + profile_trianglemultiplication = (os.environ.get('PROFILE_TRIANGLEMULTIPLICATION') == '1') + if profile_trianglemultiplication : + prof_trianglemultiplication = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0, warmup=0, active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor",f"920f-trianglemultiplication-thread{torch.get_num_threads()}"), + record_shapes=True, + profile_memory=True, + with_stack=False, + use_cuda=False + ) + prof_trianglemultiplication.start() mask = mask[..., None] act = self.layer_norm_input(act) input_act = act # For gate @@ -125,6 +137,8 @@ class TriangleMultiplication(nn.Module): act = self.center_layer_norm(act) act = self.output_projection(act) act *= torch.sigmoid(self.gating_linear(input_act)) + if profile_trianglemultiplication: + prof_trianglemultiplication.stop() return act @@ -191,8 +205,8 @@ class TriangleAttention(nn.Module): class NoExtraEvoformerIteration(nn.Module): - - def __init__(self, config, global_config, is_extra_msa,a_dim, m_dim, pa_dim): + noextra_evoformer_count = 0 + def __init__(self, config, global_config, is_extra_msa,a_dim, m_dim, pa_dim, evoformer_num_block): super().__init__() """Builds EvoformerIteration module. @@ -209,6 +223,9 @@ class NoExtraEvoformerIteration(nn.Module): Returns: Outputs, same shape/type as act. """ + NoExtraEvoformerIteration.noextra_evoformer_count += 1 + self.noextra_iteration = NoExtraEvoformerIteration.noextra_evoformer_count + self.evoformer_num_block = evoformer_num_block self.config = config self.global_config = global_config c = config @@ -237,6 +254,12 @@ class NoExtraEvoformerIteration(nn.Module): msa_mask = msa_mask.to(torch.bfloat16) pair_mask = pair_mask.to(torch.bfloat16) + if fp16 == True: + msa_act = msa_act.to(torch.float16) + pair_act = pair_act.to(torch.float16) + msa_mask = msa_mask.to(torch.float16) + pair_mask = pair_mask.to(torch.float16) + #msa_act, pair_act = activations['msa'], activations['pair'] #msa_mask, pair_mask = masks['msa'], masks['pair'] msa_act = msa_act + self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act=pair_act) @@ -261,8 +284,8 @@ class NoExtraEvoformerIteration(nn.Module): class ExtraEvoformerIteration(nn.Module): - - def __init__(self, config, global_config, is_extra_msa,a_dim, m_dim, pa_dim): + extra_evoformer_count = 0 + def __init__(self, config, global_config, is_extra_msa,a_dim, m_dim, pa_dim, extra_msa_stack_num_block): super().__init__() """Builds EvoformerIteration module. @@ -279,6 +302,9 @@ class ExtraEvoformerIteration(nn.Module): Returns: Outputs, same shape/type as act. """ + ExtraEvoformerIteration.extra_evoformer_count += 1 + self.extra_iteration = ExtraEvoformerIteration.extra_evoformer_count + self.extra_msa_stack_num_block = extra_msa_stack_num_block self.config = config self.global_config = global_config c = config diff --git a/alphafold_pytorch_jit/basics.py b/alphafold_pytorch_jit/basics.py index 004bf0fae75a81d107ac9a28f041aee37455df54..35fe843d0fddb9eaa4b71c93700dd7ae275f887f 100644 --- a/alphafold_pytorch_jit/basics.py +++ b/alphafold_pytorch_jit/basics.py @@ -4,6 +4,7 @@ from torch.nn import functional as F import numpy as np import pdb import time +import os def mask_mean(mask, value, axis=torch.Tensor(), drop_mask_channel=torch.Tensor([False]), eps=torch.Tensor([1e-10])): @@ -523,6 +524,19 @@ class GatingAttention(nn.Module): Returns: A float32 tensor of shape [batch_size, N_queries, output_dim]. """ + profile_gatingattention = (os.environ.get('PROFILE_GATINGATTENTION') == '1') + if profile_gatingattention: + prof_gatingattention = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor",f"920f-gatingattention-thread{torch.get_num_threads()}"), + record_shapes=True, + profile_memory=True, + with_stack=False, + use_cuda=False + ) + prof_gatingattention.start() + # get query, key, value q = torch.einsum('bqa,ahc->bqhc', q_data, self.query_w) * self.key_dim**(-0.5) k = torch.einsum('bka,ahc->bkhc', m_data, self.key_w) @@ -544,5 +558,7 @@ class GatingAttention(nn.Module): weighted_avg *= gate_values # linear(res_gated) -> output output = torch.einsum('bqhc,hco->bqo', weighted_avg, self.output_w) + self.output_b + if profile_gatingattention: + prof_gatingattention.stop() return output diff --git a/alphafold_pytorch_jit/embeddings.py b/alphafold_pytorch_jit/embeddings.py index a829f9e0495a824f066b98ba739317e599ae4a65..b559a6ac34363be5a1cba1f6d875c7369c5fcd80 100644 --- a/alphafold_pytorch_jit/embeddings.py +++ b/alphafold_pytorch_jit/embeddings.py @@ -8,11 +8,17 @@ from alphafold_pytorch_jit.backbones import TriangleAttention, TriangleMultiplic import os bf16 = (os.environ.get('AF2_BF16') == '1') +fp16 = (os.environ.get('AF2_FP16') == '1') + class TemplatePairSubStack(nn.Module): + template_pairsubstack_count = 0 """Pair stack for the templates.""" - def __init__(self, config, global_config, pa_dim): + def __init__(self, config, global_config, pa_dim, num_block): super().__init__() + TemplatePairSubStack.template_pairsubstack_count += 1 + self.template_pairsubstack_iteration = TemplatePairSubStack.template_pairsubstack_count + self.num_block = num_block c = config gc = global_config self.c = config @@ -42,7 +48,7 @@ class TemplatePairStack(nn.Module): self.gc = global_config self.num_block = self.c['num_block'] self.template_pair_sub_stack = nn.ModuleList([ - TemplatePairSubStack(self.c,self.gc,a_dim) + TemplatePairSubStack(self.c,self.gc,a_dim,self.c['num_block']) for i in range(self.c['num_block']) ]) @@ -119,6 +125,11 @@ class SingleTemplateEmbedding(nn.Module): if bf16 == True: mask_2d = mask_2d.to(torch.bfloat16) + if fp16 == True: + mask_2d = mask_2d.to(torch.float16) + query_embedding = query_embedding.to(torch.float16) + + assert mask_2d.dtype == query_embedding.dtype num_res = template_aatype.shape[0] template_mask_2d = template_pseudo_beta_mask[:, None] * template_pseudo_beta_mask[None, :] diff --git a/alphafold_pytorch_jit/net.py b/alphafold_pytorch_jit/net.py index 9efa198a5dd3f0469b6effe3ba38c553b42f9534..2581c7bfc089fc725c6d3c010366efc5d7f8e511 100644 --- a/alphafold_pytorch_jit/net.py +++ b/alphafold_pytorch_jit/net.py @@ -2,6 +2,7 @@ from typing import Any, Union, Mapping from alphafold_pytorch_jit import features import tensorflow.compat.v1 as tf from torch import nn +import time import os import jax import numpy as np diff --git a/alphafold_pytorch_jit/subnets.py b/alphafold_pytorch_jit/subnets.py index 8233e397d3d257534af0eef8502a9440cd7d3ef5..a9038e3b2a55042945dbedb76dda1521a6862e1b 100644 --- a/alphafold_pytorch_jit/subnets.py +++ b/alphafold_pytorch_jit/subnets.py @@ -18,8 +18,22 @@ from alphafold_pytorch_jit.weight_io import filtered_pth_params from alphafold_pytorch_jit.utils import detached, list2tensor import jax import time +import datetime import os import pickle +import kpex + +global fast_test +global fast_test_msax2_evox12 +if os.environ.get("FAST_TEST"): + fast_test = 1 +else: + fast_test = 0 + +if os.environ.get("FAST_TEST_MSAX2_EVOX12"): + fast_test_msax2_evox12 = 1 +else: + fast_test_msax2_evox12 = 0 class EmbeddingsAndEvoformer(nn.Module): @@ -44,24 +58,35 @@ class EmbeddingsAndEvoformer(nn.Module): self.right_single = nn.Linear(init_dims['target_feat'], self.c['pair_channel']) self.extra_msa_activations = nn.Linear(25,self.c['extra_msa_channel']) print('### [INFO] build evoformer network') + + if fast_test == 1: + self.c['extra_msa_stack_num_block'] = 1 + self.c['evoformer_num_block'] = 1 + + if fast_test_msax2_evox12 == 1: + self.c['extra_msa_stack_num_block'] = 2 + self.c['evoformer_num_block'] = 6 + self.extra_msa_stack = nn.ModuleList([ - ExtraEvoformerIteration( - self.c['evoformer'], - self.gc, - True, # is_extra_msa is True - self.c['extra_msa_channel'], - self.c['extra_msa_channel'], - self.c['pair_channel']) - for i in range(self.c['extra_msa_stack_num_block'])]) + ExtraEvoformerIteration( + self.c['evoformer'], + self.gc, + True, # is_extra_msa is True + self.c['extra_msa_channel'], + self.c['extra_msa_channel'], + self.c['pair_channel'], + self.c['extra_msa_stack_num_block']) + for i in range(self.c['extra_msa_stack_num_block'])]) self.evoformer_iteration = nn.ModuleList([ - NoExtraEvoformerIteration( - self.c['evoformer'], - self.gc, - False, # is_extra_msa is False - self.c['msa_channel'], - self.c['msa_channel'], - self.c['pair_channel']) - for i in range(self.c['evoformer_num_block'])]) + NoExtraEvoformerIteration( + self.c['evoformer'], + self.gc, + False, # is_extra_msa is False + self.c['msa_channel'], + self.c['msa_channel'], + self.c['pair_channel'], + self.c['evoformer_num_block']) + for i in range(self.c['evoformer_num_block'])]) self.single_activations = nn.Linear(self.c['msa_channel'],self.c['seq_channel']) self.prev_pos_linear = nn.Linear(15,self.c['pair_channel']) self.prev_msa_first_row_norm = nn.LayerNorm(normalized_shape=256,elementwise_affine=True) @@ -101,7 +126,33 @@ class EmbeddingsAndEvoformer(nn.Module): torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask - ): + ): + profile_embedding = (os.environ.get('PROFILE_EMBEDDING') == '1') + if profile_embedding: + prof_embedding = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor","920f-embedding"), + record_shapes=True, + profile_memory=True, + with_stack=False, + use_cuda=False + ) + prof_embedding.start() + + profile_embeddingsandevoformer = (os.environ.get('PROFILE_EMBEDDINGSANDEVOFORMER') == '1') + if profile_embeddingsandevoformer: + prof_embeddingsandevoformer = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor","920f-embeddingsandevoformer_iterX1"), + record_shapes=True, + profile_memory=True, + with_stack=True, + use_cuda=False + ) + prof_embeddingsandevoformer.start() + ### start here: computing stuck at 2nd alphafold_iteration t1_embedding = self.read_time() print(' # [INFO] linear embedding of features') @@ -131,25 +182,25 @@ class EmbeddingsAndEvoformer(nn.Module): offset = pos[:, None] - pos[None, :] rel_sub_input = offset.long() + self.max_relative_feature rel_pos = F.one_hot( - rel_sub_input.clip( - min=0, - max=int(2 * self.max_relative_feature) - ).long(), - 2 * self.max_relative_feature + 1) + rel_sub_input.clip( + min=0, + max=int(2 * self.max_relative_feature) + ).long(), + 2 * self.max_relative_feature + 1) pair_activations += self.pair_activiations(rel_pos.float()) ### stop here: computing stuck at 2nd alphafold_iteration if self.template_enabled: print(' ## [INFO] execute template embedding') template_pair_representation = self.template_embedding( - pair_activations, - template_mask, - template_aatype, - template_pseudo_beta_mask, - template_pseudo_beta, - template_all_atom_positions, - template_all_atom_masks, - mask_2d) + pair_activations, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d) pair_activations += template_pair_representation # Embed extra MSA features. print(' ## [INFO] execute extra_msa_activations') @@ -158,6 +209,8 @@ class EmbeddingsAndEvoformer(nn.Module): # Extra MSA Stack. print(' ## [INFO] execute extra_msa_iterations') n_msa_iters = len(self.extra_msa_stack) + + print(f"# [INFO] execute extra_msa_iterations count is {self.c['extra_msa_stack_num_block']}") for i, extra_msa_iter in enumerate(self.extra_msa_stack): print(' # [INFO] execute extra_msa_iter %d/%d' % (i+1, n_msa_iters)) extra_msa_output = extra_msa_iter( @@ -190,28 +243,46 @@ class EmbeddingsAndEvoformer(nn.Module): torch.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), torch.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), ret['torsion_angles_mask']] - , dim=-1) + , dim=-1) template_activations = self.template_single_embedding(template_features) template_activations = F.relu(template_activations) template_activations = self.template_projection(template_activations) # Concatenate the templates to the msa. evoformer_input['msa'] = torch.cat( - [evoformer_input['msa'], - template_activations], - dim=0) + [evoformer_input['msa'], + template_activations], + dim=0) # Concatenate templates masks to the msa masks. # Use mask from the psi angle, as it only depends on the backbone atoms # from a single residue. torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] # Concatenate the templates to the msa. evoformer_masks['msa'] = torch.cat( - [evoformer_masks['msa'], - torsion_angle_mask], - dim=0) + [evoformer_masks['msa'], + torsion_angle_mask], + dim=0) t2_embedding = self.read_time() print(' # [TIME] total embedding duration =', (t2_embedding - t1_embedding), 'sec') + if profile_embedding: + prof_embedding.stop() print(' ## [INFO] execute evoformer_iterations') + print(f" ## [INFO] evoformer_iterations count is {self.c['evoformer_num_block']}") + + profile_evoformer = (os.environ.get('PROFILE_EVOFORMER') == '1') + if profile_evoformer: + prof_evoformer = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor","920f-evoformer"), + record_shapes=True, + profile_memory=True, + with_stack=True, + use_cuda=False + ) + prof_evoformer.start() + + t1_evoformer= self.read_time() for i, evoformer_iter in enumerate(self.evoformer_iteration): evoformer_input = evoformer_iter( @@ -233,6 +304,14 @@ class EmbeddingsAndEvoformer(nn.Module): output['msa_first_row'] = msa_activations[0] t2_evoformer = self.read_time() print(' # [TIME] total evoformer duration =', (t2_evoformer - t1_evoformer), 'sec') + + if profile_evoformer: + prof_evoformer.stop() + + + if profile_embeddingsandevoformer: + prof_embeddingsandevoformer.stop() + return output @@ -259,7 +338,8 @@ class AlphaFold(nn.Module): self.gc, evo_init_dims, struct_apply, - af2iter_params + af2iter_params, + struct_params, ) ### filter input params if af2iter_params is not None: @@ -294,14 +374,12 @@ class AlphaFold(nn.Module): num_ensemble = batch_size // (self.config['num_recycle'] + 1) ### helper func: slice out the batch of current recycle def slice_recycle_idx(x): - x = x.detach().cpu().numpy() + x = x.detach() start = recycle_idx * num_ensemble size = num_ensemble - res = torch.tensor(np.array( - jax.lax.dynamic_slice_in_dim(x, start, size, axis=0))) + res = x[start:start + size] return res - ensembled_batch = jax.tree_map(slice_recycle_idx, batch) - ensembled_batch = jax.tree_map(torch.tensor, ensembled_batch) + ensembled_batch = {k: slice_recycle_idx(v) for k, v in batch.items()} del batch # else: num_ensemble = batch_size @@ -323,7 +401,7 @@ class AlphaFold(nn.Module): batch, compute_loss=False, ensemble_representations=False, - return_representations=False + return_representations=(os.environ.get('COMPARE_ACTIVE_REPRESENTATIONS') == '1') ): print('### [INFO] jit cocmpilation') # [issue] PyTorch 1.11 has a bug at 2nd alphafold iter self.impl.compile() # use jit.script to compile alphafolditeration @@ -347,6 +425,9 @@ class AlphaFold(nn.Module): num_iter = self.config['num_recycle'] ### recycling loop #num_iter = 0 # [inc TODO] debug for INC, plz remove this flag after debug finished + if fast_test == 1: + num_iter = 0 # debug + print(f"### [INFO] AlphaFold Iteration count is {num_iter}") for i in range(0, num_iter+1): print('### [INFO] start AlphaFold Iteration-%d' % (i+1)) t0 = time.time() @@ -379,12 +460,12 @@ class AlphaFold(nn.Module): class AlphaFoldIteration(nn.Module): - def __init__(self, config, global_config, evo_init_dims,struct_apply, af2iter_params, name='alphafold_iteration'): + def __init__(self, config, global_config, evo_init_dims,struct_apply, af2iter_params, struct_params, name='alphafold_iteration'): super().__init__() self.c = config self.gc = global_config self.evoformer = EmbeddingsAndEvoformer( - self.c['embeddings_and_evoformer'], + self.c['embeddings_and_evoformer'], self.gc, evo_init_dims) self.heads = OrderedDict() @@ -392,20 +473,20 @@ class AlphaFoldIteration(nn.Module): if not head_config['weight'] or head_name in ['structure_module']: continue # Do not instantiate zero-weight heads. head_factory = { - 'masked_msa': MaskedMsaHead, - 'distogram': DistogramHead, - 'predicted_lddt': PredictedLDDTHead, - 'predicted_aligned_error': PredictedAlignedErrorHead, - 'experimentally_resolved': ExperimentallyResolvedHead, + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, }[head_name] self.heads[head_name] = ( - head_factory(head_config, self.gc)) + head_factory(head_config, self.gc)) self.heads['structure_module'] = struct_apply if af2iter_params is not None: # need to load the parameters for plddt head seperately res = {} for name in list(self.heads['predicted_lddt'].state_dict().keys()): - res[name] = af2iter_params['predicted_lddt_head.' + name] + res[name] = af2iter_params['predicted_lddt_head.' + name] assert res.keys() == self.heads['predicted_lddt'].state_dict().keys() res = OrderedDict(res) self.heads['predicted_lddt'].load_state_dict(res) @@ -452,7 +533,7 @@ class AlphaFoldIteration(nn.Module): 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', - 'msa_mask'] + 'msa_mask'] ordered_values=[] for i in values_keys_order: ordered_values.append(batch[i]) @@ -477,12 +558,12 @@ class AlphaFoldIteration(nn.Module): return time.time() def forward(self, - ensembled_batch, - Struct_Params, - rng, - non_ensembled_batch=None, - ensemble_representations=True, - idx=-1 + ensembled_batch, + Struct_Params, + rng, + non_ensembled_batch=None, + ensemble_representations=True, + idx=-1 ): num_ensemble = ensembled_batch['seq_length'].shape[0] if not ensemble_representations: @@ -490,6 +571,8 @@ class AlphaFoldIteration(nn.Module): '''EmbeddingsAndEvoformer part''' batch0 = self._slice_batch(0, ensembled_batch, non_ensembled_batch) evo_input=self.batch_expand(batch0) + for i in [0, 1, 2, 3, 5, 6, 8, 9, 14, 17, 18]: + evo_input[i] = evo_input[i].to(kpex._C.device()) print(' # [INFO] start evoformer iteration',idx) representations = self.evoformer(*evo_input) #print(self.evoformer.graph) @@ -503,10 +586,28 @@ class AlphaFoldIteration(nn.Module): if k != 'msa': representations[k] /= num_ensemble representations['msa'] = msa_representation + for i in representations: + if i != 'pair': + representations[i] = representations[i].to('cpu') '''structure_module part''' ret = {} ret['representations'] = representations + + profile_heads = (os.environ.get('PROFILE_HEADS') == '1') + + if profile_heads: + prof_heads = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor","920f-head"), + record_shapes=True, + profile_memory=True, + with_stack=False, + use_cuda=False + ) + prof_heads.start() + t1_head = self.read_time() for name, (module) in self.heads.items(): if name in ('predicted_lddt', 'predicted_aligned_error'): @@ -516,7 +617,9 @@ class AlphaFoldIteration(nn.Module): continue representations_hk = jax.tree_map(detached,representations) batch_hk = jax.tree_map(detached,batch0) + head_module = self.read_time() res_hk = module(Struct_Params,rng,representations_hk,batch_hk) + head_module2 = self.read_time() ret[name] = jax.tree_map(list2tensor,res_hk) del res_hk if 'representations' in ret[name].keys(): @@ -527,6 +630,7 @@ class AlphaFoldIteration(nn.Module): # f_tmp_plddt = f_tmp_plddt + '-1.pkl' # with open(f_tmp_plddt, 'wb') as h_tmp: # pickle.dump(representations['structure_module'], h_tmp, protocol=4) + print(' # [TIME] head module duration =', (head_module2 - head_module), 'sec') else: ret[name] = module(representations) if 'representations' in ret[name]: @@ -547,6 +651,9 @@ class AlphaFoldIteration(nn.Module): ret[name] = module(representations) t2_head = self.read_time() print(' # [TIME] total heads duration =', (t2_head - t1_head), 'sec') + + if profile_heads: + prof_heads.stop() #del representations return ret diff --git a/alphafold_pytorch_jit/utils.py b/alphafold_pytorch_jit/utils.py index 2dd645c0c27e102985d08d91d26189f2cea8b786..31e1f23a2c501b8d8a001f416c7fba45ea1ec08c 100644 --- a/alphafold_pytorch_jit/utils.py +++ b/alphafold_pytorch_jit/utils.py @@ -90,7 +90,11 @@ def batched_gather(params, indices, axis=0, batch_dims=0): def detached(x:torch.Tensor): - return x.detach().cpu().numpy() + x = x.detach().cpu() + if x.dtype == torch.bfloat16: + x = x.to(torch.float) + return x.numpy() + # return x.detach().cpu().numpy() def unwrap_tensor(x:torch.Tensor): diff --git a/psi_run_af2.sh b/psi_run_af2.sh new file mode 100644 index 0000000000000000000000000000000000000000..98439a454c1d9878e67d66ee6140b1342a3b2237 --- /dev/null +++ b/psi_run_af2.sh @@ -0,0 +1,58 @@ +ps aux | grep run_psi_af2 | awk '{ print $2}' | xargs kill -9 + +# 路径换成自己的 +bisheng_path=/root/pacific_ext/HPCKit_25.0.0_Linux-aarch64/HPCKit/25.0.0/compiler/bisheng/ +kml_path=/root/pacific_ext/HPCKit_25.0.0_Linux-aarch64/HPCKit/25.0.0/kml/bisheng/lib/ +export LD_PRELOAD=$bisheng_path/lib/jemalloc-4kbpage/libjemalloc.so:$LD_PRELOAD +export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1" +export LD_LIBRARY_PATH=$kml_path:$LD_LIBRARY_PATH +export AF2_BF16=1 +#export USE_TPP=1 + +# profile 不同阶段开关 +# export PROFILE_EMBEDDING=1 +# export PROFILE_EVOFORMER=1 +# export PROFILE_HEADS=1 +# export PROFILE_GATINGATTENTION=1 +# export PROFILE_TRIANGLEMULTIPLICATION=1 +# export PROFILE_EMBEDDINGSANDEVOFORMER=1 +# export PROFILE_ALL=1 + +rm -f structure_align/tmalign* +rm -f structure_align/vis*.html +rm -f 920f-unrelaxed*.pdb +rm -f *right_extra*.pt +rm -f *right_noextra*.pt +rm -r right_result*.pkl + +export PRINT_LDDT=0 +export DISTRIBUTED_MPI=1 +#export FAST_TEST=1 +export DISTRIBUTED_EXTRABEFORE_OUTER=1 +export DISTRIBUTED_EXTRAAFTER_OUTER=1 +export DISTRIBUTED_NOEXTRABEFORE_OUTER=1 +export DISTRIBUTED_NOEXTRAAFTER_OUTER=1 +export DISTRIBUTED_EMBEDDING_TRIANGLE=1 + +#export SAVE_ACTIVE_VALUES=1 +#export COMPARE_ACTIVE_REPRESENTATIONS=1 + +export SHMID=test +thread=36 +rankfile=rankfile32corer +rank_num=2 + +export SEQ_TYPE=T1050 + +if [ ! "$DISTRIBUTED_MPI" = "1" ]; then + export OMP_NUM_THREADS=$thread + start_cpu_id=0 + numactl -C $start_cpu_id-$((start_cpu_id+thread-1)) -m 0-3 python run_psi_af2.py +else + mpirun --allow-run-as-root -n $rank_num --map-by rankfile:file=$rankfile -x OMP_NUM_THREADS=$thread --report-bindings -x UCX_TLS=self,sm python -u run_psi_af2.py +fi + +# 在HBM-flat 模式运行 +# export OMP_NUM_THREADS=32 +# export UCX_TLS=self,sm +# mpirun --map-by rankfile:file=rankfile32core --report-bindings -np 1 numactl -m 16-17 python run_psi_af2.py : -np 1 numactl -m 18-19 python run_psi_af2.py : -np 1 numactl -m 20-21 python run_psi_af2.py : -np 1 numactl -m 22-23 python run_psi_af2.py : -np 1 numactl -m 24-25 python run_psi_af2.py : -np 1 numactl -m 26-27 python run_psi_af2.py : -np 1 numactl -m 28-29 python run_psi_af2.py : -np 1 numactl -m 30-31 python run_psi_af2.py \ No newline at end of file diff --git a/rankfile32core b/rankfile32core new file mode 100644 index 0000000000000000000000000000000000000000..d35e7d560148d0798695c67e607aefdf3d68a844 --- /dev/null +++ b/rankfile32core @@ -0,0 +1,16 @@ +rank 0 =localhost slot=0-31 +rank 1 =localhost slot=72-103 +rank 2 =localhost slot=144-175 +rank 3 =localhost slot=216-247 +rank 4 =localhost slot=288-319 +rank 5 =localhost slot=360-391 +rank 6 =localhost slot=432-463 +rank 7 =localhost slot=504-535 +rank 8 =localhost slot=36-67 +rank 9 =localhost slot=108-139 +rank 10 =localhost slot=180-211 +rank 11 =localhost slot=252-283 +rank 12 =localhost slot=324-355 +rank 13 =localhost slot=396-427 +rank 14 =localhost slot=468-499 +rank 15 =localhost slot=540-571 \ No newline at end of file diff --git a/rankfile32corer b/rankfile32corer new file mode 100644 index 0000000000000000000000000000000000000000..bf63df768f6ae6c14bde80f1ef0056184e1e62e1 --- /dev/null +++ b/rankfile32corer @@ -0,0 +1,16 @@ +rank 0 =localhost slot=0-35 +rank 1 =localhost slot=38-73 +rank 2 =localhost slot=76-111 +rank 3 =localhost slot=114-149 +rank 4 =localhost slot=152-187 +rank 5 =localhost slot=190-225 +rank 6 =localhost slot=228-263 +rank 7 =localhost slot=266-301 +rank 8 =localhost slot=304-339 +rank 9 =localhost slot=342-377 +rank 10 =localhost slot=380-415 +rank 11 =localhost slot=418-453 +rank 12 =localhost slot=456-491 +rank 13 =localhost slot=494-529 +rank 14 =localhost slot=532-567 +rank 15 =localhost slot=570-605 \ No newline at end of file diff --git a/rankfile64core b/rankfile64core new file mode 100644 index 0000000000000000000000000000000000000000..5cbb5a8afeae2994a7e5f7e3f6187a4c7b93a69b --- /dev/null +++ b/rankfile64core @@ -0,0 +1,8 @@ +rank 0 =localhost slot=144-207 +rank 1 =localhost slot=288-351 +rank 2 =localhost slot=0-63 +rank 3 =localhost slot=432-495 +rank 4 =localhost slot=72-135 +rank 5 =localhost slot=216-279 +rank 6 =localhost slot=360-423 +rank 7 =localhost slot=504-567 \ No newline at end of file diff --git a/run_pdb_eval.py b/run_pdb_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..73fab9d173afc6d147a1a79a93a117de2828ae1d --- /dev/null +++ b/run_pdb_eval.py @@ -0,0 +1,148 @@ +# PDB +import numpy as np +import subprocess + + + +def run_tm_align(model_pdb, ref_pdb = 'ref.pdb', result_dir = 'tmalign', tmalign_bin = 'TMalign', rank = 100): + p = subprocess.Popen( f'mkdir -p {result_dir}', shell=True ) + + p = subprocess.Popen( + f'{tmalign_bin} {model_pdb} {ref_pdb} -o {result_dir}/tmalign{rank}' , + stdout=subprocess.PIPE, + shell=True + ) + result = p.communicate()[0].decode() + # print(result) + + for line in result.split('\n'): + if 'TM-score=' in line: + print(line) + + +def parse_align(filename): + coords = {'A':[],'B':[]} + + + with open(filename, 'r') as file: + for line in file: + if line.startswith("ATOM"): + parts = line.split() + x, y, z = float(parts[6]), float(parts[7]), float(parts[8]) + coords[parts[4]].append((x, y, z)) + + return np.array(coords['A']),np.array(coords['B']) + +def calculate_lDDT( + pred_pos, # (B, L, 3) or (L, 3) + true_pos, # (B, L, 3) or (L, 3) + per_atom=False, # if true, return lDDT for each atom (retain L axis) + true_dist_cutoff=15.0, + dist_diff_cutoffs=[0.5, 1.0, 2.0, 4.0], +): + # Calculate predicted and true distance matrices + def pos_to_dist(pos): + return ((pos[..., None, :, :] - pos[..., None, :]) ** 2).sum(-1) ** 0.5 + + pred_dist, true_dist = pos_to_dist(pred_pos), pos_to_dist(true_pos) + + # Get mask for distances to be scored (remove self distances) + L = pred_pos.shape[-2] + dist_mask = (true_dist < true_dist_cutoff) * (1.0 - np.eye(L)) + + # Calculate contribution of each distance to lDDT + dist_diff = np.abs(pred_dist - true_dist) + score = (dist_diff[..., None] < np.array(dist_diff_cutoffs)).mean(-1) + + # Return calculated lDDT + axis = (-1,) if per_atom else (-2, -1) + return (score * dist_mask).sum(axis) / dist_mask.sum(axis) + + +html_template = ''' + + + +
+ + + + + + + + + + + + + + +''' + +def visualize( + result_dir = 'result', + ref_pdb = 'ref.pdb', + rank = 1 +): + with open(ref_pdb) as f: + ref_pdb_data = f.read() + with open(f'{result_dir}/tmalign{rank}.pdb') as f: + model_pdb_data = f.read() + with open(f'{result_dir}/vis{rank}.html','w') as f: + f.write(html_template.replace('[ref_pdb_data]',ref_pdb_data).replace('[model_pdb_data]',model_pdb_data)) + + + +def main( + rank=100, + model_pdb = 'model.pdb', + ref_pdb = 'ref.pdb', + result_dir = 'structure_align', + tmalign_bin = 'TMalign' +): + run_tm_align(model_pdb,ref_pdb=ref_pdb,result_dir=result_dir,tmalign_bin=tmalign_bin,rank=rank) + coords1,coords2 = parse_align(f'{result_dir}/tmalign{rank}') + score = calculate_lDDT(coords1,coords2 ,) + print('Align len:',coords1.shape[0]) + print(f'rank{rank}_LDDT-Ca = {score}') + visualize(result_dir=result_dir,ref_pdb=ref_pdb,rank=rank) + + + +if __name__ == '__main__': + # import argparse + + # parser = + ppdb_path = "920f-unrelaxed.pdb" + main(model_pdb=ppdb_path, + ref_pdb='../af_input/casp14/T1050/T1050_ref.pdb', + result_dir='structure_align', + tmalign_bin='./tools/TM-align/TMalign') diff --git a/run_preprocess.sh b/run_preprocess.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb971f44134811622e396f4df4c7801621b402f0 --- /dev/null +++ b/run_preprocess.sh @@ -0,0 +1,20 @@ +ps aux | grep run_preprocess | awk '{ print $2}' | xargs kill -9 +python run_preprocess.py \ + --fasta_paths=../af_input/fa_data/T1050.fasta \ + --output_dir=outdir --data_dir=/workspace/Alphafold2/dbs/ \ + --model_names=model_1 \ + --data_dir=/workspace/Alphafold2/dbs/ \ + --jackhmmer_binary_path=/home/psi/fxm/alphafold2/hmmer/hmmerinstall/bin/jackhmmer \ + --hhblits_binary_path=/home/psi/compiled/anaconda3/envs/iaf_/bin/hhblits \ + --hhsearch_binary_path=/home/psi/compiled/anaconda3/envs/iaf_/bin/hhsearch \ + --kalign_binary_path=/home/psi/compiled/anaconda3/envs/iaf_/bin/kalign \ + --preset reduced_dbs \ + --uniref90_database_path=/workspace/Alphafold2/dbs/uniref90/uniref90.fasta \ + --mgnify_database_path=/workspace/Alphafold2/dbs/mgnify/mgy_clusters_2022_05.fa \ + --small_bfd_database_path=/workspace/Alphafold2/dbs/small_bfd/bfd-first_non_consensus_sequences.fasta \ + --uniref30_database_path=/workspace/Alphafold2/dbs/uniref30/UniRef30_2021_03 \ + --pdb70_database_path=/workspace/Alphafold2/dbs/pdb70/pdb70 \ + --template_mmcif_dir=/workspace/Alphafold2/dbs/pdb_mmcif/mmcif_files/ \ + --max_template_date=2020-05-12 \ + --obsolete_pdbs_path=/workspace/Alphafold2/dbs/pdb_mmcif/obsolete.dat \ + --n_cpu=40 diff --git a/run_psi_af2.py b/run_psi_af2.py new file mode 100644 index 0000000000000000000000000000000000000000..a38c9768507abf2890e0a635baead7d0b3ea9e35 --- /dev/null +++ b/run_psi_af2.py @@ -0,0 +1,196 @@ +from alphafold_pytorch_jit import net as model +from alphafold.model import config +import jax +import torch +import numpy as np +import os +import time +import pickle + +import torch.distributed as dist +import kpex +import alphafold_pytorch_jit +#kpex.alphafold.optimize(alphafold_pytorch_jit) + +distributed_mpi = (os.environ.get('DISTRIBUTED_MPI') == '1') +print_lddt = (os.environ.get('PRINT_LDDT') == '1') +compare_active_representations = (os.environ.get('COMPARE_ACTIVE_REPRESENTATIONS') == '1') +seq_type = os.environ.get('SEQ_TYPE') + +from runners.timmer import Timmers + +try: + use_tpp = (os.environ.get('USE_TPP') == '1') + if use_tpp: + from alphafold_pytorch_jit.basics import GatingAttention + from tpp_pytorch_extension.alphafold.Alpha_Attention import GatingAttentionOpti_forward + GatingAttention.forward = GatingAttentionOpti_forward + from alphafold_pytorch_jit.backbones import TriangleMultiplication + from tpp_pytorch_extension.alphafold.Alpha_TriangleMultiplication import TriangleMultiplicationOpti_forward + TriangleMultiplication.forward = TriangleMultiplicationOpti_forward + is_tpp = True + print('Running with Intel Optimizations. TPP extension detected !!!!!!!!!!!!!!!!!!!!!!') + else: + is_tpp = False + print('[warning] No TPP extension detected !!!!!!!!!!!!!!!!!!!!!!') +except: + is_tpp = False + print('[warning] No TPP extension detected, will fallback to imperative mode !!!!!!!!!!!!!!!!!!!!!!') + + +def main(): + bf16 = (os.environ.get('AF2_BF16') == '1') + print("bf16 variable: ", bf16) + + fp16 = (os.environ.get('AF2_FP16') == '1') + print("fp16 variable: ", fp16) + + input_path = f'../af_input/casp14/{seq_type}/deepmind/intermediates/processed_features.npz' + data = np.load(input_path, allow_pickle=True) + processed_feature_dict = {} + for k in data.files: + processed_feature_dict[k] = data[k] + processed_feature_dict.keys() + + plddts = {} + + processed_feature_dict = jax.tree_map( + lambda x:torch.tensor(x), processed_feature_dict) + + + from runners.timmer import Timmers + h_timmer = Timmers('time-920f.txt') + + num_ensemble = 1 + random_seed = 0 + torch.manual_seed(random_seed) + model_runners = {} + model_list = ['model_1'] + print("List of models:", model_list) + for model_name in model_list: + model_config = config.model_config(model_name) + model_config['data']['eval']['num_ensemble'] = num_ensemble + root_params = '../af_input/weights/extracted/' + model_name + model_runner = model.RunModel( + model_config, + root_params, + h_timmer, + random_seed) + model_runners[model_name] = model_runner + + model_runner = model_runners['model_1'] + model_runner = kpex.tpp.alphafold.alphafold.kpex_alphafold(model_runner, model_config) + + print(f"model_runner.model===={model_runner.model}") + #print(f"model_runner.structure_module.model===={model_runner.model.impl.structure_module.model}") + h_timmer.add_timmer('model_1 infer') + with torch.no_grad(): + # with torch.cpu.amp.autocast(dtype=(torch.float16 if fp16 else torch.bfloat16 ),enabled=bf16): + with torch.cpu.amp.autocast(enabled=bf16): + prediction_result = model_runner(processed_feature_dict) + + print('### [INFO] post-assessment: plddt') + timmer_name = f'post-assessment by plddt: {model_name}' + # timmer.add_timmer(timmer_name) + plddts[model_name] = np.mean(prediction_result['plddt']) + print("plddts score = ", plddts[model_name]) + # print("PTM = ", prediction_result['ptm']) + + # 保存最终激活值 + if compare_active_representations: + import pickle + if distributed_mpi: + result_output_path = f'right_result{dist.get_rank()}.pkl' + else: + result_output_path = 'right_result1000.pkl' + + with open(result_output_path, 'wb') as f: + pickle.dump(prediction_result, f, protocol=4) + + + # PDB + from alphafold.common import protein + from alphafold.common import residue_constants + + b_factors = np.repeat( + prediction_result['plddt'][:, None], residue_constants.atom_type_num, axis=-1) + unrelaxed_protein = protein.from_prediction( + jax.tree_map(lambda x:x.detach().numpy(),processed_feature_dict), + prediction_result, + b_factors) + + if distributed_mpi: + unrelaxed_pdb_path = f'920f-unrelaxed_rank{dist.get_rank()}.pdb' + else: + unrelaxed_pdb_path = '920f-unrelaxed.pdb' + with open(unrelaxed_pdb_path, 'w') as h: + h.write(protein.to_pdb(unrelaxed_protein)) + + +if __name__ == '__main__': + if distributed_mpi: + dist.init_process_group(backend="mpi") + rank_nums = dist.get_world_size() + rank = dist.get_rank() + kpex._C.mpi.initialize(rank_nums, rank, 268_435_456) + print(f" # [INFO] rank/world_size: {rank}/{rank_nums}") + else: + rank=1000 + rank_nums=1000 + + profile_all = (os.environ.get('PROFILE_ALL') == '1') + if profile_all : + prof_all = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=0,warmup=0,active=1,repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("log_tensor",f"920f-thread{torch.get_num_threads()}-{rank}_{rank_nums}-{seq_type}"), + record_shapes=True, + profile_memory=True, + with_stack=False, + use_cuda=False + ) + prof_all.start() + + t1 = time.time(); + + main() + t2 = time.time(); + print(' # [TIME] total duration =', (t2 - t1), 'sec') + if distributed_mpi: + kpex._C.mpi.finalize() + + if profile_all: + prof_all.stop() + + if print_lddt: + import run_pdb_eval + if distributed_mpi: + ppdb_path=f'920f-unrelaxed_rank{dist.get_rank()}.pdb' + rank_id = dist.get_rank() + else: + ppdb_path = "920f-unrelaxed.pdb" + rank_id = 100 + + run_pdb_eval.main(rank=rank_id, + model_pdb=ppdb_path, + ref_pdb='../af_input/casp14/T1050/T1050_ref.pdb', + result_dir='structure_align', + tmalign_bin='./tools/TM-align/TMalign') + + # 比较最终激活值是否相等,判断推理逻辑正确与否 + # 先获取单进程正确的right_result1000.pkl 激活值移动到./pt_file/toright/right_result1000_ipa.pkl,再并行运行比较结果是否一致 + if compare_active_representations: + with open('./pt_file/toright/right_result1000_ipa.pkl','rb') as f: + right_result = pickle.load(f) + if distributed_mpi: + with open(f'right_result{rank}.pkl','rb') as f: + result1 = pickle.load(f) + else: + with open(f'right_result1000.pkl','rb') as f: + result1 = pickle.load(f) + + print(f"{rank}_{np.array_equal(right_result['representations']['msa'],result1['representations']['msa'])}") + print(f"{rank}_{np.array_equal(right_result['representations']['msa_first_row'],result1['representations']['msa_first_row'])}") + print(f"{rank}_{np.array_equal(right_result['representations']['pair'],result1['representations']['pair'])}") + print(f"{rank}_{np.array_equal(right_result['representations']['single'],result1['representations']['single'])}") + print(f"{rank}_{np.array_equal(right_result['representations']['structure_module'],result1['representations']['structure_module'])}") diff --git a/tpp-pytorch-extension/build_tpp_bisheng.sh b/tpp-pytorch-extension/build_tpp_bisheng.sh new file mode 100644 index 0000000000000000000000000000000000000000..b70c277b755e8ae493c28c1fa12e1dc7223dcf9a --- /dev/null +++ b/tpp-pytorch-extension/build_tpp_bisheng.sh @@ -0,0 +1,10 @@ +# https://github.com/libxsmm/parlooper/tree/630b6396369c2dab1fd96372c054cd1f34c35e7e +# https://github.com/libxsmm/libxsmm/tree/b8b085ee9aa62c269f7beb2a585c90d2f28adfbe +# export LD_LIBRARY_PATH=./tpp-pytorch-extension/parlooper/lib/:$LD_LIBRARY_PATH +# export LD_LIBRARY_PATH=./tpp-pytorch-extension/libxsmm/lib/:$LD_LIBRARY_PATH +pip uninstall -y tpp-pytorch-extension +python setup.py clean +rm -rf build dist +export CFLAGS=" -stdlib=libc++ -lc++ -lc++abi" +# export CFLAGS="-rtlib=compiler-rt" +python setup.py install \ No newline at end of file diff --git a/tpp-pytorch-extension/setup.py b/tpp-pytorch-extension/setup.py index a1f0f79dbc9a872ab6cb29dcf1dce274ecb8076f..aefaf33f937481630dd7279891c51ea35c9e672e 100644 --- a/tpp-pytorch-extension/setup.py +++ b/tpp-pytorch-extension/setup.py @@ -18,6 +18,7 @@ from subprocess import check_call, check_output import pathlib import torch import platform +import subprocess cwd = os.path.dirname(os.path.realpath(__file__)) @@ -166,6 +167,15 @@ print("extra_compile_args = ", extra_compile_args) print(sources) +c_result = subprocess.run(['which', os.getenv('CC')], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout.strip() + +if 'clang' in c_result: + COMPILER_CC = 'clang' + COMPILER_CXX = 'clang++' +else: + COMPILER_CC = 'gcc' + COMPILER_CXX = 'g++' + setup( name="tpp-pytorch-extension", version="0.0.1", @@ -186,18 +196,17 @@ setup( # install_requires=["torch>=1.4.0"], scripts=["utils/run_dist.sh", "utils/run_dist_ht.sh", "utils/run_dist_numa.sh"], libraries=[ - ("xsmm", xsmm_makefile, ["CC=gcc", "CXX=g++", "AVX=2", "-j", "STATIC=1"]), + ("xsmm", xsmm_makefile, [f"CC={COMPILER_CC}", f"CXX={COMPILER_CXX}", "-j", "STATIC=1", "JIT=1", "LIBXSMM_JIT=1", "LIBXSMM_CONFIG_JIT=1"]), ( "parlooper", parlooper_makefile, [ - "CC=gcc", - "CXX=g++", - "AVX=2", + f"CC={COMPILER_CC}", + f"CXX={COMPILER_CXX}", "-j", "ROOTDIR = " + parlooper_root, "LIBXSMM_ROOT=" + libxsmm_root, - "PARLOOPER_COMPILER=gcc", + "PARLOOPER_COMPILER=clang", ], ), ], @@ -207,8 +216,8 @@ setup( sources, extra_compile_args=extra_compile_args, include_dirs=[xsmm_include, parlooper_include, "{}/src/csrc".format(cwd)], - # library_dirs=[xsmm_lib], - # libraries=["xsmm"], + # library_dirs=[xsmm_lib,parlooper_lib], + # libraries=["xsmm","parlooper"], ) ], cmdclass={"build_ext": BuildExtension, "build_clib": BuildMakeLib}, diff --git a/tpp-pytorch-extension/src/csrc/bert/pad/fused_self_attention_bwd_tmpl.h b/tpp-pytorch-extension/src/csrc/bert/pad/fused_self_attention_bwd_tmpl.h index b8da82c889723c8b446c1d9e2c28e1a156793bed..628ed72666777258d18a1152c2ee251aebba615e 100644 --- a/tpp-pytorch-extension/src/csrc/bert/pad/fused_self_attention_bwd_tmpl.h +++ b/tpp-pytorch-extension/src/csrc/bert/pad/fused_self_attention_bwd_tmpl.h @@ -213,8 +213,14 @@ auto t_Wv_TV = wt_tensor_for_bwd(N, H, N, H, t_Wv); for (int b = 0; b < B; b++) { for (int n = 0; n < N; n++) { for (int s11 = 0; s11 < S1; s11++) { - float dtAPD[S1][S2][S2] = {0}; - T dtAPD_bf[S1][S2][S2] = {0}; + // float dtAPD[S1][S2][S2] = {0}; + // T dtAPD_bf[S1][S2][S2] = {0}; + // 毕昇编译修改 + float dtAPD[S1][S2][S2]; + T dtAPD_bf[S1][S2][S2]; + dtAPD[S1][S2][S2] = {0}; + dtAPD_bf[S1][S2][S2] = {0}; + for (int s21 = 0; s21 < S1; s21++) { if (dAPO) a_convert_tpp(dAPO[b][s11][n][s21], dtAPD[s21][0]); diff --git a/tpp-pytorch-extension/src/csrc/dlrm/embbag.cpp b/tpp-pytorch-extension/src/csrc/dlrm/embbag.cpp index aaefcba727977046d44e329bc5a72c18c4ef777d..8cdcf87fd9432774ce643081e3d486f8397b8ab7 100644 --- a/tpp-pytorch-extension/src/csrc/dlrm/embbag.cpp +++ b/tpp-pytorch-extension/src/csrc/dlrm/embbag.cpp @@ -53,8 +53,9 @@ void tpp_embedding_bag_forward_tmpl( t_input = t_input.contiguous(); t_offsets = t_offsets.contiguous(); - DECL_VLA_PTR_PT(scalar_t, weight, [E], t_weight); - DECL_VLA_PTR_PT(out_scalar_t, output, [E], t_output); + // 毕昇编译修改 + // DECL_VLA_PTR_PT(scalar_t, weight, [E], t_weight); + // DECL_VLA_PTR_PT(out_scalar_t, output, [E], t_output); int64_t* input = t_input.data_ptr