diff --git a/README.md b/README.md index 661ad021d9a4ebbd4e7868c83ba91518da7826df..1b19f9d9767fcf5d75a83de25ee0856e79e41dc6 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ DeepSparkHub甄选上百个应用算法和模型,覆盖AI和通用计算各领 | [Llama2-13B](nlp/llm/llama2-13b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 3.4.0 | | [Llama2-34B](nlp/llm/llama2-34b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 3.4.0 | | [Llama3-8B](nlp/llm/llama3_8b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 4.1.1 | +| [Llama3-8B](nlp/llm/llama3_8b/megatron-lm) | PyTorch | Megatron-LM | GPT Small-117M | 4.3.0 | | [Llama3-8B SFT](nlp/llm/llama3_8b_sft/pytorch) | PyTorch | ColossalAI | school_math_0.25M | 4.1.1 | | [Llama3-8B PPO](nlp/llm/llama3_8b/openrlhf) | PyTorch | OpenRLHF | Llama-3-8b-sft-mixture | 4.2.0 | | [Llama3-8B DPO](nlp/llm/llama3_8b/openrlhf) | PyTorch | OpenRLHF | Llama-3-8b-sft-mixture | 4.2.0 | @@ -51,6 +52,7 @@ DeepSparkHub甄选上百个应用算法和模型,覆盖AI和通用计算各领 | [Qwen2.5-7B SFT](nlp/llm/qwen2.5-7b/pytorch) | PyTorch | LLaMA-Factory | qwen2.5-7b | 4.1.1 | | [Qwen2.5-1.5B verl](nlp/llm/qwen2.5-1.5b/verl) | PyTorch | verl | qwen2.5-1.5b | 4.2.0 | | [Qwen2.5-7B verl](nlp/llm/qwen2.5-7b/verl) | PyTorch | verl | qwen2.5-7b | 4.2.0 | +| [Qwen2.5-3B](nlp/llm/qwen2.5-3b/pytorch) | PyTorch | ColossalAI | qwen2.5-3b | 4.3.0 | | [Yi-6B](nlp/llm/yi-6b/pytorch) | PyTorch | DeepSpeed | Yi-6B | 4.2.0 | | [Yi-1.5-6B](nlp/llm/yi-1.5-6b/pytorch) | PyTorch | DeepSpeed | Yi-1.5-6B | 4.2.0 | | [Yi-VL-6B](nlp/llm/yi-vl-6b/pytorch) | PyTorch | LLaMA-Factory | Yi-VL-6B-hf | 4.2.0 | diff --git a/README_en.md b/README_en.md index 81f6a9dcbc86918d159ab66e3c9ce4717439273a..e73d799b6816ba496ec07624a42b2b3718970469 100644 --- a/README_en.md +++ b/README_en.md @@ -39,6 +39,7 @@ individuals, healthcare, education, communication, energy, and more. | [Llama2-13B](nlp/llm/llama2-13b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 3.4.0 | | [Llama2-34B](nlp/llm/llama2-34b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 3.4.0 | | [Llama3-8B](nlp/llm/llama3_8b/pytorch) | PyTorch | Megatron-DeepSpeed | Bookcorpus | 4.1.1 | +| [Llama3-8B](nlp/llm/llama3_8b/megatron-lm) | PyTorch | Megatron-LM | GPT Small-117M | 4.3.0 | | [Llama3-8B SFT](nlp/llm/llama3_8b_sft/pytorch) | PyTorch | ColossalAI | school_math_0.25M | 4.1.1 | | [Llama3-8B PPO](nlp/llm/llama3_8b/openrlhf) | PyTorch | OpenRLHF | Llama-3-8b-sft-mixture | 4.2.0 | | [Llama3-8B DPO](nlp/llm/llama3_8b/openrlhf) | PyTorch | OpenRLHF | Llama-3-8b-sft-mixture | 4.2.0 | @@ -53,6 +54,7 @@ individuals, healthcare, education, communication, energy, and more. | [Qwen2.5-7B SFT](nlp/llm/qwen2.5-7b/pytorch) | PyTorch | LLaMA-Factory | qwen2.5-7b | 4.1.1 | | [Qwen2.5-1.5B verl](nlp/llm/qwen2.5-1.5b/verl) | PyTorch | verl | qwen2.5-1.5b | 4.2.0 | | [Qwen2.5-7B verl](nlp/llm/qwen2.5-7b/verl) | PyTorch | verl | qwen2.5-7b | 4.2.0 | +| [Qwen2.5-3B](nlp/llm/qwen2.5-3b/pytorch) | PyTorch | ColossalAI | qwen2.5-3b | 4.3.0 | | [Yi-6B](nlp/llm/yi-6b/pytorch) | PyTorch | DeepSpeed | Yi-6B | 4.2.0 | | [Yi-1.5-6B](nlp/llm/yi-1.5-6b/pytorch) | PyTorch | DeepSpeed | Yi-1.5-6B | 4.2.0 | | [Yi-VL-6B](nlp/llm/yi-vl-6b/pytorch) | PyTorch | LLaMA-Factory | Yi-VL-6B-hf | 4.2.0 | diff --git a/nlp/llm/llama3_8b/megatron-lm/README.md b/nlp/llm/llama3_8b/megatron-lm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f4e3de103a8fd047c5273208295a5c6b03c0ebb0 --- /dev/null +++ b/nlp/llm/llama3_8b/megatron-lm/README.md @@ -0,0 +1,50 @@ +# Llama3-8B (Megatron-LM) + +## Model Description + +Llama3-8B is an advanced auto-regressive language model developed by Meta, featuring 8 billion parameters. It utilizes +an optimized transformer architecture with Grouped-Query Attention (GQA) for improved inference efficiency. Trained on +sequences of 8,192 tokens and using a 128K token vocabulary, it excels in various natural language tasks. The model +incorporates supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human +preferences, ensuring both helpfulness and safety in its responses. Llama3-8B offers state-of-the-art performance in +language understanding and generation. + +## Supported Environments + +| GPU | [IXUCA SDK](https://gitee.com/deep-spark/deepspark#%E5%A4%A9%E6%95%B0%E6%99%BA%E7%AE%97%E8%BD%AF%E4%BB%B6%E6%A0%88-ixuca) | Release | +| :----: | :----: | :----: | +| BI-V150 | 4.3.0 | 25.09 | + +## Model Preparation + +### Prepare Resources + +```sh +mkdir -p dataset +pushd dataset +# get gpt_small_117M_llama3.tar +wget http://files.deepspark.org.cn:880/deepspark/data/datasets/gpt_small_117M_llama3.tar +tar -xf gpt_small_117M_llama3.tar +rm -f gpt_small_117M_llama3.tar +popd + +mkdir -p llama3-8b +# get LLM-Research/Meta-Llama-3-8B tokenizer.json put into llama3-8b +``` + +### Install Dependencies + +Contact the Iluvatar administrator to get the missing packages: +- transformers-4.45.2+corex.4.3.0-py3-none-any.whl + +## Model Training + +```sh +bash llama3_8b_dp2_pp8_tp1.sh +``` + +## Model Results + +## References + +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) diff --git a/nlp/llm/llama3_8b/megatron-lm/llama3_8b_dp2_pp8_tp1.sh b/nlp/llm/llama3_8b/megatron-lm/llama3_8b_dp2_pp8_tp1.sh new file mode 100644 index 0000000000000000000000000000000000000000..490033918e38d7671af5afc31cd54d0155d23ab3 --- /dev/null +++ b/nlp/llm/llama3_8b/megatron-lm/llama3_8b_dp2_pp8_tp1.sh @@ -0,0 +1,129 @@ +#!/bin/bash +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +GPUS_PER_NODE=16 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NUM_NODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_NET_SHARED_BUFFERS=0 +export NCCL_DEBUG=TRACE +export NCCL_ALGO=Ring +export OMP_NUM_THREADS=4 +export ENABLE_FLASH_ATTENTION_WITH_IXDNN=1 +export NCCL_USE_DIRECT=1 + +DATA_PATH=datasets/gpt_small_117M_llama3/gpt_small_117M_text_document +TOKENIZER_MODEL=llama3-8b + +TP=1 +PP=8 +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=14336 +NUM_KEY_VALUE_HEADS=8 +SEQ_LEN=8192 +MAX_POSITION_EMBEDDINGS=8192 + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +GPT_MODEL_ARGS=( + --num-layers ${NUM_LAYERS} + --hidden-size ${HIDDEN_SIZE} + --num-attention-heads ${NUM_ATTN_HEADS} + --ffn-hidden-size ${INTERMEDIATE_SIZE} + --seq-length ${SEQ_LEN} + --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model ${TOKENIZER_MODEL} + --group-query-attention + --num-query-groups ${NUM_KEY_VALUE_HEADS} + --attention-dropout 0.0 + --hidden-dropout 0.0 + --attention-softmax-in-fp32 + --normalization RMSNorm + --position-embedding-type rope + --rotary-base 500000 + --rotary-percent 1.0 + --untie-embeddings-and-output-weights + --disable-bias-linear + --transformer-impl transformer_engine + --swiglu + --bf16 + --use-legacy-models + --ckpt-format torch +) + +TRAINING_ARGS=( + --adam-beta1 0.9 + --adam-beta2 0.95 + --weight-decay 0.1 + --clip-grad 1.0 + --lr 6.0e-5 + --min-lr 6.0e-6 + --lr-decay-style cosine + --lr-warmup-fraction .001 + --lr-decay-iters 430000 + --micro-batch-size 1 + --global-batch-size 64 + --train-iters 5 + --init-method-std 0.006 + --no-load-optim + --no-load-rng + --no-create-attention-mask-in-dataloader + --initial-loss-scale 65536 + --use-flash-attn + --num-layers-per-stage 1 3 4 4 2 5 1 3 + --recompute-granularity=full + --recompute-method=uniform + --recompute-num-layers 1 + --recompute-method-per-stage 8 1 + --recompute-num-layers-per-stage 5 2 3 0 + --use-distributed-optimizer +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size ${TP} + --pipeline-model-parallel-size ${PP} +) + +DATA_ARGS=( + --data-path $DATA_PATH + --split 99,1,0 +) + +EVAL_AND_LOGGING_ARGS=( + --log-interval 1 + --save-interval 10000 + --eval-interval 1000 + --eval-iters 0 +) + +torchrun ${DISTRIBUTED_ARGS[@]} ./pretrain_gpt.py \ + ${GPT_MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} \ No newline at end of file diff --git a/nlp/llm/llama3_8b/megatron-lm/pretrain_gpt.py b/nlp/llm/llama3_8b/megatron-lm/pretrain_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..26ef2e68dd42dd62bea5678097b265abfc38e8da --- /dev/null +++ b/nlp/llm/llama3_8b/megatron-lm/pretrain_gpt.py @@ -0,0 +1,306 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain GPT.""" + +import os +import torch +from functools import partial +from contextlib import nullcontext +import inspect + +from typing import List, Optional, Tuple, Union +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.rerun_state_machine import get_rerun_state_machine +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + if args.record_memory_history: + torch.cuda.memory._record_memory_history(True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + + # record stack information for the trace events + trace_alloc_record_context=True) + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if args.num_experts: + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te) + else: + # Define the decoder layer spec + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) + + build_model_context = nullcontext + build_model_context_args = {} + if args.fp8_param_gather: + try: + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Check if fp8_model_init supports preserve_high_precision_init_val + if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: + build_model_context_args["preserve_high_precision_init_val"] = True + except: + raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") + + with build_model_context(**build_model_context_args): + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +# define spiky loss as a variation of 20% or more +SPIKY_LOSS_PERC = 0.2 + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are determinisic + fatal=False, + ) + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel, config=None): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + # Sometimes --data-path is too long, instead we parse it from a file. + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path=args.s3_cache_path, + do_cache_local=args.data_cache_local, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) \ No newline at end of file diff --git a/tests/model_info.json b/tests/model_info.json index 1df2e6536255c12436f8d4773a0222d4bdfda111..205942a0f19216b4d0e284823174c763cedd6743 100644 --- a/tests/model_info.json +++ b/tests/model_info.json @@ -7343,6 +7343,54 @@ "github_branch": "", "github_path": "", "priority": "P4" + }, + { + "model_name": "qwen2.5-3b", + "framework": "pytorch", + "release_version": "25.09", + "release_sdk": "4.3.0", + "release_gpgpu": "BI-V150", + "latest_sdk": "", + "latest_gpgpu": "", + "category": "nlp/llm", + "toolbox": "", + "mdims": "", + "dataset": "", + "license": "", + "model_path": "deepsparkhub/nlp/llm/qwen2.5-3b/pytorch/", + "readme_file": "deepsparkhub/nlp/llm/qwen2.5-3b/pytorch/README.md", + "bitbucket_repo": "", + "bitbucket_branch": "", + "bitbucket_path": "", + "develop_owner": "", + "github_repo": "", + "github_branch": "", + "github_path": "", + "priority": "P4" + }, + { + "model_name": "llama3_8b", + "framework": "megatron-lm", + "release_version": "25.09", + "release_sdk": "4.3.0", + "release_gpgpu": "BI-V150", + "latest_sdk": "", + "latest_gpgpu": "", + "category": "nlp/llm", + "toolbox": "", + "mdims": "", + "dataset": "", + "license": "", + "model_path": "deepsparkhub/nlp/llm/llama3_8b/megatron-lm/", + "readme_file": "deepsparkhub/nlp/llm/llama3_8b/megatron-lm/README.md", + "bitbucket_repo": "", + "bitbucket_branch": "", + "bitbucket_path": "", + "develop_owner": "", + "github_repo": "", + "github_branch": "", + "github_path": "", + "priority": "P4" } ] } \ No newline at end of file