From 32fa67815966ecb4c98d6434fab416e41a3bc60e Mon Sep 17 00:00:00 2001 From: "hongliang.yuan" Date: Wed, 9 Apr 2025 11:22:37 +0800 Subject: [PATCH] sync bloom model fix qlora --- .../pytorch/models/bloom/config.json | 37 + .../models/bloom/configuration_bloom.py | 55 +- .../pytorch/models/bloom/modeling_bloom.py | 745 +++++++++--------- 3 files changed, 436 insertions(+), 401 deletions(-) create mode 100644 nlp/llm/bloom-7b1/pytorch/models/bloom/config.json diff --git a/nlp/llm/bloom-7b1/pytorch/models/bloom/config.json b/nlp/llm/bloom-7b1/pytorch/models/bloom/config.json new file mode 100644 index 000000000..ca136ff59 --- /dev/null +++ b/nlp/llm/bloom-7b1/pytorch/models/bloom/config.json @@ -0,0 +1,37 @@ +{ + "apply_residual_connection_post_layernorm": false, + "architectures": [ + "BloomForCausalLM" + ], + "auto_map": { + "AutoConfig": "configuration_bloom.BloomConfig", + "AutoModelForCausalLM": "modeling_bloom.BloomForCausalLM", + "AutoModel": "modeling_bloom.BloomForCausalLM", + "AutoModelForSeq2SeqLM": "modeling_bloom.BloomForCausalLM" + }, + "attention_dropout": 0.0, + "attention_softmax_in_fp32": true, + "bias_dropout_fusion": true, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_dropout": 0.0, + "hidden_size": 4096, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "masked_softmax_fusion": true, + "model_type": "bloom", + "n_head": 32, + "n_inner": null, + "n_layer": 30, + "offset_alibi": 100, + "pad_token_id": 3, + "pretraining_tp": 1, + "skip_bias_add": true, + "skip_bias_add_qkv": false, + "slow_but_exact": false, + "torch_dtype": "float16", + "transformers_version": "4.22.2", + "unk_token_id": 0, + "use_cache": true, + "vocab_size": 250880 +} diff --git a/nlp/llm/bloom-7b1/pytorch/models/bloom/configuration_bloom.py b/nlp/llm/bloom-7b1/pytorch/models/bloom/configuration_bloom.py index c485331bc..424a016a1 100755 --- a/nlp/llm/bloom-7b1/pytorch/models/bloom/configuration_bloom.py +++ b/nlp/llm/bloom-7b1/pytorch/models/bloom/configuration_bloom.py @@ -12,29 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Bloom configuration""" +"""Bloom configuration""" + from collections import OrderedDict from typing import TYPE_CHECKING, Any, List, Mapping, Optional from packaging import version + if TYPE_CHECKING: - from ... import PreTrainedTokenizer, TensorType + from transformers import PreTrainedTokenizer, TensorType from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfigWithPast, PatchingSpec from transformers.utils import is_torch_available, logging -logger = logging.get_logger(__name__) -BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json", - "bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json", - "bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json", - "bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json", - "bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json", - "bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json", -} +logger = logging.get_logger(__name__) class BloomConfig(PretrainedConfig): @@ -136,9 +130,7 @@ class BloomConfig(PretrainedConfig): self.initializer_range = initializer_range self.use_cache = use_cache self.pretraining_tp = pretraining_tp - self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm - ) + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout @@ -159,9 +151,7 @@ class BloomOnnxConfig(OnnxConfigWithPast): patching_specs: List[PatchingSpec] = None, use_past: bool = False, ): - super().__init__( - config, task=task, patching_specs=patching_specs, use_past=use_past - ) + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) if not getattr(self._config, "pad_token_id", None): # TODO: how to do that better? self._config.pad_token_id = 0 @@ -171,13 +161,8 @@ class BloomOnnxConfig(OnnxConfigWithPast): common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344 - self.fill_with_past_key_values_( - common_inputs, direction="inputs", inverted_values_shape=True - ) - common_inputs["attention_mask"] = { - 0: "batch", - 1: "past_sequence + sequence", - } + self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True) + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} @@ -204,11 +189,7 @@ class BloomOnnxConfig(OnnxConfigWithPast): framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, - batch_size=batch_size, - seq_length=seq_length, - is_pair=is_pair, - framework=framework, + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # We need to order the input in the way they appears in the forward() @@ -217,9 +198,7 @@ class BloomOnnxConfig(OnnxConfigWithPast): # Need to add the past_keys if self.use_past: if not is_torch_available(): - raise ValueError( - "Cannot generate dummy past_keys inputs without PyTorch installed." - ) + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") else: import torch @@ -238,19 +217,14 @@ class BloomOnnxConfig(OnnxConfigWithPast): head_dim, ) ordered_inputs["past_key_values"] = [ - (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) - for _ in range(self.num_layers) + (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] if self.use_past: mask_dtype = ordered_inputs["attention_mask"].dtype ordered_inputs["attention_mask"] = torch.cat( - [ - ordered_inputs["attention_mask"], - torch.ones(batch, past_key_values_length, dtype=mask_dtype), - ], - dim=1, + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 ) return ordered_inputs @@ -258,3 +232,6 @@ class BloomOnnxConfig(OnnxConfigWithPast): @property def default_onnx_opset(self) -> int: return 13 + + +__all__ = ["BloomConfig", "BloomOnnxConfig"] diff --git a/nlp/llm/bloom-7b1/pytorch/models/bloom/modeling_bloom.py b/nlp/llm/bloom-7b1/pytorch/models/bloom/modeling_bloom.py index 9328319b0..8ec68b93b 100755 --- a/nlp/llm/bloom-7b1/pytorch/models/bloom/modeling_bloom.py +++ b/nlp/llm/bloom-7b1/pytorch/models/bloom/modeling_bloom.py @@ -23,12 +23,11 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -37,29 +36,19 @@ from transformers.modeling_outputs import ( TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - +from transformers.utils import is_torchdynamo_compiling, logging from .configuration_bloom import BloomConfig +from apex.normalization import FusedLayerNorm + +LayerNorm = FusedLayerNorm logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" _CONFIG_FOR_DOC = "BloomConfig" -BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "bigscience/bigscience-small-testing", - "bigscience/bloom-560m", - "bigscience/bloom-1b1", - "bigscience/bloom-1b7", - "bigscience/bloom-3b", - "bigscience/bloom-7b1", - "bigscience/bloom", -] - -def build_alibi_tensor( - attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype -) -> torch.Tensor: +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -71,7 +60,7 @@ def build_alibi_tensor( Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) attention_mask (`torch.Tensor`): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - num_heads (`int`, *required*): + num_heads (`int`): number of heads dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor @@ -79,29 +68,17 @@ def build_alibi_tensor( batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor( - 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32, - ) - powers = torch.arange( - 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32 + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32, + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange( - 1, - 1 + 2 * num_remaining_heads, - 2, - device=attention_mask.device, - dtype=torch.int32, - ) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention @@ -115,20 +92,18 @@ def build_alibi_tensor( return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) -def dropout_add( - x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool -) -> torch.Tensor: +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ Dropout add function Args: - x (`torch.tensor`, *required*): + x (`torch.tensor`): input tensor - residual (`torch.tensor`, *required*): + residual (`torch.tensor`): residual tensor - prob (`float`, *required*): + prob (`float`): dropout probability - training (`bool`, *required*): + training (`bool`): training mode """ out = F.dropout(x, p=prob, training=training) @@ -142,7 +117,7 @@ def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor: make the model jitable. Args: - x (`torch.tensor`, *required*): + x (`torch.tensor`): input hidden states """ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) @@ -154,17 +129,15 @@ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 0.3989423 * x * torch.exp(-0.5 * x * x) Args: - g (`torch.tensor`, *required*): + g (`torch.tensor`): gradient output tensor - x (`torch.tensor`, *required*): + x (`torch.tensor`): input tensor """ x = x[0] # x is a tuple of 1 element, needs to unpack it first tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ( - (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) return ff * g @@ -201,7 +174,7 @@ class BloomGelu(nn.Module): class BloomAttention(nn.Module): - def __init__(self, config: BloomConfig): + def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() self.pretraining_tp = config.pretraining_tp @@ -222,39 +195,44 @@ class BloomAttention(nn.Module): # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.beta = 1.0 + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) - self.query_key_value = nn.Linear( - self.hidden_size, 3 * self.hidden_size, bias=True - ) + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.attention_dropout = nn.Dropout(config.attention_dropout) - def _split_heads( - self, fused_qkv: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _reshape(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory - storage as `fused_qkv` + Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape + without making any copies, results share same memory storage as `fused_qkv` Args: - fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] + query: [batch_size, num_heads, seq_length, head_dim] + key: [batch_size, num_heads, seq_length, head_dim] + value: [batch_size, num_heads, seq_length, head_dim] """ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view( - batch_size, seq_length, self.num_heads, 3, self.head_dim - ) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + query_layer = fused_qkv[..., 0, :].transpose(1, 2) + key_layer = fused_qkv[..., 1, :].transpose(1, 2) + value_layer = fused_qkv[..., 2, :].transpose(1, 2) + return query_layer, key_layer, value_layer def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: """ Merge heads together over the last dimension Args: - x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim] Returns: torch.tensor: [batch_size, seq_length, num_heads * head_dim] @@ -280,47 +258,28 @@ class BloomAttention(nn.Module): residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ): - fused_qkv = self.query_key_value( - hidden_states - ) # [batch_size, seq_length, 3 x hidden_size] - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, _ = hidden_states.shape + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape( - batch_size * self.num_heads, q_length, self.head_dim - ) - key_layer = key_layer.permute(0, 2, 3, 1).reshape( - batch_size * self.num_heads, self.head_dim, q_length - ) - value_layer = value_layer.transpose(1, 2).reshape( - batch_size * self.num_heads, q_length, self.head_dim - ) if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=2) - value_layer = torch.cat((past_value, value_layer), dim=1) + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) - _, _, kv_length = key_layer.shape - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # [batch_size * num_heads, q_length, kv_length] - # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 - matmul_result = alibi.baddbmm( + attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, @@ -328,21 +287,13 @@ class BloomAttention(nn.Module): ) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view( - batch_size, self.num_heads, q_length, kv_length - ) + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16: - attention_scores = attention_scores.to(torch.float) - attn_weights = torch.masked_fill( - attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min - ) - attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - input_dtype - ) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -351,9 +302,7 @@ class BloomAttention(nn.Module): attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view( - batch_size * self.num_heads, q_length, kv_length - ) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) @@ -373,11 +322,9 @@ class BloomAttention(nn.Module): else: output_tensor = self.dense(context_layer) - output_tensor = dropout_add( - output_tensor, residual, self.hidden_dropout, self.training - ) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present) + outputs = (output_tensor, layer_past) if output_attentions: outputs += (attention_probs,) @@ -396,9 +343,7 @@ class BloomMLP(nn.Module): self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) self.hidden_dropout = config.hidden_dropout - def forward( - self, hidden_states: torch.Tensor, residual: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) if self.pretraining_tp > 1 and self.slow_but_exact: @@ -407,37 +352,29 @@ class BloomMLP(nn.Module): for i in range(self.pretraining_tp): intermediate_output = intermediate_output + F.linear( hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense_4h_to_h.weight[ - :, int(i * slices) : int((i + 1) * slices) - ], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: intermediate_output = self.dense_4h_to_h(hidden_states) - output = dropout_add( - intermediate_output, residual, self.hidden_dropout, self.training - ) + output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) return output class BloomBlock(nn.Module): - def __init__(self, config: BloomConfig): + def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.num_heads = config.n_head - self.self_attention = BloomAttention(config) - self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon - ) + self.self_attention = BloomAttention(config, layer_idx) + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config) - self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm - ) + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.hidden_dropout = config.hidden_dropout def forward( @@ -445,15 +382,17 @@ class BloomBlock(nn.Module): hidden_states: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states).to(hidden_states.dtype) + layernorm_output = self.input_layernorm(hidden_states) + # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output @@ -470,15 +409,14 @@ class BloomBlock(nn.Module): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = attn_outputs[0] outputs = attn_outputs[1:] - layernorm_output = self.post_attention_layernorm(attention_output).to( - attention_output.dtype - ) + layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: @@ -494,7 +432,7 @@ class BloomBlock(nn.Module): else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, present, attentions + return outputs # hidden_states, past_kv, attentions class BloomPreTrainedModel(PreTrainedModel): @@ -503,6 +441,9 @@ class BloomPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -523,45 +464,6 @@ class BloomPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - @staticmethod - def _convert_to_standard_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, - num_heads, ...])) - """ - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape - num_heads = batch_size_times_num_heads // batch_size - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - @staticmethod - def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: - """ - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) - """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - BLOOM_START_DOCSTRING = r""" @@ -591,14 +493,24 @@ BLOOM_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - - Each element of `past_key_values` is a tuple (past_key, past_value): - - past_key: [batch_size * num_heads, head_dim, kv_length] - - past_value: [batch_size * num_heads, kv_length, head_dim] + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -630,6 +542,10 @@ BLOOM_INPUTS_DOCSTRING = r""" more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -646,14 +562,10 @@ class BloomModel(BloomPreTrainedModel): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - self.word_embeddings_layernorm = LayerNorm( - self.embed_dim, eps=config.layer_norm_epsilon - ) + self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList( - [BloomBlock(config) for _ in range(config.num_hidden_layers)] - ) + self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -663,9 +575,7 @@ class BloomModel(BloomPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def build_alibi_tensor( - self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype - ) -> torch.Tensor: + def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: return build_alibi_tensor(attention_mask, num_heads, dtype) def get_input_embeddings(self): @@ -683,7 +593,7 @@ class BloomModel(BloomPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -691,6 +601,7 @@ class BloomModel(BloomPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -703,85 +614,68 @@ class BloomModel(BloomPreTrainedModel): if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_length + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds).to( - inputs_embeds.dtype - ) - - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), device=hidden_states.device - ) + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) - alibi = self.build_alibi_tensor( - attention_mask, self.num_heads, dtype=hidden_states.dtype + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - causal_mask = causal_mask.bool() - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -791,30 +685,30 @@ class BloomModel(BloomPreTrainedModel): hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + if use_cache: + next_decoder_cache = outputs[1] if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], - ) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -822,25 +716,147 @@ class BloomModel(BloomPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] - if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( """ @@ -849,7 +865,7 @@ class BloomModel(BloomPreTrainedModel): """, BLOOM_START_DOCSTRING, ) -class BloomForCausalLM(BloomPreTrainedModel): +class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BloomConfig): @@ -868,39 +884,61 @@ class BloomForCausalLM(BloomPreTrainedModel): def prepare_inputs_for_generation( self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, **kwargs, - ) -> dict: - # only last tokens for input_ids if past is not None + ): + # Overwriten because of the fixed-shape attention mask creation + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the + # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in + # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor + # The only difference is the usage of 2D instead of 4D mask, but the shape will be static + if isinstance(past_key_values, StaticCache) and attention_mask is not None: + target_length = past_key_values.get_max_cache_shape() + batch_size, seq_length = attention_mask.shape + diff = target_length - seq_length + + new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype) + attention_mask = torch.cat( + [attention_mask, new_attn_mask], + dim=-1, + ) model_inputs.update( { + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) @@ -915,7 +953,7 @@ class BloomForCausalLM(BloomPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -924,6 +962,7 @@ class BloomForCausalLM(BloomPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -932,6 +971,8 @@ class BloomForCausalLM(BloomPreTrainedModel): `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ + # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly + num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None) if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( @@ -942,9 +983,7 @@ class BloomForCausalLM(BloomPreTrainedModel): if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, @@ -956,6 +995,7 @@ class BloomForCausalLM(BloomPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -965,15 +1005,12 @@ class BloomForCausalLM(BloomPreTrainedModel): if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length), + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + num_items_in_batch=num_items_in_batch, ) if not return_dict: @@ -989,9 +1026,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ) def _reorder_cache( - self, - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], - beam_idx: torch.LongTensor, + self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or @@ -1000,24 +1035,18 @@ class BloomForCausalLM(BloomPreTrainedModel): Output shares the same memory storage as `past`. """ - standardized_past = self._convert_to_standard_cache( - past, batch_size=len(beam_idx) - ) - # Get a copy of `beam_idx` on all the devices where we need those indices. device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) - for layer_past in past - for past_state in layer_past + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past } reordered_past = tuple( ( layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), ) - for layer_past in standardized_past + for layer_past in past ) - return self._convert_to_bloom_cache(reordered_past) + return reordered_past @add_start_docstrings( @@ -1054,7 +1083,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -1081,9 +1110,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, @@ -1106,35 +1133,29 @@ class BloomForSequenceClassification(BloomPreTrainedModel): batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: - sequence_lengths = -1 + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: - if input_ids is not None: - sequence_lengths = ( - torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - ).to(logits.device) - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1177,10 +1198,7 @@ class BloomForTokenClassification(BloomPreTrainedModel): self.num_labels = config.num_labels self.transformer = BloomModel(config) - if ( - hasattr(config, "classifier_dropout") - and config.classifier_dropout is not None - ): + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: classifier_dropout = config.classifier_dropout elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: classifier_dropout = config.hidden_dropout @@ -1201,7 +1219,7 @@ class BloomForTokenClassification(BloomPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -1228,9 +1246,7 @@ class BloomForTokenClassification(BloomPreTrainedModel): if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, @@ -1255,8 +1271,7 @@ class BloomForTokenClassification(BloomPreTrainedModel): batch_size, seq_length = labels.shape loss_fct = CrossEntropyLoss() loss = loss_fct( - logits.view(batch_size * seq_length, self.num_labels), - labels.view(batch_size * seq_length), + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) ) if not return_dict: @@ -1287,9 +1302,7 @@ class BloomForQuestionAnswering(BloomPreTrainedModel): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward( - BLOOM_INPUTS_DOCSTRING.format("batch_size, sequence_length") - ) + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1313,9 +1326,7 @@ class BloomForQuestionAnswering(BloomPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.transformer( input_ids, @@ -1363,3 +1374,13 @@ class BloomForQuestionAnswering(BloomPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +__all__ = [ + "BloomForCausalLM", + "BloomModel", + "BloomPreTrainedModel", + "BloomForSequenceClassification", + "BloomForTokenClassification", + "BloomForQuestionAnswering", +] -- Gitee