From 1c5daadd610baed3bba2591621c221e5c7e06016 Mon Sep 17 00:00:00 2001 From: sanghui_ilu Date: Mon, 11 Aug 2025 23:02:06 +0800 Subject: [PATCH 1/2] add qwen2-7b grpo by verl link #ICJQ4R add qwen2-7b grpo by verl --- nlp/llm/qwen2-7b/verl/README.md | 44 + nlp/llm/qwen2-7b/verl/run_qwen2_7b_gsm8k.sh | 49 + toolbox/verl/pyproject.toml | 81 +- toolbox/verl/setup.py | 101 +- toolbox/verl/v0.5.0/requirements.txt | 26 + toolbox/verl/v0.5.0/setup.py | 99 + toolbox/verl/v0.5.0/verl/__init__.py | 64 + toolbox/verl/v0.5.0/verl/base_config.py | 91 + .../verl/v0.5.0/verl/experimental/__init__.py | 13 + .../verl/experimental/agent_loop/__init__.py | 21 + .../experimental/agent_loop/agent_loop.py | 543 ++++++ .../agent_loop/single_turn_agent_loop.py | 55 + .../agent_loop/tool_agent_loop.py | 166 ++ .../experimental/agent_loop/tool_parser.py | 106 + .../verl/experimental/dataset/__init__.py | 13 + .../verl/experimental/dataset/sampler.py | 40 + .../experimental/dynamic_dataset/__init__.py | 13 + .../dynamic_dataset/dynamicgen_dataset.py | 112 ++ .../verl/v0.5.0/verl/interactions/__init__.py | 15 + toolbox/verl/v0.5.0/verl/interactions/base.py | 72 + .../verl/interactions/gsm8k_interaction.py | 87 + .../verl/interactions/utils/__init__.py | 14 + .../utils/interaction_registry.py | 85 + .../verl/v0.5.0/verl/model_merger/__init__.py | 13 + .../verl/v0.5.0/verl/model_merger/__main__.py | 73 + .../verl/model_merger/base_model_merger.py | 345 ++++ .../verl/model_merger/fsdp_model_merger.py | 265 +++ .../model_merger/megatron_model_merger.py | 537 ++++++ toolbox/verl/v0.5.0/verl/models/README.md | 35 + toolbox/verl/v0.5.0/verl/models/__init__.py | 13 + .../verl/v0.5.0/verl/models/llama/__init__.py | 13 + .../verl/models/llama/megatron/__init__.py | 34 + .../megatron/checkpoint_utils/__init__.py | 13 + .../megatron/checkpoint_utils/llama_loader.py | 317 +++ .../llama_loader_depracated.py | 458 +++++ .../megatron/checkpoint_utils/llama_saver.py | 442 +++++ .../models/llama/megatron/layers/__init__.py | 34 + .../megatron/layers/parallel_attention.py | 460 +++++ .../llama/megatron/layers/parallel_decoder.py | 150 ++ .../llama/megatron/layers/parallel_linear.py | 106 + .../llama/megatron/layers/parallel_mlp.py | 74 + .../llama/megatron/layers/parallel_rmsnorm.py | 48 + .../llama/megatron/modeling_llama_megatron.py | 688 +++++++ .../verl/v0.5.0/verl/models/mcore/__init__.py | 30 + .../verl/models/mcore/config_converter.py | 392 ++++ .../verl/v0.5.0/verl/models/mcore/loader.py | 492 +++++ .../verl/v0.5.0/verl/models/mcore/mbridge.py | 23 + .../v0.5.0/verl/models/mcore/model_forward.py | 148 ++ .../verl/models/mcore/model_forward_fused.py | 327 ++++ .../verl/models/mcore/model_initializer.py | 263 +++ .../v0.5.0/verl/models/mcore/patch_v012.py | 215 +++ .../verl/models/mcore/qwen2_5_vl/__init__.py | 21 + .../verl/models/mcore/qwen2_5_vl/attention.py | 221 +++ .../verl/models/mcore/qwen2_5_vl/model.py | 340 ++++ .../models/mcore/qwen2_5_vl/rope_utils.py | 266 +++ .../models/mcore/qwen2_5_vl/vision_config.py | 85 + .../models/mcore/qwen2_5_vl/vision_model.py | 309 +++ .../qwen2_5_vl/vision_transformer_block.py | 265 +++ .../verl/v0.5.0/verl/models/mcore/readme.md | 99 + .../verl/v0.5.0/verl/models/mcore/registry.py | 237 +++ .../verl/v0.5.0/verl/models/mcore/saver.py | 497 +++++ toolbox/verl/v0.5.0/verl/models/mcore/util.py | 240 +++ .../verl/models/mcore/weight_converter.py | 479 +++++ .../verl/v0.5.0/verl/models/qwen2/__init__.py | 13 + .../verl/models/qwen2/megatron/__init__.py | 34 + .../megatron/checkpoint_utils/__init__.py | 13 + .../megatron/checkpoint_utils/qwen2_loader.py | 337 ++++ .../qwen2_loader_depracated.py | 475 +++++ .../megatron/checkpoint_utils/qwen2_saver.py | 448 +++++ .../models/qwen2/megatron/layers/__init__.py | 26 + .../megatron/layers/parallel_attention.py | 399 ++++ .../qwen2/megatron/layers/parallel_decoder.py | 150 ++ .../qwen2/megatron/layers/parallel_linear.py | 79 + .../qwen2/megatron/layers/parallel_mlp.py | 74 + .../qwen2/megatron/layers/parallel_rmsnorm.py | 48 + .../qwen2/megatron/modeling_qwen2_megatron.py | 737 +++++++ toolbox/verl/v0.5.0/verl/models/registry.py | 58 + .../verl/models/transformers/__init__.py | 13 + .../verl/models/transformers/dense_common.py | 193 ++ .../verl/models/transformers/kimi_vl.py | 185 ++ .../v0.5.0/verl/models/transformers/llama.py | 239 +++ .../verl/models/transformers/monkey_patch.py | 340 ++++ .../verl/models/transformers/npu_patch.py | 50 + .../v0.5.0/verl/models/transformers/qwen2.py | 241 +++ .../verl/models/transformers/qwen2_5_vl.py | 288 +++ .../verl/models/transformers/qwen2_vl.py | 559 ++++++ .../verl/models/weight_loader_registry.py | 56 + toolbox/verl/v0.5.0/verl/protocol.py | 964 ++++++++++ toolbox/verl/v0.5.0/verl/py.typed | 0 .../v0.5.0/verl/single_controller/__init__.py | 26 + .../verl/single_controller/base/__init__.py | 18 + .../verl/single_controller/base/decorator.py | 527 +++++ .../base/megatron/__init__.py | 13 + .../single_controller/base/megatron/worker.py | 106 + .../base/megatron/worker_group.py | 55 + .../base/register_center/__init__.py | 13 + .../base/register_center/ray.py | 37 + .../verl/single_controller/base/worker.py | 301 +++ .../single_controller/base/worker_group.py | 252 +++ .../verl/single_controller/ray/__init__.py | 29 + .../v0.5.0/verl/single_controller/ray/base.py | 893 +++++++++ .../verl/single_controller/ray/megatron.py | 77 + .../verl/v0.5.0/verl/third_party/__init__.py | 13 + .../verl/third_party/sglang/__init__.py | 26 + .../verl/third_party/sglang/parallel_state.py | 328 ++++ .../v0.5.0/verl/third_party/torch/__init__.py | 87 + .../third_party/torch/distributed/__init__.py | 87 + .../torch/distributed/_state_dict_utils.py | 840 ++++++++ .../torch/distributed/checkpoint/__init__.py | 87 + .../distributed/checkpoint/state_dict.py | 1493 +++++++++++++++ .../v0.5.0/verl/third_party/vllm/__init__.py | 59 + toolbox/verl/v0.5.0/verl/tools/__init__.py | 14 + toolbox/verl/v0.5.0/verl/tools/base_tool.py | 92 + toolbox/verl/v0.5.0/verl/tools/geo3k_tool.py | 99 + toolbox/verl/v0.5.0/verl/tools/gsm8k_tool.py | 106 + .../verl/v0.5.0/verl/tools/mcp_base_tool.py | 116 ++ .../verl/v0.5.0/verl/tools/mcp_search_tool.py | 69 + .../v0.5.0/verl/tools/sandbox_fusion_tools.py | 193 ++ toolbox/verl/v0.5.0/verl/tools/schemas.py | 89 + toolbox/verl/v0.5.0/verl/tools/search_tool.py | 278 +++ .../verl/v0.5.0/verl/tools/utils/__init__.py | 14 + .../utils/mcp_clients/McpClientManager.py | 97 + .../verl/tools/utils/mcp_clients/utils.py | 58 + .../verl/tools/utils/search_r1_like_utils.py | 243 +++ .../v0.5.0/verl/tools/utils/tool_registry.py | 107 ++ toolbox/verl/v0.5.0/verl/trainer/__init__.py | 13 + .../v0.5.0/verl/trainer/config/__init__.py | 26 + .../_generated_ppo_megatron_trainer.yaml | 418 ++++ .../config/_generated_ppo_trainer.yaml | 368 ++++ .../verl/trainer/config/actor/actor.yaml | 111 ++ .../verl/trainer/config/actor/dp_actor.yaml | 73 + .../trainer/config/actor/megatron_actor.yaml | 120 ++ .../v0.5.0/verl/trainer/config/algorithm.py | 114 ++ .../verl/v0.5.0/verl/trainer/config/config.py | 126 ++ .../verl/trainer/config/critic/critic.yaml | 94 + .../verl/trainer/config/critic/dp_critic.yaml | 95 + .../config/critic/megatron_critic.yaml | 130 ++ .../verl/trainer/config/data/legacy_data.yaml | 109 ++ .../verl/trainer/config/evaluation.yaml | 14 + .../verl/trainer/config/generation.yaml | 55 + .../config/npu_profile/npu_profile.yaml | 29 + .../trainer/config/ppo_megatron_trainer.yaml | 135 ++ .../verl/trainer/config/ppo_trainer.yaml | 336 ++++ .../verl/trainer/config/ref/dp_ref.yaml | 38 + .../verl/trainer/config/ref/megatron_ref.yaml | 51 + .../v0.5.0/verl/trainer/config/ref/ref.yaml | 21 + .../config/reward_model/dp_reward_model.yaml | 51 + .../reward_model/megatron_reward_model.yaml | 61 + .../config/reward_model/reward_model.yaml | 81 + .../verl/trainer/config/rollout/rollout.yaml | 215 +++ .../verl/trainer/config/sft_trainer.yaml | 85 + .../verl/v0.5.0/verl/trainer/constants_ppo.py | 37 + .../v0.5.0/verl/trainer/fsdp_sft_trainer.py | 825 ++++++++ toolbox/verl/v0.5.0/verl/trainer/main_eval.py | 80 + .../v0.5.0/verl/trainer/main_generation.py | 148 ++ toolbox/verl/v0.5.0/verl/trainer/main_ppo.py | 338 ++++ .../verl/v0.5.0/verl/trainer/ppo/__init__.py | 13 + .../v0.5.0/verl/trainer/ppo/core_algos.py | 1148 +++++++++++ .../v0.5.0/verl/trainer/ppo/metric_utils.py | 446 +++++ .../v0.5.0/verl/trainer/ppo/ray_trainer.py | 1421 ++++++++++++++ .../verl/v0.5.0/verl/trainer/ppo/reward.py | 179 ++ .../verl/v0.5.0/verl/trainer/runtime_env.yaml | 5 + toolbox/verl/v0.5.0/verl/utils/__init__.py | 19 + .../v0.5.0/verl/utils/activation_offload.py | 558 ++++++ .../v0.5.0/verl/utils/checkpoint/__init__.py | 13 + .../utils/checkpoint/checkpoint_manager.py | 237 +++ .../checkpoint/fsdp_checkpoint_manager.py | 350 ++++ .../checkpoint/megatron_checkpoint_manager.py | 525 +++++ toolbox/verl/v0.5.0/verl/utils/config.py | 65 + .../verl/v0.5.0/verl/utils/dataset/README.md | 16 + .../v0.5.0/verl/utils/dataset/__init__.py | 19 + .../utils/dataset/multiturn_sft_dataset.py | 334 ++++ .../v0.5.0/verl/utils/dataset/rl_dataset.py | 338 ++++ .../v0.5.0/verl/utils/dataset/rm_dataset.py | 144 ++ .../v0.5.0/verl/utils/dataset/sft_dataset.py | 183 ++ .../v0.5.0/verl/utils/dataset/vision_utils.py | 117 ++ .../verl/v0.5.0/verl/utils/debug/__init__.py | 17 + .../v0.5.0/verl/utils/debug/performance.py | 17 + .../verl/utils/debug/trajectory_tracker.py | 109 ++ toolbox/verl/v0.5.0/verl/utils/device.py | 86 + toolbox/verl/v0.5.0/verl/utils/distributed.py | 42 + .../verl/utils/experimental/__init__.py | 13 + .../utils/experimental/torch_functional.py | 216 +++ .../verl/v0.5.0/verl/utils/flops_counter.py | 312 +++ toolbox/verl/v0.5.0/verl/utils/fs.py | 292 +++ toolbox/verl/v0.5.0/verl/utils/fsdp_utils.py | 562 ++++++ toolbox/verl/v0.5.0/verl/utils/hdfs_io.py | 149 ++ .../verl/v0.5.0/verl/utils/import_utils.py | 156 ++ .../verl/v0.5.0/verl/utils/kernel/__init__.py | 31 + .../verl/v0.5.0/verl/utils/kernel/kernels.py | 1560 +++++++++++++++ .../verl/utils/kernel/linear_cross_entropy.py | 117 ++ .../verl/v0.5.0/verl/utils/logger/__init__.py | 32 + .../verl/utils/logger/aggregate_logger.py | 140 ++ .../verl/v0.5.0/verl/utils/logging_utils.py | 32 + .../v0.5.0/verl/utils/megatron/__init__.py | 13 + .../verl/utils/megatron/dist_checkpointing.py | 56 + .../verl/v0.5.0/verl/utils/megatron/memory.py | 38 + .../v0.5.0/verl/utils/megatron/optimizer.py | 80 + .../verl/utils/megatron/pipeline_parallel.py | 71 + .../verl/utils/megatron/sequence_parallel.py | 52 + .../verl/utils/megatron/tensor_parallel.py | 186 ++ .../verl/v0.5.0/verl/utils/megatron_utils.py | 1017 ++++++++++ .../verl/v0.5.0/verl/utils/memory_buffer.py | 218 +++ .../verl/v0.5.0/verl/utils/metric/__init__.py | 17 + .../verl/v0.5.0/verl/utils/metric/utils.py | 54 + toolbox/verl/v0.5.0/verl/utils/model.py | 664 +++++++ toolbox/verl/v0.5.0/verl/utils/net_utils.py | 61 + .../v0.5.0/verl/utils/profiler/__init__.py | 41 + .../verl/v0.5.0/verl/utils/profiler/config.py | 61 + .../verl/utils/profiler/empty_annotations.py | 40 + .../verl/utils/profiler/mstx_profile.py | 224 +++ .../verl/utils/profiler/nvtx_profile.py | 191 ++ .../v0.5.0/verl/utils/profiler/performance.py | 205 ++ .../v0.5.0/verl/utils/profiler/profile.py | 227 +++ .../verl/v0.5.0/verl/utils/py_functional.py | 317 +++ toolbox/verl/v0.5.0/verl/utils/ray_utils.py | 81 + .../v0.5.0/verl/utils/rendezvous/__init__.py | 13 + .../verl/utils/rendezvous/ray_backend.py | 73 + .../verl/utils/reward_score/__init__.py | 134 ++ .../v0.5.0/verl/utils/reward_score/geo3k.py | 36 + .../v0.5.0/verl/utils/reward_score/gsm8k.py | 72 + .../v0.5.0/verl/utils/reward_score/math.py | 224 +++ .../verl/utils/reward_score/math_batch.py | 26 + .../verl/utils/reward_score/math_dapo.py | 272 +++ .../verl/utils/reward_score/math_verify.py | 39 + .../utils/reward_score/prime_code/README.md | 16 + .../utils/reward_score/prime_code/__init__.py | 73 + .../reward_score/prime_code/testing_util.py | 683 +++++++ .../utils/reward_score/prime_code/utils.py | 60 + .../utils/reward_score/prime_math/__init__.py | 411 ++++ .../utils/reward_score/prime_math/grader.py | 384 ++++ .../reward_score/prime_math/math_normalize.py | 192 ++ .../reward_score/sandbox_fusion/__init__.py | 117 ++ .../reward_score/sandbox_fusion/utils.py | 570 ++++++ .../reward_score/search_r1_like_qa_em.py | 156 ++ .../verl/v0.5.0/verl/utils/rollout_trace.py | 237 +++ .../v0.5.0/verl/utils/seqlen_balancing.py | 375 ++++ toolbox/verl/v0.5.0/verl/utils/tokenizer.py | 88 + .../verl/v0.5.0/verl/utils/torch_dtypes.py | 80 + .../v0.5.0/verl/utils/torch_functional.py | 771 ++++++++ toolbox/verl/v0.5.0/verl/utils/tracking.py | 425 +++++ toolbox/verl/v0.5.0/verl/utils/ulysses.py | 328 ++++ toolbox/verl/v0.5.0/verl/utils/vllm_utils.py | 203 ++ toolbox/verl/v0.5.0/verl/version/version | 1 + toolbox/verl/v0.5.0/verl/workers/__init__.py | 13 + .../v0.5.0/verl/workers/actor/__init__.py | 18 + .../verl/v0.5.0/verl/workers/actor/base.py | 66 + .../v0.5.0/verl/workers/actor/dp_actor.py | 486 +++++ .../verl/workers/actor/megatron_actor.py | 658 +++++++ .../v0.5.0/verl/workers/critic/__init__.py | 18 + .../verl/v0.5.0/verl/workers/critic/base.py | 40 + .../v0.5.0/verl/workers/critic/dp_critic.py | 256 +++ .../verl/workers/critic/megatron_critic.py | 334 ++++ .../v0.5.0/verl/workers/engine/__init__.py | 17 + .../verl/v0.5.0/verl/workers/engine/base.py | 235 +++ .../verl/workers/engine/fsdp/__init__.py | 16 + .../verl/workers/engine/fsdp/engine_impl.py | 727 +++++++ .../v0.5.0/verl/workers/engine/fsdp/utils.py | 61 + .../verl/workers/engine/megatron/__init__.py | 13 + .../workers/engine/megatron/engine_impl.py | 166 ++ .../verl/v0.5.0/verl/workers/fsdp_workers.py | 1700 +++++++++++++++++ .../v0.5.0/verl/workers/megatron_workers.py | 1159 +++++++++++ .../verl/workers/reward_manager/__init__.py | 29 + .../verl/workers/reward_manager/batch.py | 122 ++ .../verl/workers/reward_manager/dapo.py | 146 ++ .../verl/workers/reward_manager/naive.py | 120 ++ .../verl/workers/reward_manager/prime.py | 186 ++ .../verl/workers/reward_manager/registry.py | 51 + .../verl/workers/reward_model/__init__.py | 17 + .../v0.5.0/verl/workers/reward_model/base.py | 44 + .../workers/reward_model/megatron/__init__.py | 17 + .../reward_model/megatron/reward_model.py | 350 ++++ .../v0.5.0/verl/workers/roles/__init__.py | 17 + .../verl/v0.5.0/verl/workers/roles/actor.py | 51 + .../verl/v0.5.0/verl/workers/roles/critic.py | 183 ++ .../v0.5.0/verl/workers/rollout/__init__.py | 19 + .../verl/workers/rollout/async_server.py | 282 +++ .../verl/v0.5.0/verl/workers/rollout/base.py | 28 + .../verl/workers/rollout/chat_scheduler.py | 444 +++++ .../v0.5.0/verl/workers/rollout/hf_rollout.py | 175 ++ .../verl/workers/rollout/naive/__init__.py | 17 + .../workers/rollout/naive/naive_rollout.py | 120 ++ .../v0.5.0/verl/workers/rollout/schemas.py | 675 +++++++ .../rollout/sglang_rollout/__init__.py | 16 + .../sglang_rollout/async_sglang_server.py | 95 + .../rollout/sglang_rollout/sglang_rollout.py | 1402 ++++++++++++++ .../workers/rollout/sglang_rollout/utils.py | 108 ++ .../v0.5.0/verl/workers/rollout/tokenizer.py | 163 ++ .../workers/rollout/vllm_rollout/__init__.py | 46 + .../rollout/vllm_rollout/vllm_async_server.py | 338 ++++ .../rollout/vllm_rollout/vllm_rollout_spmd.py | 501 +++++ .../verl/workers/sharding_manager/__init__.py | 13 + .../verl/workers/sharding_manager/base.py | 35 + .../workers/sharding_manager/fsdp_sglang.py | 263 +++ .../workers/sharding_manager/fsdp_ulysses.py | 72 + .../workers/sharding_manager/fsdp_vllm.py | 342 ++++ .../sharding_manager/megatron_sglang.py | 266 +++ .../workers/sharding_manager/megatron_vllm.py | 219 +++ 298 files changed, 61200 insertions(+), 76 deletions(-) create mode 100644 nlp/llm/qwen2-7b/verl/README.md create mode 100644 nlp/llm/qwen2-7b/verl/run_qwen2_7b_gsm8k.sh create mode 100644 toolbox/verl/v0.5.0/requirements.txt create mode 100644 toolbox/verl/v0.5.0/setup.py create mode 100644 toolbox/verl/v0.5.0/verl/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/base_config.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/agent_loop/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/agent_loop/agent_loop.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/agent_loop/single_turn_agent_loop.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_agent_loop.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_parser.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/dataset/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/dataset/sampler.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/dynamicgen_dataset.py create mode 100644 toolbox/verl/v0.5.0/verl/interactions/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/interactions/base.py create mode 100644 toolbox/verl/v0.5.0/verl/interactions/gsm8k_interaction.py create mode 100644 toolbox/verl/v0.5.0/verl/interactions/utils/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/interactions/utils/interaction_registry.py create mode 100644 toolbox/verl/v0.5.0/verl/model_merger/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/model_merger/__main__.py create mode 100644 toolbox/verl/v0.5.0/verl/model_merger/base_model_merger.py create mode 100644 toolbox/verl/v0.5.0/verl/model_merger/fsdp_model_merger.py create mode 100644 toolbox/verl/v0.5.0/verl/model_merger/megatron_model_merger.py create mode 100644 toolbox/verl/v0.5.0/verl/models/README.md create mode 100644 toolbox/verl/v0.5.0/verl/models/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_saver.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_attention.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_decoder.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_linear.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_mlp.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_rmsnorm.py create mode 100644 toolbox/verl/v0.5.0/verl/models/llama/megatron/modeling_llama_megatron.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/config_converter.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/loader.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/mbridge.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/model_forward.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/model_forward_fused.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/model_initializer.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/patch_v012.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/attention.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/model.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/rope_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_config.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_model.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/readme.md create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/registry.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/saver.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/util.py create mode 100644 toolbox/verl/v0.5.0/verl/models/mcore/weight_converter.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_attention.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_decoder.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_linear.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_mlp.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py create mode 100644 toolbox/verl/v0.5.0/verl/models/qwen2/megatron/modeling_qwen2_megatron.py create mode 100644 toolbox/verl/v0.5.0/verl/models/registry.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/dense_common.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/kimi_vl.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/llama.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/monkey_patch.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/npu_patch.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/qwen2.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/qwen2_5_vl.py create mode 100644 toolbox/verl/v0.5.0/verl/models/transformers/qwen2_vl.py create mode 100644 toolbox/verl/v0.5.0/verl/models/weight_loader_registry.py create mode 100644 toolbox/verl/v0.5.0/verl/protocol.py create mode 100644 toolbox/verl/v0.5.0/verl/py.typed create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/decorator.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker_group.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/register_center/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/register_center/ray.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/worker.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/base/worker_group.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/ray/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/ray/base.py create mode 100644 toolbox/verl/v0.5.0/verl/single_controller/ray/megatron.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/sglang/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/sglang/parallel_state.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/torch/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/torch/distributed/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/torch/distributed/_state_dict_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/state_dict.py create mode 100644 toolbox/verl/v0.5.0/verl/third_party/vllm/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/base_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/geo3k_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/gsm8k_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/mcp_base_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/mcp_search_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/sandbox_fusion_tools.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/schemas.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/search_tool.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/utils/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/McpClientManager.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/utils/search_r1_like_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/tools/utils/tool_registry.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_megatron_trainer.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_trainer.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/actor/actor.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/actor/dp_actor.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/actor/megatron_actor.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/algorithm.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/config.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/critic/critic.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/critic/dp_critic.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/critic/megatron_critic.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/data/legacy_data.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/evaluation.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/generation.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/npu_profile/npu_profile.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/ppo_megatron_trainer.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/ppo_trainer.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/ref/dp_ref.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/ref/megatron_ref.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/ref/ref.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/reward_model/dp_reward_model.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/reward_model/megatron_reward_model.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/reward_model/reward_model.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/rollout/rollout.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/config/sft_trainer.yaml create mode 100644 toolbox/verl/v0.5.0/verl/trainer/constants_ppo.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/fsdp_sft_trainer.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/main_eval.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/main_generation.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/main_ppo.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/ppo/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/ppo/core_algos.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/ppo/metric_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/ppo/ray_trainer.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/ppo/reward.py create mode 100644 toolbox/verl/v0.5.0/verl/trainer/runtime_env.yaml create mode 100644 toolbox/verl/v0.5.0/verl/utils/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/activation_offload.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/checkpoint/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/checkpoint/checkpoint_manager.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/checkpoint/megatron_checkpoint_manager.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/config.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/README.md create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/multiturn_sft_dataset.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/rl_dataset.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/rm_dataset.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/sft_dataset.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/dataset/vision_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/debug/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/debug/performance.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/debug/trajectory_tracker.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/device.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/distributed.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/experimental/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/experimental/torch_functional.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/flops_counter.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/fs.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/fsdp_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/hdfs_io.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/import_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/kernel/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/kernel/kernels.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/kernel/linear_cross_entropy.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/logger/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/logger/aggregate_logger.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/logging_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/dist_checkpointing.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/memory.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/optimizer.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/pipeline_parallel.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/sequence_parallel.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron/tensor_parallel.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/megatron_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/memory_buffer.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/metric/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/metric/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/model.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/net_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/config.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/empty_annotations.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/mstx_profile.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/nvtx_profile.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/performance.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/profiler/profile.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/py_functional.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/ray_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/rendezvous/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/rendezvous/ray_backend.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/geo3k.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/gsm8k.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/math.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/math_batch.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/math_dapo.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/math_verify.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_code/README.md create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_code/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_code/testing_util.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_code/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_math/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_math/grader.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/prime_math/math_normalize.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/sandbox_fusion/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/sandbox_fusion/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/reward_score/search_r1_like_qa_em.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/rollout_trace.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/seqlen_balancing.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/tokenizer.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/torch_dtypes.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/torch_functional.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/tracking.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/ulysses.py create mode 100644 toolbox/verl/v0.5.0/verl/utils/vllm_utils.py create mode 100644 toolbox/verl/v0.5.0/verl/version/version create mode 100644 toolbox/verl/v0.5.0/verl/workers/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/actor/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/actor/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/actor/dp_actor.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/actor/megatron_actor.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/critic/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/critic/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/critic/dp_critic.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/critic/megatron_critic.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/fsdp/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/fsdp/engine_impl.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/fsdp/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/engine/megatron/engine_impl.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/fsdp_workers.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/megatron_workers.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/batch.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/dapo.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/naive.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/prime.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_manager/registry.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_model/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_model/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_model/megatron/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/reward_model/megatron/reward_model.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/roles/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/roles/actor.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/roles/critic.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/async_server.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/chat_scheduler.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/hf_rollout.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/naive/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/naive/naive_rollout.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/schemas.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/sglang_rollout/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/sglang_rollout/async_sglang_server.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/sglang_rollout/sglang_rollout.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/sglang_rollout/utils.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/tokenizer.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/vllm_rollout/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/vllm_rollout/vllm_async_server.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/__init__.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/base.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/fsdp_sglang.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/fsdp_ulysses.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/fsdp_vllm.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/megatron_sglang.py create mode 100644 toolbox/verl/v0.5.0/verl/workers/sharding_manager/megatron_vllm.py diff --git a/nlp/llm/qwen2-7b/verl/README.md b/nlp/llm/qwen2-7b/verl/README.md new file mode 100644 index 000000000..b34bc86f5 --- /dev/null +++ b/nlp/llm/qwen2-7b/verl/README.md @@ -0,0 +1,44 @@ +# Qwen2.5-7B grpo (verl) + +## Model Description + +Qwen2 is the new series of Qwen large language models. For Qwen2, we release a number of base language models and instruction-tuned language models ranging from 0.5 to 72 billion parameters, including a Mixture-of-Experts model. Compared with the state-of-the-art opensource language models, including the previous released Qwen1.5, Qwen2 has generally surpassed most opensource models and demonstrated competitiveness against proprietary models across a series of benchmarks targeting for language understanding, language generation, multilingual capability, coding, mathematics, reasoning, etc. Qwen2-7B-Instruct supports a context length of up to 131,072 tokens, enabling the processing of extensive inputs. + +## Supported Environments + +| GPU | [IXUCA SDK](https://gitee.com/deep-spark/deepspark#%E5%A4%A9%E6%95%B0%E6%99%BA%E7%AE%97%E8%BD%AF%E4%BB%B6%E6%A0%88-ixuca) | Release | +| :----: | :----: | :----: | +| BI-V150 | 4.3.0 | 25.07 | + +## Environment Preparation + +### Install Dependencies +```bash +cd toolbox/verl/v0.5.0 +pip3 install -r requirements.txt +python3 setup.py install +``` + +### Prepare Resources + +```bash +python3 examples/data_preprocess/gsm8k.py +mv ~/data/gsm8k /home/datasets/verl/gsm8k + +# download Qwen2.5-7B-Instruct and put to /home/model_zoos/verl/Qwen2.5-7B-Instruct + +``` + +## Model Training + +### train on gsm8k +```bash +cd nlp/llm/qwen2-7b/verl +bash run_qwen2_7B_gsm8k.sh +``` + +## Model Results + +## References + +- [verl](https://github.com/volcengine/verl/tree/v0.5.0) diff --git a/nlp/llm/qwen2-7b/verl/run_qwen2_7b_gsm8k.sh b/nlp/llm/qwen2-7b/verl/run_qwen2_7b_gsm8k.sh new file mode 100644 index 000000000..7c1a56dfa --- /dev/null +++ b/nlp/llm/qwen2-7b/verl/run_qwen2_7b_gsm8k.sh @@ -0,0 +1,49 @@ +set -x + + +HOME=$(cd "$(dirname "$0")"; pwd) +echo "HOME:$HOME" +PATH_DATASETS=/home/datasets/verl +PATH_MODEL=/home/model_zoos/verl + +export VLLM_USE_V1=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$PATH_DATASETS/gsm8k/train.parquet \ + data.val_files=$PATH_DATASETS/gsm8k/test.parquet \ + data.train_batch_size=16 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$PATH_MODEL/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=100 \ + trainer.total_epochs=1 $@ diff --git a/toolbox/verl/pyproject.toml b/toolbox/verl/pyproject.toml index 2c27268b4..78273d8ee 100644 --- a/toolbox/verl/pyproject.toml +++ b/toolbox/verl/pyproject.toml @@ -16,50 +16,54 @@ name = "verl" # We'll mark the version as "dynamic" because it's read from the file "verl/version/version" # (PEP 621 calls this "dynamic version"). # The actual version is specified in the [tool.setuptools.dynamic] section below. -dynamic = ["version"] +dynamic = ["version", "dependencies", "optional-dependencies", "urls"] description = "verl: Volcano Engine Reinforcement Learning for LLM" -license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier +license = {text = "Apache-2.0"} # Changed from file to text format readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.10" -authors = [ - { name = "Bytedance - Seed - MLSys", email = "zhangchi.usc1992@bytedance.com" }, - { name = "Bytedance - Seed - MLSys", email = "gmsheng@connect.hku.hk" }, -] +# ------------------------------- +# tool.ruff - Linting configuration +# ------------------------------- +[tool.ruff] +# Note: While the formatter will attempt to format lines such that they remain within the line-length, +# it isn't a hard upper bound, and formatted lines may exceed the line-length. +line-length = 120 +exclude = ["tests/workers/rollout/test_sglang_async_rollout_sf_tools.py", "scripts/legacy_model_merger.py"] -# Dependencies corresponding to install_requires in setup.py -dependencies = [ - "accelerate", - "codetiming", - "datasets", - "dill", - "hydra-core", - "numpy", - "pandas", - "peft", - "pyarrow>=15.0.0", - "pybind11", - "pylatexenc", - "ray>=2.10", - "tensordict<0.6", - "transformers", - # "vllm<=0.6.3", - "peft", - # "liger-kernel", +[tool.ruff.lint] +isort = {known-first-party = ["verl"]} +# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # isort + "I", + "G", ] - -# Optional dependencies (extras_require in setup.py) -[project.optional-dependencies] -test = [ - "pytest", "yapf", "py-spy", +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # `.log()` statement uses f-string + "G004", + # X | None for type annotations + "UP045", + # deprecated import + "UP035", ] -prime = ["pyext"] -gpu = ["liger-kernel", "flash-attn"] - -# URLs -[project.urls] -Homepage = "https://github.com/volcengine/verl" # ------------------------------- # tool.setuptools - Additional config @@ -82,5 +86,6 @@ version = {file = "verl/version/version"} [tool.setuptools.package-data] verl = [ "version/*", - "trainer/config/*.yaml" + "trainer/config/*.yaml", + "trainer/config/*/*.yaml", ] diff --git a/toolbox/verl/setup.py b/toolbox/verl/setup.py index 4133d6e70..1671d8488 100644 --- a/toolbox/verl/setup.py +++ b/toolbox/verl/setup.py @@ -13,62 +13,87 @@ # limitations under the License. # setup.py is the fallback installation script when pyproject.toml does not work -from setuptools import setup, find_packages import os +from pathlib import Path + +from setuptools import find_packages, setup version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) -with open(os.path.join(version_folder, 'verl/version/version')) as f: +with open(os.path.join(version_folder, "verl/version/version")) as f: __version__ = f.read().strip() install_requires = [ - 'accelerate', - 'codetiming', - 'datasets', - 'dill', - 'hydra-core', - 'numpy', - 'pandas', - 'peft', - 'pyarrow>=15.0.0', - 'pybind11', - 'pylatexenc', - 'ray>=2.10', - 'tensordict<0.6', - 'transformers', - 'vllm<=0.6.3', - 'wandb', + "accelerate", + "codetiming", + "datasets", + "dill", + "hydra-core", + "numpy<2.0.0", + "pandas", + "peft", + "pyarrow>=19.0.0", + "pybind11", + "pylatexenc", + "ray[default]>=2.41.0", + "torchdata", + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "transformers", + "wandb", + "packaging>=20.0", ] -TEST_REQUIRES = ['pytest', 'yapf', 'py-spy'] -PRIME_REQUIRES = ['pyext'] -GPU_REQUIRES = ['liger-kernel', 'flash-attn'] +TEST_REQUIRES = ["pytest", "pre-commit", "py-spy", "pytest-asyncio"] +PRIME_REQUIRES = ["pyext"] +GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] +GPU_REQUIRES = ["liger-kernel", "flash-attn"] +MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency +VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"] +SGLANG_REQUIRES = [ + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "sglang[srt,openai]==0.4.6.post5", + "torch-memory-saver>=0.0.5", + "torch==2.6.0", +] +TRL_REQUIRES = ["trl<=0.9.6"] +MCORE_REQUIRES = ["mbridge"] extras_require = { - 'test': TEST_REQUIRES, - 'prime': PRIME_REQUIRES, - 'gpu': GPU_REQUIRES, + "test": TEST_REQUIRES, + "prime": PRIME_REQUIRES, + "geo": GEO_REQUIRES, + "gpu": GPU_REQUIRES, + "math": MATH_REQUIRES, + "vllm": VLLM_REQUIRES, + "sglang": SGLANG_REQUIRES, + "trl": TRL_REQUIRES, + "mcore": MCORE_REQUIRES, } -from pathlib import Path + this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() +if "VERL_LOCAL_VERSION_IDENTIFIER" in os.environ: + __version__ += "+" + str(os.environ['VERL_LOCAL_VERSION_IDENTIFIER']) + setup( - name='verl', + name="verl", version=__version__, - package_dir={'': '.'}, - packages=find_packages(where='.'), - url='https://github.com/volcengine/verl', - license='Apache 2.0', - author='Bytedance - Seed - MLSys', - author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk', - description='verl: Volcano Engine Reinforcement Learning for LLM', - install_requires=install_requires, + package_dir={"": "."}, + packages=find_packages(where="."), + # url="https://github.com/volcengine/verl", + # license="Apache 2.0", + # author="Bytedance - Seed - MLSys", + # author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk", + # description="verl: Volcano Engine Reinforcement Learning for LLM", + # install_requires=install_requires, extras_require=extras_require, - package_data={'': ['version/*'], - 'verl': ['trainer/config/*.yaml'],}, + package_data={ + "": ["version/*"], + "verl": ["trainer/config/**/*", "trainer/config/*"], + }, include_package_data=True, long_description=long_description, - long_description_content_type='text/markdown' -) \ No newline at end of file + long_description_content_type="text/markdown", +) diff --git a/toolbox/verl/v0.5.0/requirements.txt b/toolbox/verl/v0.5.0/requirements.txt new file mode 100644 index 000000000..acba9e39e --- /dev/null +++ b/toolbox/verl/v0.5.0/requirements.txt @@ -0,0 +1,26 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +#flash-attn +hydra-core +#liger-kernel +numpy==1.26.4 +pandas +peft +pyarrow>=19.0.0 +pybind11 +pylatexenc +pre-commit +ray[default] +tensordict==0.6.2 +torchdata +#transformers==4.52.0 +# vllm==0.9.1 +wandb +packaging>=20.0 +uvicorn +fastapi +latex2sympy2_extended +math_verify diff --git a/toolbox/verl/v0.5.0/setup.py b/toolbox/verl/v0.5.0/setup.py new file mode 100644 index 000000000..1671d8488 --- /dev/null +++ b/toolbox/verl/v0.5.0/setup.py @@ -0,0 +1,99 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# setup.py is the fallback installation script when pyproject.toml does not work +import os +from pathlib import Path + +from setuptools import find_packages, setup + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +with open(os.path.join(version_folder, "verl/version/version")) as f: + __version__ = f.read().strip() + +install_requires = [ + "accelerate", + "codetiming", + "datasets", + "dill", + "hydra-core", + "numpy<2.0.0", + "pandas", + "peft", + "pyarrow>=19.0.0", + "pybind11", + "pylatexenc", + "ray[default]>=2.41.0", + "torchdata", + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "transformers", + "wandb", + "packaging>=20.0", +] + +TEST_REQUIRES = ["pytest", "pre-commit", "py-spy", "pytest-asyncio"] +PRIME_REQUIRES = ["pyext"] +GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] +GPU_REQUIRES = ["liger-kernel", "flash-attn"] +MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency +VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"] +SGLANG_REQUIRES = [ + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "sglang[srt,openai]==0.4.6.post5", + "torch-memory-saver>=0.0.5", + "torch==2.6.0", +] +TRL_REQUIRES = ["trl<=0.9.6"] +MCORE_REQUIRES = ["mbridge"] + +extras_require = { + "test": TEST_REQUIRES, + "prime": PRIME_REQUIRES, + "geo": GEO_REQUIRES, + "gpu": GPU_REQUIRES, + "math": MATH_REQUIRES, + "vllm": VLLM_REQUIRES, + "sglang": SGLANG_REQUIRES, + "trl": TRL_REQUIRES, + "mcore": MCORE_REQUIRES, +} + + +this_directory = Path(__file__).parent +long_description = (this_directory / "README.md").read_text() + +if "VERL_LOCAL_VERSION_IDENTIFIER" in os.environ: + __version__ += "+" + str(os.environ['VERL_LOCAL_VERSION_IDENTIFIER']) + +setup( + name="verl", + version=__version__, + package_dir={"": "."}, + packages=find_packages(where="."), + # url="https://github.com/volcengine/verl", + # license="Apache 2.0", + # author="Bytedance - Seed - MLSys", + # author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk", + # description="verl: Volcano Engine Reinforcement Learning for LLM", + # install_requires=install_requires, + extras_require=extras_require, + package_data={ + "": ["version/*"], + "verl": ["trainer/config/**/*", "trainer/config/*"], + }, + include_package_data=True, + long_description=long_description, + long_description_content_type="text/markdown", +) diff --git a/toolbox/verl/v0.5.0/verl/__init__.py b/toolbox/verl/v0.5.0/verl/__init__.py new file mode 100644 index 000000000..593f3dc61 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/__init__.py @@ -0,0 +1,64 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import os +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as get_version + +from packaging.version import parse as parse_version + +from .protocol import DataProto +from .utils.device import is_npu_available +from .utils.logging_utils import set_basic_config + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +with open(os.path.join(version_folder, "version/version")) as f: + __version__ = f.read().strip() + + +set_basic_config(level=logging.WARNING) + + +__all__ = ["DataProto", "__version__"] + +if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": + if importlib.util.find_spec("modelscope") is None: + raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") + # Patch hub to download models from modelscope to speed up. + from modelscope.utils.hf_util import patch_hub + + patch_hub() + +if is_npu_available: + from .models.transformers import npu_patch as npu_patch + + package_name = "transformers" + required_version_spec = "4.52.4" + try: + installed_version = get_version(package_name) + installed = parse_version(installed_version) + required = parse_version(required_version_spec) + + if installed < required: + raise ValueError( + f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is " + f"{installed}." + ) + except PackageNotFoundError as e: + raise ImportError( + f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}" + ) from e diff --git a/toolbox/verl/v0.5.0/verl/base_config.py b/toolbox/verl/v0.5.0/verl/base_config.py new file mode 100644 index 000000000..0cd117bb6 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/base_config.py @@ -0,0 +1,91 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from dataclasses import ( + dataclass, + field, + fields, # Import the fields function to inspect dataclass fields +) +from typing import Any + + +# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary +@dataclass +class BaseConfig(collections.abc.Mapping): + """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. + + The BaseConfig class implements the Mapping Abstract Base Class. + This allows instances of this class to be used like dictionaries. + """ + + extra: dict[str, Any] = field(default_factory=dict) + + def __setattr__(self, name: str, value): + # if the field already exists (i.e. was set in __init__) + # and is in our frozen list, block assignment + if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__: + from dataclasses import FrozenInstanceError + + raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified") + # otherwise do the normal thing + super().__setattr__(name, value) + + def get(self, key: str, default: Any = None) -> Any: + """Get the value associated with the given key. If the key does not exist, return the default value. + + Args: + key (str): The attribute name to retrieve. + default (Any, optional): The value to return if the attribute does not exist. Defaults to None. + + Returns: + Any: The value of the attribute or the default value. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + def __getitem__(self, key: str): + """Implement the [] operator for the class. Allows accessing attributes like dictionary items. + + Args: + key (str): The attribute name to retrieve. + + Returns: + Any: The value of the attribute. + + Raises: + AttributeError: If the attribute does not exist. + TypeError: If the key type is not string + """ + return getattr(self, key) + + def __iter__(self): + """Implement the iterator protocol. Allows iterating over the attribute names of the instance. + + Yields: + str: The name of each field in the dataclass. + """ + for f in fields(self): + yield f.name + + def __len__(self): + """ + Return the number of fields in the dataclass. + + Returns: + int: The number of fields in the dataclass. + """ + return len(fields(self)) diff --git a/toolbox/verl/v0.5.0/verl/experimental/__init__.py b/toolbox/verl/v0.5.0/verl/experimental/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/experimental/agent_loop/__init__.py b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/__init__.py new file mode 100644 index 000000000..a39171db7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent_loop import AgentLoopBase, AgentLoopManager +from .single_turn_agent_loop import SingleTurnAgentLoop +from .tool_agent_loop import ToolAgentLoop + +_ = [SingleTurnAgentLoop, ToolAgentLoop] + +__all__ = ["AgentLoopBase", "AgentLoopManager"] diff --git a/toolbox/verl/v0.5.0/verl/experimental/agent_loop/agent_loop.py b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/agent_loop.py new file mode 100644 index 000000000..ef8638102 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/agent_loop.py @@ -0,0 +1,543 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import logging +import os +import random +from abc import ABC, abstractmethod +from typing import Any + +import hydra +import numpy as np +import ray +import torch +from cachetools import LRUCache +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel +from tensordict import TensorDict +from transformers import AutoTokenizer + +from verl.protocol import DataProto +from verl.single_controller.ray.base import RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op +from verl.workers.rollout.async_server import async_server_class + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AsyncLLMServerManager: + """ + A class to manage multiple OpenAI compatible LLM servers. This class provides + - Load balance: least requests load balancing + - Sticky session: send multi-turn chat completions to same server for automatic prefix caching + """ + + def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000): + """Initialize the AsyncLLMServerManager. + + Args: + config (DictConfig): YAML config. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000. + """ + self.config = config + self.server_handles = server_handles + random.shuffle(self.server_handles) + + # Least requests load balancing + self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles] + heapq.heapify(self.weighted_serveres) + + # LRU cache to map request_id to server + self.request_id_to_server = LRUCache(maxsize=max_cache_size) + + def _choose_server(self, request_id: str) -> ray.actor.ActorHandle: + # TODO: implement server pressure awareness load balancing + if request_id in self.request_id_to_server: + return self.request_id_to_server[request_id] + + server = self.weighted_serveres[0][1][1] + self.weighted_serveres[0][0] += 1 + heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0]) + self.request_id_to_server[request_id] = server + return server + + @rollout_trace_op + async def generate( + self, + request_id, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + ) -> list[int]: + """Generate tokens from prompt ids. + + Args: + request_id (str): request id for sticky session. + prompt_ids (List[int]): List of prompt token ids. + sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. + + Returns: + List[int]: List of generated token ids. + """ + server = self._choose_server(request_id) + output = await server.generate.remote( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + ) + return output + + +class AgentLoopMetrics(BaseModel): + """Agent loop performance metrics.""" + + generate_sequences: float = 0.0 + tool_calls: float = 0.0 + + +class AgentLoopOutput(BaseModel): + """Agent loop output.""" + + prompt_ids: list[int] + """Prompt token ids.""" + response_ids: list[int] + """Response token ids including LLM generated token, tool response token.""" + response_mask: list[int] + """Response mask, 1 for LLM generated token, 0 for tool response token.""" + num_turns: int = 0 + """Number of chat turns, including user, assistant, tool.""" + metrics: AgentLoopMetrics + """Auxiliary performance metrics""" + + +# make hydra.utils.instantiate happy +class _DummyConfig: + def __init__(self, config: DictConfig) -> None: + self.config = config + + +class AgentLoopBase(ABC): + """An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various + environments.""" + + _class_initialized = False + + def __init__( + self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs + ): + """Initialize agent loop, each sample will have its own loop instance. + + Args: + trainer_config (_DummyConfig): trainer config. + server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + """ + self.init_class(trainer_config.config, tokenizer, **kwargs) + self.config = trainer_config.config + self.server_manager = server_manager + self.tokenizer = tokenizer + self.loop = asyncio.get_running_loop() + + @classmethod + def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs): + """This is used to do heavy initialization work that should shared across all instances. It's only called once. + + Args: + config (DictConfig): trainer config. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`. + """ + if cls._class_initialized: + return + cls._class_initialized = True + + @abstractmethod + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + """Run agent loop to interact with LLM server and environment. + + Args: + messages (List[Dict[str, Any]]): Input messages. + sampling_params (Dict[str, Any]): LLM sampling params. + + Returns: + AgentLoopOutput: Agent loop output. + """ + raise NotImplementedError + + +"""Agent loop registry: key is agent_name, value is a dict of agent loop config +used by hydra.utils.instantiate to initialize agent loop instance. + +https://hydra.cc/docs/advanced/instantiate_objects/overview/ +""" +_agent_loop_registry: dict[str, dict] = {} + + +def register(agent_name: str): + """Register agent loop class.""" + + def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: + fqdn = f"{subclass.__module__}.{subclass.__qualname__}" + _agent_loop_registry[agent_name] = {"_target_": fqdn} + return subclass + + return decorator + + +@ray.remote +class AgentLoopWorker: + """Agent loop worker takes a batch of messages and run each message in an agent loop.""" + + def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]): + """Initialize agent loop manager. + + Args: + config (DictConfig): YAML config. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + """ + self.config = config + self.server_manager = AsyncLLMServerManager(config, server_handles) + + model_path = config.actor_rollout_ref.model.path + self.model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(config.actor_rollout_ref.model.path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path + if agent_loop_config_path: + agent_loop_configs = OmegaConf.load(agent_loop_config_path) + for agent_loop_config in agent_loop_configs: + _agent_loop_registry[agent_loop_config.name] = agent_loop_config + + trace_config = config.trainer.get("rollout_trace", {}) + trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) + RolloutTraceConfig.init( + self.config.trainer.project_name, + self.config.trainer.experiment_name, + trace_config.get("backend"), + trace_config.get("token2text", False), + ) + + async def generate_sequences(self, batch: DataProto) -> DataProto: + """Generate sequences from agent loop. + + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch. + - prompts: [bsz, prompt_length], prompt token ids from dataset. + - responses: [bsz, response_length], output token ids include response tokens + from LLM generation and observation tokens from tool_calls. + - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. + - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens + and response tokens. + - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. + - position_ids: [bsz, prompt_length + response_length], incremental position ids. + + For multi-turn conversations: + responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| + response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| + """ + config = self.config.actor_rollout_ref.rollout + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + repetition_penalty=1.0, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["temperature"] = config.val_kwargs.temperature + + # by default, we assume it's a single turn agent + if "agent_name" not in batch.non_tensor_batch: + batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) + + tasks = [] + agent_names = batch.non_tensor_batch["agent_name"] + raw_prompts = batch.non_tensor_batch["raw_prompt"] + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(raw_prompts)) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) + ) + + for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True): + tasks.append( + asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory)) + ) + outputs = await asyncio.gather(*tasks) + + output = self._postprocess(outputs) + return output + + async def _run_agent_loop( + self, + agent_name: str, + messages: list[dict[str, Any]], + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + ) -> AgentLoopOutput: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=_DummyConfig(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + ) + output = await agent_loop.run(messages, sampling_params) + return output + + def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts + self.tokenizer.padding_side = "left" + outputs = self.tokenizer.pad( + [{"input_ids": input.prompt_ids} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # responses + self.tokenizer.padding_side = "right" + outputs = self.tokenizer.pad( + [{"input_ids": input.response_ids} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=True, + ) + response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # response_mask + outputs = self.tokenizer.pad( + [{"input_ids": input.response_mask} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=False, + ) + response_mask = outputs["input_ids"] + assert response_ids.shape == response_mask.shape, ( + f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" + ) + response_mask = response_mask * response_attention_mask + + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompt_ids, # [bsz, prompt_length] + "responses": response_ids, # [bsz, response_length] + "response_mask": response_mask, # [bsz, response_length] + "input_ids": input_ids, # [bsz, prompt_length + response_length] + "attention_mask": attention_mask, # [bsz, prompt_length + response_length] + "position_ids": position_ids, # [bsz, prompt_length + response_length] + }, + batch_size=len(input_ids), + ) + + num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32) + metrics = [input.metrics.model_dump() for input in inputs] + return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics}) + + +async def get_trajectory_info(step, index, validate): + """Get trajectory info. + + Args: + step (int): global steps in the trainer. + index (list): form datastore extra_info.index column. + validate (bool): whether is a validate step. + + Returns: + list: trajectory. + """ + trajectory_info = [] + rollout_n = 0 + for i in range(len(index)): + if i > 0 and index[i - 1] == index[i]: + rollout_n += 1 + else: + rollout_n = 0 + trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate}) + return trajectory_info + + +class AgentLoopManager: + """Agent loop manager that manages a group of agent loop workers.""" + + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + """Initialize agent loop manager. + + Args: + config (DictConfig): trainer config. + worker_group (RayWorkerGroup): ActorRolloutRef worker group. + """ + self.config = config + self.worker_group = worker_group + + self._initialize_llm_servers() + self._init_agent_loop_workers() + + # Initially we're in sleep mode. + self.sleep() + + def _initialize_llm_servers(self): + self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size + + register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + workers_info = ray.get(register_center.get_worker_info.remote()) + assert len(workers_info) == self.worker_group.world_size + + self.async_llm_servers = [None] * self.rollout_dp_size + self.server_addresses = [None] * self.rollout_dp_size + + if self.config.actor_rollout_ref.rollout.agent.custom_async_server: + server_class = async_server_class( + rollout_backend=self.config.actor_rollout_ref.rollout.name, + rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path, + rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name, + ) + else: + server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name) + + # Start all server instances, restart if address already in use. + unready_dp_ranks = set(range(self.rollout_dp_size)) + while len(unready_dp_ranks) > 0: + servers = { + rollout_dp_rank: server_class.options( + # make sure AsyncvLLMServer colocates with its corresponding workers + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_server_{rollout_dp_rank}", + ).remote(self.config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in unready_dp_ranks + } + + for rollout_dp_rank, server in servers.items(): + try: + address = ray.get(server.get_server_address.remote()) + self.server_addresses[rollout_dp_rank] = address + self.async_llm_servers[rollout_dp_rank] = server + unready_dp_ranks.remove(rollout_dp_rank) + except Exception: + ray.kill(server) + print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") + + # All server instances are ready, init AsyncLLM engine. + ray.get([server.init_engine.remote() for server in self.async_llm_servers]) + + def _init_agent_loop_workers(self): + self.agent_loop_workers = [] + for i in range(self.config.actor_rollout_ref.rollout.agent.num_workers): + self.agent_loop_workers.append( + AgentLoopWorker.options( + name=f"agent_loop_worker_{i}", + ).remote(self.config, self.async_llm_servers) + ) + + def generate_sequences(self, prompts: DataProto) -> DataProto: + """Split input batch and dispatch to agent loop workers. + + Args: + prompts (DataProto): Input batch. + + Returns: + DataProto: Output batch. + """ + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.wake_up() + chunkes = prompts.chunk(len(self.agent_loop_workers)) + outputs = ray.get( + [ + worker.generate_sequences.remote(chunk) + for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) + ] + ) + output = DataProto.concat(outputs) + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.sleep() + + # calculate performance metrics + metrics = [output.meta_info["metrics"] for output in outputs] # List[List[Dict[str, str]]] + timing = self._performance_metrics(metrics, output) + + output.meta_info = {"timing": timing} + return output + + def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: + timing = {} + t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) + t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) + timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() + timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() + timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() + timing["agent_loop/tool_calls/min"] = t_tool_calls.min() + timing["agent_loop/tool_calls/max"] = t_tool_calls.max() + timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() + + # batch sequence generation is bounded by the slowest sample + slowest = np.argmax(t_generate_sequences + t_tool_calls) + attention_mask = output.batch["attention_mask"][slowest] + prompt_length = output.batch["prompts"].shape[1] + timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] + timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] + timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() + timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() + + return timing + + def wake_up(self): + """Wake up all rollout server instances.""" + ray.get([server.wake_up.remote() for server in self.async_llm_servers]) + + def sleep(self): + """Sleep all rollout server instances.""" + ray.get([server.sleep.remote() for server in self.async_llm_servers]) diff --git a/toolbox/verl/v0.5.0/verl/experimental/agent_loop/single_turn_agent_loop.py b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/single_turn_agent_loop.py new file mode 100644 index 000000000..411388e73 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -0,0 +1,55 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register +from verl.utils.profiler import simple_timer + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("single_turn_agent") +class SingleTurnAgentLoop(AgentLoopBase): + """Naive agent loop that only do single turn chat completion.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length + self.response_length = self.config.actor_rollout_ref.rollout.response_length + + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + metrics = {} + request_id = uuid4().hex + prompt_ids = await self.loop.run_in_executor( + None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + ) + + with simple_timer("generate_sequences", metrics): + response_ids = await self.server_manager.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + response_mask = [1] * len(response_ids) + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + num_turns=2, + metrics=metrics, + ) + return output diff --git a/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_agent_loop.py b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_agent_loop.py new file mode 100644 index 000000000..3437c0be5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_agent_loop.py @@ -0,0 +1,166 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +import logging +import os +from typing import Any +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register +from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser +from verl.tools.utils.tool_registry import initialize_tools_from_config +from verl.utils.profiler import simple_timer +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("tool_agent") +class ToolAgentLoop(AgentLoopBase): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + if cls._class_initialized: + return + cls._class_initialized = True + print("Performing class-level ToolAgentLoop initialization") + + # Initialize tools from config file + cls.tokenizer = tokenizer + cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns + cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls + cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length + cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side + tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path + tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + cls.tools = {tool.name: tool for tool in tool_list} + cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer) + print(f"Initialized tools: {cls.tools}") + + cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length + cls.response_length = config.actor_rollout_ref.rollout.response_length + cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) + + @rollout_trace_op + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + metrics = {} + request_id = uuid4().hex + prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True + ), + ) + response_mask = [] + + user_turns, assistant_turns = 0, 0 + while True: + with simple_timer("generate_sequences", metrics): + response_ids = await self.server_manager.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + prompt_ids += response_ids + response_mask += [1] * len(response_ids) + assistant_turns += 1 + + # reach max response length + if len(response_mask) >= self.response_length: + break + + # reach max assistant turns + if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: + break + + # reach max user turns + if self.max_user_turns and user_turns >= self.max_user_turns: + break + + # no tool calls + _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids) + if not tool_calls: + break + + # call tools + tasks = [] + for tool_call in tool_calls[: self.max_parallel_calls]: + tasks.append(self._call_tool(tool_call)) + with simple_timer("tool_calls", metrics): + tool_responses = await asyncio.gather(*tasks) + if any(isinstance(item, Exception) for item in tool_responses): + break + + # append tool_response_ids + tool_response_ids = await self.loop.run_in_executor( + None, + lambda messages=tool_responses: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ), + ) + tool_response_ids = tool_response_ids[len(self.system_prompt) :] + + # NOTE: last turn should not be user turn, or the EOS token reward + # can't be propagated to previous token in GAE. + if len(response_mask) + len(tool_response_ids) >= self.response_length: + break + + prompt_ids += tool_response_ids + response_mask += [0] * len(tool_response_ids) + user_turns += 1 + + response_ids = prompt_ids[-len(response_mask) :] + prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + num_turns=user_turns + assistant_turns + 1, + metrics=metrics, + ) + return output + + async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: + """Call tool and return tool response.""" + tool, instance_id = None, None + try: + # TODO: append malformed tool_call to the prompt: invalid function name or arguments + tool_name = tool_call.name + tool_args = json.loads(tool_call.arguments) + tool = self.tools[tool_name] + + instance_id = await tool.create() + tool_response, _, _ = await tool.execute(instance_id, tool_args) + except Exception as e: + logger.exception(f"Error when executing tool: {e}") + return e + finally: + if tool and instance_id: + await tool.release(instance_id) + + if len(tool_response) > self.max_tool_response_length: + if self.tool_response_truncate_side == "left": + tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" + elif self.tool_response_truncate_side == "right": + tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] + else: + length = self.max_tool_response_length // 2 + tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] + + return { + "role": "tool", + "content": tool_response, + } diff --git a/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_parser.py b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_parser.py new file mode 100644 index 000000000..5b4de4a8e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/agent_loop/tool_parser.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +import logging +import os +from abc import ABC, abstractmethod + +import regex as re +from pydantic import BaseModel + +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class FunctionCall(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ToolParser(ABC): + _registry: dict[str, type["ToolParser"]] = {} + + def __init__(self, tokenizer) -> None: + self.tokenizer = tokenizer + + @abstractmethod + async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: + """Extract tool calls from the responses. + + Args: + responses_ids (List[int]): The ids of the responses. + + Returns: + Tuple[str, List[FunctionCall]]: Content and extracted tool calls. + """ + raise NotImplementedError + + @classmethod + def get_tool_parser(cls, name: str, tokenizer): + if name not in cls._registry: + raise ValueError(f"Unknown tool parser: {name}") + return cls._registry[name](tokenizer) + + @classmethod + def register(cls, name: str): + def decorator(subclass: type[ToolParser]) -> type[ToolParser]: + cls._registry[name] = subclass + return subclass + + return decorator + + +@ToolParser.register("hermes") +class HermesToolParser(ToolParser): + """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" + + def __init__(self, tokenizer) -> None: + super().__init__(tokenizer) + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) + + @rollout_trace_op + async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: + loop = asyncio.get_running_loop() + text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) + if self.tool_call_start_token not in text or self.tool_call_end_token not in text: + return text, [] + + matches = self.tool_call_regex.findall(text) + function_calls = [] + for match in matches: + try: + function_call = json.loads(match) + name, arguments = function_call["name"], function_call["arguments"] + function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) + except Exception as e: + logger.error(f"Failed to decode tool call: {e}") + + # remaing text exclude tool call tokens + content = self.tool_call_regex.sub("", text) + + return content, function_calls diff --git a/toolbox/verl/v0.5.0/verl/experimental/dataset/__init__.py b/toolbox/verl/v0.5.0/verl/experimental/dataset/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/dataset/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/experimental/dataset/sampler.py b/toolbox/verl/v0.5.0/verl/experimental/dataset/sampler.py new file mode 100644 index 000000000..b7b15b422 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/dataset/sampler.py @@ -0,0 +1,40 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from collections.abc import Sized + +from omegaconf import DictConfig +from torch.utils.data import Sampler + +from verl import DataProto + + +class AbstractSampler(Sampler[int]): + """Abstract interface for custom samplers.""" + + @abstractmethod + def __init__( + self, + data_source: Sized, + data_config: DictConfig, + ): + pass + + +class AbstractCurriculumSampler(AbstractSampler): + """Experimental interface for curriculum learning samplers.""" + + @abstractmethod + def update(self, batch: DataProto) -> None: + pass diff --git a/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/__init__.py b/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/dynamicgen_dataset.py b/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/dynamicgen_dataset.py new file mode 100644 index 000000000..a9532aa03 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/experimental/dynamic_dataset/dynamicgen_dataset.py @@ -0,0 +1,112 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Dataset class that enables dynamic data generation strategies between iterations of training. +This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. + +This is especially useful in settings where proposer model generates new tasks based +on rollout data. +""" + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import datasets +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin + +from verl import DataProto +from verl.utils.dataset import RLHFDataset +from verl.utils.import_utils import load_extern_type + +logger = logging.getLogger(__name__) + + +class AbstractDataGenerator(ABC): + def __init__(self, config: DictConfig): + self.config = config + + @abstractmethod + def generate(self, dataset: Dataset) -> datasets.Dataset: + """ + Generate method must be implemented by subclasses. + Args: + dataset: The dataset to generate from. + Returns: + Processed data or result as implemented by the subclass. + """ + pass + + +class MockDataGenerator(AbstractDataGenerator): + """ + A noop data gen class that only reappends the first datapoint. + This class is useful as a placeholder and testing. + """ + + def __init__(self, config: DictConfig = None): + super().__init__(config) + + def generate(self, dataset: Dataset) -> datasets.Dataset: + print("MockDataGenerator: No operation performed on the dataset.") + return dataset.dataframe.select([0]) + + +class DynamicGenDataset(RLHFDataset): + """ + A dataset class that uses a data generation strategy to process data. + This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + super().__init__(data_files, tokenizer, config, processor) + self.datagen: AbstractDataGenerator = config.datagen + assert "datagen" in config and config.datagen.get("path", None) is not None, ( + f"datagen path is not set in config: {config}" + ) + # Dynamically load the custom datagen class + datagen_cls = load_extern_type(config.datagen.path, config.datagen.name) + + # Verify that the custom datagen class inherits from AbstractDataGenerator + abs_cls = AbstractDataGenerator + if not issubclass(datagen_cls, abs_cls): + raise TypeError( + f"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'" + + " must inherit from {abs_cls}" + ) + + self.data_generator = datagen_cls(config.datagen) + self.on_batch_end() + + def append_dataframe(self, new_dataframe: datasets.Dataset): + new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe) + self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe]) + + logger.info(f"new dataset len: {len(self.dataframe)}") + + def on_batch_end(self, batch: DataProto) -> None: + """ + Generate data using the provided data generation strategy. + Note: This method is intended to change the dataset after each training batch. + """ + new_data = self.data_generator.generate(self) + self.append_dataframe(new_data) diff --git a/toolbox/verl/v0.5.0/verl/interactions/__init__.py b/toolbox/verl/v0.5.0/verl/interactions/__init__.py new file mode 100644 index 000000000..b6db0fcef --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/interactions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/interactions/base.py b/toolbox/verl/v0.5.0/verl/interactions/base.py new file mode 100644 index 000000000..7c5d200ab --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/interactions/base.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional +from uuid import uuid4 + + +class BaseInteraction: + def __init__(self, config: dict[str, Any]): + self.config = config + self.name: str = config.get("name", "interaction_agent") # More general agent default role name + + async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method + """ + Generates a response for the current turn of interaction. + Returns a tuple containing: + - should_terminate_sequence (bool): True if the interaction sequence should end. + - response_content (str): The textual content of the response. + - current_turn_score (float): The score for this specific turn/response. + - additional_data (dict): Any extra information or metadata. + """ + should_terminate_sequence: bool = False # if True, end rollout + response_content: str = "Your current result seems acceptable." + current_turn_score: float = 0.8 + additional_data: dict[str, Any] = {} + return should_terminate_sequence, response_content, current_turn_score, additional_data + + async def calculate_score(self) -> float: # More clear score calculation method + """ + Calculates a score for the interaction, + potentially considering aspects like partial exposure & in-context task switching. + should be invoke at turn-level + """ + # ...implement the logic to calculate turn-level score... + score = 0.0 + return score + + async def finalize_interaction(self) -> None: # More clear interaction end and resource release method + """ + Finalizes the interaction session and releases any associated state or resources. + Simulates: release state + """ + # ...implement the logic to release state... + pass diff --git a/toolbox/verl/v0.5.0/verl/interactions/gsm8k_interaction.py b/toolbox/verl/v0.5.0/verl/interactions/gsm8k_interaction.py new file mode 100644 index 000000000..c30768391 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/interactions/gsm8k_interaction.py @@ -0,0 +1,87 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k + +from .base import BaseInteraction + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kInteraction(BaseInteraction): + """A demo interaction for calculating the reward of gsm8k. + + - `start_interaction`: start a interaction instance for a trajectory. + - `generate_response`: generate the response of the user. + - `calculate_score`: calculate the score of the interaction. + - `finalize_interaction`: finalize the interaction instance. + """ + + def __init__(self, config: dict): + super().__init__(config) + self._instance_dict = {} + + async def start_interaction( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict]: + content = "" + for i in range(len(messages) - 1, -1, -1): + item = messages[i] + if item.get("role") == "assistant": + content = item.get("content") + break + + self._instance_dict[instance_id]["response"] = content + + reward = await self.calculate_score(instance_id) + if reward == 1.0: + response = "Your response is correct!" + should_terminate_sequence = True + else: + response = "Your response is incorrect! You need to reflect on your answer and try again." + should_terminate_sequence = False + + return should_terminate_sequence, response, reward, {} + + async def calculate_score(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="strict", + format_score=0.0, + score=1.0, + ) + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/toolbox/verl/v0.5.0/verl/interactions/utils/__init__.py b/toolbox/verl/v0.5.0/verl/interactions/utils/__init__.py new file mode 100644 index 000000000..c4b932b1a --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/interactions/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/interactions/utils/interaction_registry.py b/toolbox/verl/v0.5.0/verl/interactions/utils/interaction_registry.py new file mode 100644 index 000000000..df747af11 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/interactions/utils/interaction_registry.py @@ -0,0 +1,85 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import logging +import os +import sys + +from omegaconf import OmegaConf + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_interaction_class(cls_name): + """Dynamically import and return the interaction class.""" + module_name, class_name = cls_name.rsplit(".", 1) + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + interaction_cls = getattr(module, class_name) + return interaction_cls + + +def initialize_interactions_from_config(interaction_config_file): + """Initialize interactions from configuration file. + + Args: + interaction_config_file: Path to the interaction configuration file. + + Returns: + dict: A dictionary mapping interaction names to BaseInteraction instances. + """ + interaction_config = OmegaConf.load(interaction_config_file) + interaction_map = {} + + for interaction_item in interaction_config.interaction: + cls_name = interaction_item.class_name + interaction_cls = get_interaction_class(cls_name) + + # Extract config and name + config = OmegaConf.to_container(interaction_item.config, resolve=True) + + # Get the interaction name - either from config or derive from class name + name = interaction_item.get("name", None) + if name is None: + # If no name is specified, use the class name as default + class_simple_name = cls_name.split(".")[-1] + # Remove "Interaction" suffix if present, otherwise use full class name + if class_simple_name.endswith("Interaction"): + name = class_simple_name[:-11].lower() # Remove "Interaction" (11 chars) + else: + name = class_simple_name.lower() + + # Check for duplicate names + if name in interaction_map: + raise ValueError(f"Duplicate interaction name '{name}' found. Each interaction must have a unique name.") + + # Inject the name into the config + config["name"] = name + + # Create the interaction instance + interaction = interaction_cls(config=config) + interaction_map[name] = interaction + + logger.info(f"Initialized interaction '{name}' with class '{cls_name}'") + + return interaction_map diff --git a/toolbox/verl/v0.5.0/verl/model_merger/__init__.py b/toolbox/verl/v0.5.0/verl/model_merger/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/model_merger/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/model_merger/__main__.py b/toolbox/verl/v0.5.0/verl/model_merger/__main__.py new file mode 100644 index 000000000..f3ab5b9c2 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/model_merger/__main__.py @@ -0,0 +1,73 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. + +To merge FSDP checkpoints: +```sh +python -m verl.model_merger merge \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +To merge Megatron checkpoints: +```sh +python -m verl.model_merger merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +or use distribtued merge for large models like dpskv3 671B + +```sh +torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\ + --backend megatron \ + --local_dir ./checkpoints/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + + +For more details, please refer to documentation: +https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model +""" + +from .base_model_merger import generate_config_from_args, parse_args + + +def main(): + args = parse_args() + config = generate_config_from_args(args) + print(f"config: {config}") + + if config.backend == "fsdp": + from .fsdp_model_merger import FSDPModelMerger + + merger = FSDPModelMerger(config) + elif config.backend == "megatron": + from .megatron_model_merger import MegatronModelMerger + + merger = MegatronModelMerger(config) + else: + raise NotImplementedError(f"Unknown backend: {config.backend}") + + merger.merge_and_save() + merger.cleanup() + + +if __name__ == "__main__": + main() diff --git a/toolbox/verl/v0.5.0/verl/model_merger/base_model_merger.py b/toolbox/verl/v0.5.0/verl/model_merger/base_model_merger.py new file mode 100644 index 000000000..08859cc55 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/model_merger/base_model_merger.py @@ -0,0 +1,345 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import init_empty_weights +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + GenerationConfig, +) + +from verl.utils import hf_processor, hf_tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser(description="verl model merger") + subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + + base_op_parser = argparse.ArgumentParser(add_help=False) + base_op_parser.add_argument( + "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + ) + base_op_parser.add_argument("--local_dir", type=str, default=None, help="Path to the saved model checkpoints.") + base_op_parser.add_argument( + "--tie-word-embedding", + action="store_true", + help="Whether to tie word embedding weights (currently only Megatron supported)", + ) + base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code") + base_op_parser.add_argument( + "--is-value-model", + action="store_true", + help="Whether the model is a value model (currently only Megatron supported)", + ) + base_op_parser.add_argument( + "--use_cpu_initialization", + action="store_true", + help="Whether to use CPU initialization for the model. This is useful for large models that cannot " + "fit into GPU memory during initialization.", + ) + + merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser.add_argument( + "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + ) + merge_parser.add_argument( + "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + ) + merge_parser.add_argument( + "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + ) + + test_parser = subparsers.add_parser( + "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + ) + test_parser.add_argument( + "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + ) + + args = parser.parse_args() + return args + + +@dataclass +class ModelMergerConfig: + """Configuration for model merger operations. + + Args: + operation (str): Operation type - 'merge' or 'test'. + backend (str): Backend type for the model ('fsdp' or 'megatron'). + target_dir (Optional[str]): Directory to save the merged huggingface model. Defaults to "tmp". + hf_upload_path (Optional[str]): Hugging Face repository ID to upload the model. Defaults to None. + private (bool): Whether to upload the model to a private Hugging Face repository. Defaults to False. + test_hf_dir (Optional[str]): Path to the reference Hugging Face model directory for testing. Defaults to None. + tie_word_embedding (bool): Whether to tie word embedding weights (currently only Megatron + supported). Defaults to False. + trust_remote_code (bool): Whether to trust remote code. Defaults to False. + is_value_model (bool): Whether the model is a value model (currently only Megatron + supported). Defaults to False. + local_dir (Optional[str]): Path to the saved model checkpoints. Defaults to None. + hf_model_config_path (Optional[str]): Path to HuggingFace model configuration files. Defaults to None. + hf_upload (bool): Whether to upload to HuggingFace (computed automatically). Not for initialization. + use_cpu_initialization (bool): Whether to use CPU initialization for large models. Defaults to False. + """ + + operation: str # 'merge' or 'test' + backend: str + target_dir: Optional[str] = "tmp" + hf_upload_path: Optional[str] = None + private: bool = False + test_hf_dir: Optional[str] = None + tie_word_embedding: bool = False + trust_remote_code: bool = False + is_value_model: bool = False + local_dir: Optional[str] = None + hf_model_config_path: Optional[str] = None + hf_upload: bool = field(init=False) + use_cpu_initialization: bool = False + + def __post_init__(self): + self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) + if self.operation == "test": + self.target_dir = None + self.hf_upload_path = None + self.private = False + + +def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig: + common_config_args = { + "operation": args.operation, + "backend": args.backend, + "tie_word_embedding": args.tie_word_embedding, + "trust_remote_code": args.trust_remote_code, + "is_value_model": args.is_value_model, + "local_dir": args.local_dir, + "hf_model_config_path": os.path.join(args.local_dir, "huggingface"), + "use_cpu_initialization": args.use_cpu_initialization, + } + + if args.operation == "merge": + config = ModelMergerConfig( + **common_config_args, + target_dir=args.target_dir, + hf_upload_path=args.hf_upload_path, + private=args.private, + test_hf_dir=None, + ) + os.makedirs(config.target_dir, exist_ok=True) + elif args.operation == "test": + config = ModelMergerConfig( + **common_config_args, + test_hf_dir=args.test_hf_dir, + # the following args are not used by test operation + target_dir=None, + hf_upload_path=None, + private=False, + ) + else: + raise NotImplementedError(f"Unknown operation: {args.operation}") + return config + + +class BaseModelMerger(ABC): + """ + Abstract base class for merging distributed model checkpoints into HuggingFace format. + + This class provides common functionality for converting model checkpoints from different + distributed training backends (FSDP, Megatron) into standard HuggingFace format that + can be easily loaded and used for inference or further training. + + The merger supports two main operations: + - merge: Convert and save checkpoints to HuggingFace format + - test: Validate merged checkpoints against a reference model + + Args: + config (ModelMergerConfig): Configuration object containing paths, backend type, + and operation parameters. + + Attributes: + config (ModelMergerConfig): The configuration object passed during initialization. + hf_model_config_path (str): Path to the HuggingFace model configuration files. + model_config (PretrainedConfig): Loaded HuggingFace model configuration. + """ + + def __init__(self, config: ModelMergerConfig): + self.config = config + self.hf_model_config_path = config.hf_model_config_path + self.model_config = AutoConfig.from_pretrained( + self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + + def get_transformers_auto_model_class(self): + if "ForTokenClassification" in self.model_config.architectures[0]: + return AutoModelForTokenClassification + elif "ForCausalLM" in self.model_config.architectures[0]: + return AutoModelForCausalLM + elif "ForConditionalGeneration" in self.model_config.architectures[0]: + return AutoModelForVision2Seq + + raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + + def patch_model_generation_config(self, model): + """ + The generation_config created from model config may be different to the pretrained model, + this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 + + This function patch the generation_config created from model config to the pretrained model. + """ + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + except OSError: + print( + f"Warning: Generation config file not found in {self.hf_model_config_path}, using a " + f"generation config created from the model config." + ) + return model + + def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): + """ + Save lora adapter to safetensors. + + Returns: + lora_path: str, the path to the lora adapter. None if no lora adapter found. + + Note: + This function change the 'state_dict' in place. + """ + lora_params_names = [name for name in state_dict.keys() if "lora_" in name] + + if len(lora_params_names) == 0: + return None + + import json + from typing import OrderedDict + + import peft + from safetensors.torch import save_file + + lora_params = OrderedDict() + target_modules = set() + lora_key = None + + for name in lora_params_names: + lora_key = name.replace(".default.weight", ".weight") + target_modules.add(lora_key.split(".")[-3]) + lora_params[lora_key] = state_dict.pop(name) + + lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) + peft_dict = { + "r": lora_rank, + "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. + "target_modules": list(target_modules), + } + peft_config = peft.LoraConfig(**peft_dict).to_dict() + peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None + peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["target_modules"] = list(peft_config["target_modules"]) + + lora_path = os.path.join(self.config.target_dir, "lora_adapter") + os.makedirs(lora_path, exist_ok=True) + with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) + + for name in list(state_dict.keys()): + key = ( + name.replace("base_model.model.", "") + .replace(".base_layer.weight", ".weight") + .replace(".base_layer.bias", ".bias") + ) + state_dict[key] = state_dict.pop(name) + + return lora_path + + def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + with init_empty_weights(): + model = auto_model_class.from_config( + self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code + ) + model.to_empty(device="cpu") + model = self.patch_model_generation_config(model) + + lora_path = self.save_lora_adapter(state_dict) + if lora_path: + print(f"Saving lora adapter to {lora_path}") + + print(f"Saving model to {self.config.target_dir}") + model.save_pretrained(self.config.target_dir, state_dict=state_dict) + del state_dict + del model + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def upload_to_huggingface(self): + import requests + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError + + api = HfApi() + try: + # Attempt to create repository + api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + except HfHubHTTPError as e: + # Handle authentication/API errors + if e.response.status_code == 401: + raise PermissionError( + "Hugging Face authentication failed. Verify your token is valid and has write permissions." + ) from e + elif e.response.status_code == 404: + raise RepositoryNotFoundError(f"Repository path not found: {self.config.hf_upload_path}") from e + else: + raise ConnectionError(f"Failed to create repository ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network connection failed. Check your internet connection.") from e + + try: + # Attempt folder upload + api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + except HfHubHTTPError as e: + if e.response.status_code == 401: + raise PermissionError("Authentication failed during upload. Token may have expired.") from e + else: + raise RuntimeError(f"Upload failed ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network interruption during upload. Try again with stable connection.") from e + except OSError as e: + raise FileNotFoundError(f"Local folder error: {self.config.target_dir} - {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error during upload: {str(e)}") from e + + @abstractmethod + def merge_and_save(self): + raise NotImplementedError("Subclasses should implement this method") + + @abstractmethod + def cleanup(self): + raise NotImplementedError("Subclasses should implement this method to clean up resources if needed") diff --git a/toolbox/verl/v0.5.0/verl/model_merger/fsdp_model_merger.py b/toolbox/verl/v0.5.0/verl/model_merger/fsdp_model_merger.py new file mode 100644 index 000000000..7853b2b79 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/model_merger/fsdp_model_merger.py @@ -0,0 +1,265 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import numpy as np +import torch +from torch.distributed._tensor import Placement, Shard + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from tqdm import tqdm + +from .base_model_merger import BaseModelMerger + + +class FSDPModelMerger(BaseModelMerger): + """ + Model merger for FSDP (Fully Sharded Data Parallel) checkpoints. + + This class handles the conversion of FSDP distributed checkpoints into HuggingFace format. + FSDP shards model parameters across multiple processes, and this merger reconstructs + the full model by loading and concatenating the sharded parameters from all ranks. + + The merger supports various FSDP configurations including: + - Pure FSDP (single dimension sharding) + - FSDP + DDP (data parallel + fully sharded data parallel) + - DTensor-based sharding with custom device meshes + + Key features: + - Automatic detection of world size from checkpoint filenames + - Support for DTensor and non-DTensor checkpoints + - Parallel loading of checkpoint shards for efficiency + - Validation against reference HuggingFace models + + Example: + To merge FSDP checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="fsdp", + local_dir="path/to/fsdp/checkpoints", + target_dir="path/to/output" + ) + merger = FSDPModelMerger(config) + merger.merge_and_save() + ``` + """ + + def _get_world_size(self) -> int: + """_summary_ + From FSDP json config file, extract the world size. + + Returns: + int: world size + """ + config_path = Path(self.config.local_dir) / "fsdp_config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file {config_path} does not exist.") + + with open(config_path) as f: + config = json.load(f) + + # Extract world size from the config + world_size = config.get("world_size", None) + if world_size is None: + raise ValueError("World size not found in the config file.") + + return world_size + + def _load_rank_zero_state_dict(self, world_size: int) -> dict: + return torch.load( + Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", + map_location="cpu", + weights_only=False, + ) + + def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + """ + Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. + If no DTensor is found, infers a simple FSDP mesh based on world_size. + """ + pivot_key = sorted(list(state_dict.keys()))[0] + weight = state_dict[pivot_key] + + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([world_size], dtype=np.int64) + mesh_dim_names = ("fsdp",) + + return mesh, mesh_dim_names + + def _calculate_shard_configuration( + self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] + ) -> tuple[int, tuple[int, ...]]: + """Calculates the total number of shards and the shape of the device mesh.""" + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + + if "tp" in mesh_dim_names: + # TODO: "tp" is not supported yet due to the above assert + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + return total_shards, mesh_shape + + def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + """Merges a list of tensors based on their DTensor placement""" + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + + raise NotImplementedError(f"Unsupported placement: {placement}") + + def _load_and_merge_state_dicts( + self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + ) -> dict[str, torch.Tensor]: + model_state_dict_lst = [None] * total_shards + + def process_one_shard(rank: int, model_state_dict_lst: list): + model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] + for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + future.result() + + # Merge state dicts from all shards + state_dict = {} + param_placements: dict[str, list] = {} + + for key in set(model_state_dict_lst[0].keys()): + state_dict[key] = [] + for model_state_shard in model_state_dict_lst: + # add tensor shard in order of rank to state_dict[key] + tensor = model_state_shard.pop(key) + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + + placements = tuple(tensor.placements) + # replicated placement at dp dimension can be discarded + if mesh_dim_names[0] in ("dp", "ddp"): + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) + + del model_state_dict_lst + + # Merge tensors + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + if key in param_placements: + # merge shards + placements: tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = self._merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) + + return state_dict + + def merge_and_save(self): + world_size = self._get_world_size() + rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) + + mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + + total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + print(f"Processing model shards with {total_shards} {mesh_shape} in total") + + merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + + hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_state_dict = hf_model.state_dict() + del hf_model + + hf_model_keys = set(hf_state_dict.keys()) + collected_keys = set(state_dict.keys()) + + missing_keys = hf_model_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + + extra_keys = collected_keys - hf_model_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + + for key in hf_model_keys: + hf_shape = hf_state_dict[key].shape + collected_shape = state_dict[key].shape + assert hf_shape == collected_shape, ( + f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" + ) + + hf_dtype = hf_state_dict[key].dtype + collected_dtype = state_dict[key].dtype + assert hf_dtype == collected_dtype, ( + f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" + ) + + torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + + def cleanup(self): + """Cleanup temporary files if needed.""" + # FSDP merger does not create temporary files, so no cleanup is needed. + pass diff --git a/toolbox/verl/v0.5.0/verl/model_merger/megatron_model_merger.py b/toolbox/verl/v0.5.0/verl/model_merger/megatron_model_merger.py new file mode 100644 index 000000000..5be281681 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/model_merger/megatron_model_merger.py @@ -0,0 +1,537 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, ContextManager + +import numpy as np +import torch +import torch.distributed as dist +from accelerate import init_empty_weights +from megatron.core import mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from safetensors.torch import load_file +from transformers import ( + AutoConfig, + PretrainedConfig, +) + +from verl.models.mcore import hf_to_mcore_config +from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device +from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing +from verl.utils.megatron_utils import get_model +from verl.utils.tokenizer import hf_processor, hf_tokenizer + +from .base_model_merger import BaseModelMerger, ModelMergerConfig + + +@contextmanager +def noop_context() -> Any: + yield + + +def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]: + """Calculate the pipeline sharding configuration for Megatron-LM. + + Args: + layer_num: Total number of layers in the model. + pp_size: Number of pipeline parallel ranks. + + Returns: + layer number of each pp rank. Make the sharding of the pipeline as uniform as possible. + """ + if layer_num < pp_size: + raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.") + + if pp_size < 1: + raise ValueError(f"pp_size must be at least 1, got {pp_size}.") + if pp_size == 1: + return [layer_num] + + if pp_size == 2: + return [ + layer_num // 2, + layer_num - layer_num // 2, + ] + + middle_size = pp_size - 2 + shards_strategy = [] + for middle_layer_num in range(layer_num): + first_last_layer_num = layer_num - middle_layer_num * middle_size + first_layer_num = first_last_layer_num // 2 + last_layer_num = first_last_layer_num - first_last_layer_num // 2 + if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num: + shards_strategy.append( + ( + [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num], + abs(first_layer_num - middle_layer_num), + ) + ) + + # sort by diff of layer_num, to make it as uniform as possible + res = sorted(shards_strategy, key=lambda x: x[1])[0][0] + assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}" + return res + + +class MegatronModelMerger(BaseModelMerger): + """ + Model merger for Megatron-LM distributed checkpoints. + + This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format. + Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute + large language models across multiple GPUs. This merger reconstructs the full model by + loading distributed checkpoints and applying the necessary transformations. + + Key features: + - Support for tensor parallel, pipeline parallel, and data parallel configurations + - Automatic parameter name mapping from Megatron to HuggingFace conventions + - Handling of QKV and gate-up tensor splitting/merging + - Support for tied word embeddings and value models + - Integration with Megatron's distributed checkpointing system + + The merger handles various model architectures and configurations: + - Standard transformer models (GPT-style) + - Models with tied word embeddings + - Value models for reinforcement learning + - Multi-layer attention (MLA) architectures + - Mixture of Experts (MoE) models + + Args: + config (ModelMergerConfig): Configuration object with Megatron-specific settings + including tie_word_embedding and is_value_model flags. + + Example: + To merge Megatron checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="megatron", + local_dir="path/to/megatron/checkpoints", + target_dir="path/to/output", + tie_word_embedding=True + ) + merger = MegatronModelMerger(config) + merger.merge_and_save() + ``` + """ + + def __init__(self, config: ModelMergerConfig): + super().__init__(config) + # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards + if "WORLD_SIZE" not in os.environ: + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group(get_nccl_backend()) + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + local_rank = os.environ.get("LOCAL_RANK", 0) + get_torch_device().set_device(f"{get_device_name()}:{local_rank}") + + mpu.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=self.world_size, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + model_parallel_cuda_manual_seed(0) + self.hf_config = AutoConfig.from_pretrained( + self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + print(self.hf_config, flush=True) + + self.params_mapping = { + # megatron core gpt model name, huggingface model name + # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the + # longer key within the containing relationship is processed first. + "embedding.word_embeddings": "model.embed_tokens", + # input layer norm for dpskv3 + "input_layernorm.weight": "input_layernorm.weight", + "input_layernorm.bias": "input_layernorm.bias", + # attn + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", + "self_attention.linear_qkv": "self_attn.qkv_proj", + "self_attention.q_layernorm": "self_attn.q_norm", + "self_attention.k_layernorm": "self_attn.k_norm", + "self_attention.linear_proj": "self_attn.o_proj", + # mla + "self_attention.linear_q_proj": "self_attn.q_proj", + "self_attention.linear_q_down_proj": "self_attn.q_a_proj", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + "self_attention.linear_q_up_proj": "self_attn.q_b_proj", + "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", + # mlp + "pre_mlp_layernorm": "post_attention_layernorm", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", + "mlp.linear_fc1": "mlp.gate_up_proj", + "mlp.linear_fc2": "mlp.down_proj", + # moe + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + "mlp.router": "mlp.gate", + "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", + "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", + "linear_fc1": "gate_up_proj", + "linear_fc2": "down_proj", + # output + "final_layernorm": "norm", + "output_layer": "lm_head", + } + + if "Qwen2MoeForCausalLM" in self.hf_config.architectures: + self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj" + self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" + self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight" + + def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]: + """_summary_ + Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory. + + Args: + model_ckpt_path (str): Path to the model checkpoint directory. + + Returns: + State dict containing the model parameters. + """ + + # init hf config + self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size) + print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}") + + tf_config = hf_to_mcore_config( + self.hf_config, + torch.bfloat16, + num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None, + num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None, + ) + tf_config.use_cpu_initialization = self.config.use_cpu_initialization + tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False) + + # init megatron model + def megatron_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=tie_word_embeddings, + value=False, + ) + return parallel_model + + context: Callable[..., ContextManager] = ( + init_empty_weights if self.config.use_cpu_initialization else noop_context + ) + with context(): + whole_model = get_model( + model_provider_func=megatron_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + transformer_config=tf_config, + ) + + if self.config.use_cpu_initialization: + # convert meta device to empty tensor so it can use `copy_` function + whole_model[0].module = whole_model[0].module.to_empty(device="cpu") + + # load state dicts + sharded_state_dict = {} + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + sharded_state_dict[key] = model.sharded_state_dict() + model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path) + model_state_dict_list = [] + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + model_state_dict_list.append(model_state_dict[key]) + + return model_state_dict_list + + def _check_megatron_state_key(self, key: str) -> bool: + """ + Checks if the key is a valid Megatron state key. + + Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. + Shall not use key starts with "model." + """ + if key.startswith("model."): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with " + f"'decoder/embedding/output_layer' in TransformerLayer." + ) + + skip_checking_keys = ["embedding.word_embeddings", "output_layer"] + for skip_key in skip_checking_keys: + if skip_key in key: + print(f"skip checking key {key}") + return + + # Exclude extra state keys + if not key.startswith("decoder"): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." + ) + + def _split_tensors( + self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False + ) -> list[torch.Tensor]: + """ + Splits a tensor into multiple tensors based on the name. + This is used to handle qkv and gate_up tensors. + """ + if "linear_fc1.weight" in key: + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + gate, up = tensor.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + return [gate, up] + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: + # if the tensor is qkv, for each param on tp, split into q, k, v + # concat q, k, v separately. + q_lst, k_lst, v_lst = [], [], [] + assert config.num_attention_heads % config.num_key_value_heads == 0 + num_q_per_kv = config.num_attention_heads // config.num_key_value_heads + assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( + f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" + ) + kv_size = tensor.shape[0] // (num_q_per_kv + 2) + split_size = [kv_size * num_q_per_kv, kv_size, kv_size] + + num_query_groups_per_partition = config.num_key_value_heads + for chunk in tensor.chunk(num_query_groups_per_partition): + split_size = [ + kv_size * num_q_per_kv // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + ] + q, k, v = chunk.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) + + return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] + else: + return [tensor] + + def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + state_dict = {} + layers_cum = 0 + if self.world_size > 1: + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + + print(f"{layers_cum=}") + for model_state_dict in model_state_dict_list: + layers_handled = 0 + keys = model_state_dict.keys() + for key in keys: + if "extra_state" in key: + continue + if self.config.tie_word_embedding and ("output_layer" in key): + print("skip lm_head and reward_head loading because of tie_word_embeddings") + continue + + self._check_megatron_state_key(key) + hf_name = self._replace_name(key, self.params_mapping) + assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + if "model.layers." in hf_name: + local_layer_no = int(hf_name.split(".")[2]) + layers_handled = max(local_layer_no, layers_handled) + global_layer_no = local_layer_no + layers_cum + new_key_list = hf_name.split(".") + new_key_list[2] = str(global_layer_no) + hf_name = ".".join(new_key_list) + else: + warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) + + if "mlp.experts." in hf_name and ".weight" in hf_name: + name_prefix, expert_id = hf_name.split(".weight") + for proj in ["gate_up", "down"]: + if f"{proj}_proj" in hf_name: + hf_name = hf_name.replace( + f"mlp.experts.{proj}_proj.weight{expert_id}", + f"mlp.experts.{expert_id}.{proj}_proj.weight", + ) + + tensor = model_state_dict[key] + split_tensor = self._split_tensors( + key, tensor, self.hf_config, is_value_model=self.config.is_value_model + ) + + if len(split_tensor) == 1: + state_dict[hf_name] = split_tensor[0] + elif len(split_tensor) == 3: + # split qkv + for n, d in zip(["q", "k", "v"], split_tensor, strict=True): + state_dict[hf_name.replace("qkv", n)] = d + elif len(split_tensor) == 2: + # split gate up + state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] + state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] + shape_info = ( + split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] + ) + print(f"converted {key} to {hf_name} with shape {shape_info}") + + layers_cum += layers_handled + 1 # zero based + + return state_dict + + def save_hf_model_and_tokenizer(self, merged_state_dict): + if self.world_size == 1: + return super().save_hf_model_and_tokenizer(merged_state_dict) + + from safetensors.torch import save_file + + layer_num = self.hf_config.num_hidden_layers + + # FIXME: make configurable + saves_per_layer = 1 if layer_num < 30 else 2 + saves_total = saves_per_layer * layer_num + saves_indexes = {} + + # calculate the layer start index and key chunks + layer_this_rank = self.pipeline_shards[self.rank] + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + keys = list(merged_state_dict.keys()) + keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer) + numel = 0 + + assert len(keys_chunk) == layer_this_rank * saves_per_layer, ( + f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}." + ) + + # save to model shards manually + target_dir = Path(self.config.target_dir) + for i, keys in enumerate(keys_chunk): + sd_to_save = {k: merged_state_dict[k] for k in keys} + numel += sum([sd_to_save[i].numel() for i in sd_to_save]) + save_idx = layer_start * saves_per_layer + i + save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors" + + save_file(sd_to_save, save_path) + for k in keys: + saves_indexes[k] = str(save_path.name) + + tensor = torch.tensor([numel]).to(get_device_name()) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + numel = tensor.cpu().item() + + all_save_indexes = [{} for _ in range(self.world_size)] + dist.all_gather_object(all_save_indexes, saves_indexes) + saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()} + if self.rank == 0: + with open(target_dir / "model.safetensors.index.json", "w") as f: + json.dump( + { + "metadata": { + "total_size": numel, + }, + "weight_map": saves_indexes, + }, + f, + indent=4, + ) + print(f"model saved to {target_dir} with {numel=}") + + self.model_config.save_pretrained(self.config.target_dir) + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def merge_and_save(self): + from verl.utils.megatron_utils import get_dist_checkpoint_path + + model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) + + model_state_dict = self._load_state_dicts(model_ckpt_path) + merged_state_dict = self._merge_state_dicts(model_state_dict) + del model_state_dict + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + """ + Compares the merged Megatron state_dict against a reference safetensors model. + Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. + """ + ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") + + for name, loaded_weight in state_dict.items(): + # name = self._replace_name(original_name, self.params_mapping) + if not name or name.endswith(".bias") and name not in ref_state_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + if "lm_head.weight" in name: + if self.config.is_value_model or self.config.tie_word_embedding: + continue + if name not in ref_state_dict: + raise RuntimeError(f"key: {name} not exist in state_dict") + param = ref_state_dict[name] + assert loaded_weight.dtype == param.dtype + torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) + + def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: + for m_name, v_name in name_mapping.items(): + if m_name not in megatron_name: + continue + + megatron_name = megatron_name.replace("decoder", "model") + param_name = megatron_name.replace(m_name, v_name) + + return param_name + + return None # Return None if no mapping found + + def cleanup(self): + torch.distributed.destroy_process_group() diff --git a/toolbox/verl/v0.5.0/verl/models/README.md b/toolbox/verl/v0.5.0/verl/models/README.md new file mode 100644 index 000000000..677b92f38 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/README.md @@ -0,0 +1,35 @@ +# Models +Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. +## Adding a New Huggingface Model +### Step 1: Copy the model file from HF to verl +- Add a new file under verl/models/hf +- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf + +### Step 2: Modify the model file to use packed inputs +- Remove all the code related to inference (kv cache) +- Modify the inputs to include only + - input_ids (total_nnz,) + - cu_seqlens (total_nnz + 1,) + - max_seqlen_in_batch: int +- Note that this requires using flash attention with causal mask. + +### Step 2.5: Add tests +- Add a test to compare this version and the huggingface version +- Following the infrastructure and add tests to tests/models/hf + +### Step 3: Add a function to apply tensor parallelism +- Please follow + - https://pytorch.org/docs/stable/distributed.tensor.parallel.html + - https://pytorch.org/tutorials/intermediate/TP_tutorial.html +- General comments + - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. + +### Step 4: Add a function to apply data parallelism +- Please use FSDP2 APIs +- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 + +### Step 5: Add a function to apply pipeline parallelism +- Comes in Pytorch 2.4 +- Currently only in alpha in nightly version +- Check torchtitan for more details + diff --git a/toolbox/verl/v0.5.0/verl/models/__init__.py b/toolbox/verl/v0.5.0/verl/models/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/llama/__init__.py b/toolbox/verl/v0.5.0/verl/models/llama/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/__init__.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/__init__.py new file mode 100644 index 000000000..fc851ea43 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_llama_megatron import ( + ParallelLlamaForCausalLM, + # rmpad with megatron + ParallelLlamaForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelLlamaForCausalLMRmPadPP, + ParallelLlamaForValueRmPad, + ParallelLlamaForValueRmPadPP, + # original model with megatron + ParallelLlamaModel, +) + +__all__ = [ + "ParallelLlamaForCausalLM", + "ParallelLlamaForCausalLMRmPad", + "ParallelLlamaForCausalLMRmPadPP", + "ParallelLlamaForValueRmPad", + "ParallelLlamaForValueRmPadPP", + "ParallelLlamaModel", +] diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/__init__.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader.py new file mode 100644 index 000000000..dafecfdf0 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -0,0 +1,317 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor.data.copy_(state_dict[name]) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py new file mode 100644 index 000000000..2f65bc6b1 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -0,0 +1,458 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_saver.py new file mode 100644 index 000000000..595efcde3 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -0,0 +1,442 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2 + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + print_rank_0("collecting lm_head...") + + if is_value_model: + if pp_rank == pp_size - 1: + print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + if dtype not in [torch.float16, torch.bfloat16, torch.float32]: + print(f'Unknown/unsupported dtype to save: {dtype}"') + exit(1) + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/__init__.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/__init__.py new file mode 100644 index 000000000..352bc5608 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelLlamaAttention +from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad +from .parallel_linear import ( + LinearForLastLayer, + MergedColumnParallelLinear, + QKVParallelLinear, +) +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + +__all__ = [ + "LinearForLastLayer", + "MergedColumnParallelLinear", + "QKVParallelLinear", + "ParallelLlamaAttention", + "ParallelLlamaDecoderLayer", + "ParallelLlamaDecoderLayerRmPad", + "ParallelLlamaMLP", + "ParallelLlamaRMSNorm", +] diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_attention.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_attention.py new file mode 100644 index 000000000..e8aacbdb7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_attention.py @@ -0,0 +1,460 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import LlamaConfig +from transformers.utils import is_flash_attn_2_available + +from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): + super().__init__(dim, max_position_embeddings, base, device) + + self.factor = config.rope_scaling["factor"] # `8` in the original implementation + self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation + self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation + self.old_context_len = config.rope_scaling[ + "original_max_position_embeddings" + ] # `8192` in the original implementation + + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + + wavelen = 2 * math.pi / self.inv_freq + # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + bias=config.attention_bias, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.attention_bias, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" + scaling_type = self.config.rope_scaling[rope_type_key] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "llama3": + self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding( + self.head_dim, + self.config, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return q_embed, k_embed + + +class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_decoder.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_decoder.py new file mode 100644 index 000000000..f46e9457c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + + +class ParallelLlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelLlamaDecoderLayerRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_linear.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_linear.py new file mode 100644 index 000000000..043726c46 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_linear.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + +import torch +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class LinearForLastLayer(torch.nn.Linear): + def __init__( + self, + input_size, + output_size, + *, + config, + bias=True, + ): + super().__init__(in_features=input_size, out_features=output_size, bias=bias) + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel: + self.weight.sequence_parallel = True + + def forward( + self, + input_, + weight=None, + runtime_gather_output=None, + ): + logits = super().forward(input_) + logits = logits.float() + if self.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits, None diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_mlp.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_mlp.py new file mode 100644 index 000000000..583a317eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelLlamaMLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 000000000..bc2e9ae36 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelLlamaRMSNorm(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/llama/megatron/modeling_llama_megatron.py b/toolbox/verl/v0.5.0/verl/models/llama/megatron/modeling_llama_megatron.py new file mode 100644 index 000000000..ed5022e0c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/llama/megatron/modeling_llama_megatron.py @@ -0,0 +1,688 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model with Megatron-style acceleration.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import CausalLMOutputWithPast + +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from meta LLama pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelLlamaModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLM(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelLlamaModel(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +class ParallelLlamaModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelLlamaModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.layers = nn.ModuleList() + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: LlamaConfig, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + assert share_embeddings_and_output_weights is False, ( + "Llama Model not supports sharing embedding and output weights" + ) + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # logits shape before forward_head hidden_states.shape: [4, 32, 4096] + logits = self.lm_head(hidden_states)[0] + # logits shape after forward_head logits.shape: [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/__init__.py b/toolbox/verl/v0.5.0/verl/models/mcore/__init__.py new file mode 100644 index 000000000..29d053177 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .registry import ( + get_mcore_forward_fn, + get_mcore_forward_fused_fn, + get_mcore_weight_converter, + hf_to_mcore_config, + init_mcore_model, +) + +__all__ = [ + "hf_to_mcore_config", + "init_mcore_model", + "get_mcore_forward_fn", + "get_mcore_weight_converter", + "get_mcore_forward_fused_fn", +] diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/config_converter.py b/toolbox/verl/v0.5.0/verl/models/mcore/config_converter.py new file mode 100644 index 000000000..95fe7e416 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/config_converter.py @@ -0,0 +1,392 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# convert huggingface config to mcore transformer config + + +import warnings + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state as mpu +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from transformers import PretrainedConfig + + +def _get_base_transformer_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: + """ + Create a base TransformerConfig with common parameters across different model architectures. + TODO: (ycl) use dataclass or converter config? + + Args: + hf_config: HuggingFace model configuration + dtype: Data type for the model + override_transformer_config_kwargs: Additional parameters to override defaults + + Returns: + TransformerConfig with common parameters + """ + + # Common parallel state parameters + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) + batch_p2p_comm = False + + # Base configuration with common parameters + base_config = { + # Model architecture parameters + "num_layers": hf_config.num_hidden_layers, + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_query_groups": hf_config.num_key_value_heads, + "ffn_hidden_size": hf_config.intermediate_size, + "attention_dropout": hf_config.attention_dropout, + "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), + "kv_channels": getattr(hf_config, "head_dim", None), + "layernorm_epsilon": hf_config.rms_norm_eps, + "add_bias_linear": True, + # Activation and normalization + "activation_func": F.silu, + "normalization": "RMSNorm", + "gated_linear_unit": True, + # Data types + "pipeline_dtype": dtype, + "params_dtype": dtype, + "bf16": dtype is torch.bfloat16, + # Parallel configuration + "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(), + "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(), + "expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(), + "expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(), + "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(), + "context_parallel_size": mpu.get_context_parallel_world_size(), + "overlap_p2p_comm": overlap_p2p_comm, + "batch_p2p_comm": batch_p2p_comm, + "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1, + # Common settings + "variable_seq_lengths": True, + "masked_softmax_fusion": True, + "moe_token_dispatcher_type": "alltoall", + } + + # Update with any provided overrides + # override_transformer_config_kwargs as kwargs shall never be none + base_config.update(override_transformer_config_kwargs) + + return base_config + + +def _get_mla_transformer_config( + hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: + """ + Create a MLATransformerConfig with common parameters across different model architectures. + This is specifically for MLA models like DeepseekV3. + + Args: + hf_config: HuggingFace model configuration + mla_rope_config: MLA specific RoPE configuration + dtype: Data type for the model + override_transformer_config_kwargs: Additional parameters to override defaults + + Returns: + MLATransformerConfig with common parameters + """ + base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) + mla_config = { + # MLA specific parameters + "q_lora_rank": hf_config.q_lora_rank, + "kv_lora_rank": hf_config.kv_lora_rank, + "qk_head_dim": hf_config.qk_nope_head_dim, + "qk_pos_emb_head_dim": hf_config.qk_rope_head_dim, + "v_head_dim": hf_config.v_head_dim, + "rotary_base": hf_config.rope_theta, + "rotary_scaling_factor": mla_rope_config["factor"], + "rope_type": mla_rope_config["type"], + "max_position_embeddings": mla_rope_config["original_max_position_embeddings"], + "beta_fast": mla_rope_config["beta_fast"], + "beta_slow": mla_rope_config["beta_slow"], + "mscale": mla_rope_config["mscale"], + "mscale_all_dim": mla_rope_config["mscale_all_dim"], + } + + base_config.update(mla_config) + return base_config + + +def check_and_disable_incompatible_configs(original_config: dict) -> dict: + """ + Check and disable incompatible configurations for older Megatron version. + + Args: + original_config (dict): The original model configuration. + + Returns: + dict: The updated model configuration with incompatible settings disabled. + """ + removed_keys = [] + for key in original_config.keys(): + if not hasattr(TransformerConfig, key): + removed_keys.append(key) + if removed_keys: + warnings.warn( + f"The following keys are not supported in the current Megatron version and will be removed: {removed_keys}", + stacklevel=2, + ) + for key in removed_keys: + original_config.pop(key) + return original_config + + +def hf_to_mcore_config_dense( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # for LlamaForCausalLM or Qwen2ForCausalLM + qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) + qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False + + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + add_qkv_bias=qkv_bias, + qk_layernorm=qk_layernorm, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) + + +def hf_to_mcore_config_qwen2moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_shared_expert_overlap=True, + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=True, + add_qkv_bias=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) + + +def hf_to_mcore_config_mixtral( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + num_moe_experts=hf_config.num_local_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + moe_router_topk=hf_config.num_experts_per_tok, + moe_router_pre_softmax=True, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_router_score_function="softmax", + moe_shared_expert_intermediate_size=None, # mixtral has no shared expert + moe_shared_expert_overlap=False, # mixtral has no shared expert + moe_ffn_hidden_size=hf_config.intermediate_size, + moe_router_bias_update_rate=0.001, + # moe_permute_fusion=True, # need TE 2.1+ + moe_grouped_gemm=True, + # Other optimizations + persist_layer_norm=True, + apply_rope_fusion=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) + + +def hf_to_mcore_config_qwen3moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=False, + qk_layernorm=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) + + +def hf_to_mcore_config_dpskv3( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> MLATransformerConfig: + # DeepseekV3ForCausalLM + from megatron.core.transformer.enums import AttnBackend + + from .patch_v012 import apply_patch + + apply_patch() + + mla_rope_config = { + "beta_fast": 32, + "beta_slow": 1, + "factor": 1, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "rope", + } + if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: + mla_rope_config.update(hf_config.rope_scaling) + moe_layer_freq = [1] * hf_config.num_hidden_layers + for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)): + moe_layer_freq[i] = 0 + + # disable MTP and quantization for now + if "num_nextn_predict_layers" in hf_config: + assert hf_config.num_nextn_predict_layers == 0, ( + "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" + ) + assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( + "quantization is not supported for now, please modify the config.json to remove quantization_config" + ) + + args: dict = _get_mla_transformer_config( + hf_config=hf_config, + mla_rope_config=mla_rope_config, + dtype=dtype, + # Additional parameters + use_cpu_initialization=False, + add_bias_linear=False, + attention_backend=AttnBackend.fused, + qk_layernorm=True, + # Standard MoE parameters + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_token_dispatcher_type="alltoall", + moe_router_bias_update_rate=0.001, + moe_router_enable_expert_bias=True, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.n_routed_experts, + moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, + moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), + moe_router_load_balancing_type="seq_aux_loss", + moe_shared_expert_overlap=True, + # moe_permute_fusion=True, # need TE 2.1+ + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_pre_softmax=True, + moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, + moe_layer_freq=moe_layer_freq, + # mcore 0.12 moe + moe_router_dtype="fp64", + disable_bf16_reduced_precision_matmul=True, + # Other optimizations + # deallocate_pipeline_outputs=True, + # gradient_accumulation_fusion=True, + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + transformer_config: MLATransformerConfig = MLATransformerConfig(**args) + print(f"Overridden MLA TF init config: {transformer_config}") + # MTP + if "num_nextn_predict_layers" in hf_config: + transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers + transformer_config.mtp_loss_scaling_factor = 0.1 + + return transformer_config + + +def hf_to_mcore_config_qwen2_5_vl( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # Qwen2_5_VLForConditionalGeneration + + args = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + add_bias_linear=False, + # qwen specific + add_qkv_bias=True, + mrope_section=hf_config.rope_scaling["mrope_section"], + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = check_and_disable_incompatible_configs(args) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) + + +def hf_to_mcore_config_llama4( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # Llama4ForConditionalGeneration + raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/loader.py b/toolbox/verl/v0.5.0/verl/models/mcore/loader.py new file mode 100644 index 000000000..659b4baa2 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/loader.py @@ -0,0 +1,492 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + +from .saver import _megatron_calc_global_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == src_rank: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.decoder.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == src_rank: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=src_rank, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + if config.num_key_value_heads >= tp_size: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + num_query_groups_per_partition = models[0].config.num_query_groups // tp_size + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) + k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) + v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) + total_size_per_head = total_size // num_query_groups_per_partition + for j in range(num_query_groups_per_partition): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) + + else: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) + k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) + v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) + total_size_per_head = total_size // config.num_attention_heads + for j in range(config.num_attention_heads): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + layer_name = f"model.layers.{layer}" + print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + if f"{layer_name}.self_attn.q_norm.weight" in state_dict: + _broadcast_tensor( + sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_norm.weight", + ) + _broadcast_tensor( + sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.k_norm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + if f"{layer_name}.self_attn.q_proj.bias" in state_dict: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.output_layer.weight + + if is_value_model: + # if torch.distributed.get_rank() == src_rank: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + # else: + + # _broadcast_tensor(lm_head_weight, "lm_head.weight") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + pass + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/mbridge.py b/toolbox/verl/v0.5.0/verl/models/mcore/mbridge.py new file mode 100644 index 000000000..35c32d697 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/mbridge.py @@ -0,0 +1,23 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from mbridge import AutoBridge + from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model +except ImportError: + print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`") + raise + +__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/model_forward.py b/toolbox/verl/v0.5.0/verl/models/mcore/model_forward.py new file mode 100644 index 000000000..e70e11f4e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/model_forward.py @@ -0,0 +1,148 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.utils.megatron_utils import unwrap_model + +from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding + + +def gptmodel_forward( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel, + value_model=False, + pack_seqs=True, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + """Default forward pass for GPT models with optional sequence packing.""" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + ) + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + else: + assert logits_processor is None, "logits_processor is not supported for non-packed sequence" + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( + input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + ) + output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) + output = recover_left_padding( + output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + if value_model and post_process: + output = output[..., 0] + return output + + +def gptmodel_forward_qwen2_5_vl( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel, + value_model=False, + pack_seqs=True, + multi_modal_inputs=None, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + from megatron.core import parallel_state as mpu + + assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + pixel_values = ( + multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + ) + image_grid_thw = ( + multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + ) + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + else: + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( + input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + ) + output = model( + input_ids=new_input_ids, + position_ids=new_position_ids, + attention_mask=new_attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + output = recover_left_padding( + output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + if value_model and post_process: + output = output[..., 0] + return output diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/model_forward_fused.py b/toolbox/verl/v0.5.0/verl/models/mcore/model_forward_fused.py new file mode 100644 index 000000000..fc55ef1b0 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/model_forward_fused.py @@ -0,0 +1,327 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Optional + +import torch +from megatron.core import parallel_state +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from torch import Tensor + +from verl.models.mcore.util import preprocess_packed_seqs +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy +from verl.utils.megatron_utils import unwrap_model +from verl.utils.model import CausalLMOutputForPPO + +from .qwen2_5_vl.model import Qwen2_5VLModel +from .util import postprocess_packed_seqs_for_dict_output + + +def patch_fused_forward(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + model = model + elif isinstance(model, Qwen2_5VLModel): + if not hasattr(model, "language_model"): + # the qwen2.5vl model might only have vision_model + return + model = model.language_model + else: + raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") + model.forward_backup = model.forward + model.forward = _fused_GPTModel_forward.__get__(model, model.__class__) + return + + +def unpatch_fused_forward(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + model = model + elif isinstance(model, Qwen2_5VLModel): + model = model.language_model + else: + raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") + model.forward = model.forward_backup + return + + +def fused_forward_gptmodel( + model: GPTModel, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor, + labels_mask: Tensor, + **kwargs, +): + pre_process: bool = unwrap_model(model).pre_process + post_process: bool = unwrap_model(model).post_process + + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + + output_orig: CausalLMOutputForPPO = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + labels=labels_rmpad, + packed_seq_params=packed_seq_params, + ) + + if post_process: + # output_orig is in type of CausalLMOutputForPPO + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + else: + output = output_orig + return output + + +def fused_forward_qwen2_5_vl( + model: Qwen2_5VLModel, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor, + labels_mask: Tensor, + multi_modal_inputs=None, + **kwargs, +): + # pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + + pixel_values = ( + multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + ) + image_grid_thw = ( + multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + ) + + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig: CausalLMOutputForPPO = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + labels=labels, + ) + if post_process: + # output_orig is in type of CausalLMOutputForPPO + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + else: + output = output_orig + return output + + +def _fused_GPTModel_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + temperature: float = 1.0, +) -> CausalLMOutputForPPO: + """ + Forward pass for GPT models with fused kernel support. + + Patch https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py + """ + + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + rotary_pos_cos = None + rotary_pos_sin = None + if self.position_embedding_type == "rope" and not self.config.multi_latent_attention: + if not self.training and self.config.flash_decode and inference_context: + assert inference_context.is_static_batching(), "GPTModel currently only supports static inference batching." + # Flash decoding uses precomputed cos and sin for RoPE + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( + inference_context.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), + ) + else: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == "thd", + ) + elif self.position_embedding_type == "mrope" and not self.config.multi_latent_attention: + if self.training or not self.config.flash_decode: + rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) + else: + # Flash decoding uses precomputed cos and sin for RoPE + raise NotImplementedError( + "Flash decoding uses precomputed cos and sin for RoPE, not implmented in MultimodalRotaryEmbedding yet." + ) + + if ( + (self.config.enable_cuda_graph or self.config.flash_decode) + and rotary_pos_cos is not None + and inference_context + and inference_context.is_static_batching() + and not self.training + ): + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * inference_context.current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # skip inference + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + ) + + # Process inference output. + if inference_context and not inference_context.is_static_batching(): + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if self.mtp_process: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + compute_language_model_loss=self.compute_language_model_loss, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + output = CausalLMOutputForPPO( + loss=None, + logits=None, + past_key_values=None, + hidden_states=hidden_states, + attentions=None, + ) + + if self.config.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + logprobs, entropy = linear_cross_entropy( + hidden_states, + self.output_layer.weight, + labels, + temperature, + "none", + parallel_state.get_tensor_model_parallel_group(), + ) + + if has_config_logger_enabled(self.config): + payload = OrderedDict( + { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "decoder_input": decoder_input, + "logprobs": logprobs, + "entropy": entropy, + } + ) + log_config_to_disk(self.config, payload, prefix="input_and_logits") + + output.entropy = entropy + output.log_probs = logprobs + + return output diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/model_initializer.py b/toolbox/verl/v0.5.0/verl/models/mcore/model_initializer.py new file mode 100644 index 000000000..4c01b124b --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/model_initializer.py @@ -0,0 +1,263 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# use mcore transformer config to initialize the model +from abc import ABC, abstractmethod + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec +from megatron.core.models.gpt.gpt_model import GPTModel + +from .config_converter import PretrainedConfig, TransformerConfig + + +class BaseModelInitializer(ABC): + """Base class for model initializers.""" + + def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): + self.tfconfig = tfconfig + self.hf_config = hf_config + + @abstractmethod + def get_transformer_layer_spec(self): + """Get the transformer layer specification. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" + pass + + def get_rope_scaling_args(self) -> dict: + """Get rope scaling args.""" + rope_scaling_args = {} + if "rope_scaling" in self.hf_config: + if self.hf_config.rope_scaling is not None: + # assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] + return rope_scaling_args + + def initialize( + self, + pre_process: bool = True, + post_process: bool = True, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, + ) -> GPTModel: + """Initialize a GPT model with the given configuration. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py + + Args: + pre_process (bool): include embedding layer. + post_process (bool): including an output layer. + share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared. + value (bool): add an extra linear layer for classification or regression. + + Returns: + GPTModel: An initialized GPT model instance + """ + transformer_layer_spec = self.get_transformer_layer_spec() + rope_scaling_args = self.get_rope_scaling_args() + mtp_block_spec = extra_kwargs.get("mtp_block_spec", None) + model = GPTModel( + config=self.tfconfig, + transformer_layer_spec=transformer_layer_spec, + vocab_size=self.hf_config.vocab_size, + max_sequence_length=self.hf_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type="rope", + rotary_base=self.hf_config.rope_theta, + **rope_scaling_args, + mtp_block_spec=mtp_block_spec, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + model.output_layer = LinearForLastLayer( + input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig + ) + + return model + + +class DenseModel(BaseModelInitializer): + """Initializer for dense models like Llama and Qwen2.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + + +class Qwen2MoEModel(BaseModelInitializer): + """Initializer for Qwen2 MoE models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + + # Patch layer spec for shared experts + for i in range(len(transformer_layer_spec.layer_specs)): + transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True + + return transformer_layer_spec + + def initialize(self, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class MixtralModel(BaseModelInitializer): + """Initializer for Mixtral models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize(self, **kwargs): + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", False) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen3MoEModel(BaseModelInitializer): + """Initializer for Qwen3 MoE models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize(self, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class DeepseekV3Model(BaseModelInitializer): + """Initializer for DeepseekV3 models.""" + + def get_transformer_layer_spec(self): + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def get_rope_scaling_args(self) -> dict: + """Get rope scaling args.""" + rope_scaling_args = {} + return rope_scaling_args + + def initialize( + self, + **kwargs, + ): + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + self.tfconfig.moe_router_load_balancing_type = "none" + # MTP + if self.tfconfig.mtp_num_layers is not None: + transformer_layer_spec = self.get_transformer_layer_spec() + mtp_block_spec = get_gpt_mtp_block_spec(self.tfconfig, transformer_layer_spec, use_transformer_engine=True) + kwargs["mtp_block_spec"] = mtp_block_spec + + model = super().initialize(**kwargs) + if freeze_moe_router: + for layer in model.decoder.layers: + if hasattr(layer.mlp, "router"): + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen25VLModel(BaseModelInitializer): + """Initializer for Qwen2.5 VL models.""" + + def get_transformer_layer_spec(self): + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize( + self, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, + **extra_kwargs, + ): + tfconfig = self.tfconfig + hf_config = self.hf_config + # Qwen2_5_VLForConditionalGeneration + from copy import deepcopy + + transformer_layer_spec = self.get_transformer_layer_spec() + + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear + from megatron.core.models.gpt.moe_module_specs import MLPSubmodules + from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec + + from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config + + vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + vision_projection_config = get_vision_projection_config( + deepcopy(tfconfig), + vision_transformer_config.hidden_size, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + ) + vision_projection_layer_spec = MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + + qwen25_vl_model = Qwen2_5VLModel( + language_transformer_config=tfconfig, + language_transformer_layer_spec=transformer_layer_spec, + language_vocab_size=hf_config.vocab_size, + language_max_sequence_length=hf_config.max_position_embeddings, + vision_transformer_config=vision_transformer_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + language_rotary_base=hf_config.rope_theta, + pre_process=pre_process, + post_process=post_process, + add_decoder=True, + add_encoder=True, + parallel_output=True, + language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + qwen25_vl_model.language_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) + + return qwen25_vl_model diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/patch_v012.py b/toolbox/verl/v0.5.0/verl/models/mcore/patch_v012.py new file mode 100644 index 000000000..d54a3eb34 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/patch_v012.py @@ -0,0 +1,215 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# there is some bug in mcore 0.12, so we need to patch it +# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None + + +def apply_patch(): + import torch + from megatron.core import parallel_state, tensor_parallel + from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + apply_rotary_pos_emb, + deprecate_inference_params, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, + ) + + def patch_get_query_key_value_tensors( + self, + hidden_states, + key_value_states=None, + position_ids=None, + packed_seq_params=None, + inference_context=None, + *, + inference_params=None, + ): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # s = sequence length, b = batch size, h = hidden size, n = num attention heads + # Attention heads [s, b, n*h] + assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # ========================================= + # Prepare RoPE and seqlen related params + # ========================================= + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, None, hidden_states, self.config, packed_seq_params + ) + + # rotary_pos_emb:[s, b, 1, 64] + mscale = 1.0 + if self.config.rope_type == "rope": + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + else: + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) + + # ========================================= + # QKV down projection and layernorm + # ========================================= + if self.config.q_lora_rank is not None: + # if linear_q_down_proj is ColumnParallelLinear: + # q_compressed: [s, b, q_lora_rank / TP] + # elif linear_q_down_proj is Linear: + # q_compressed: [s / TP, b, q_lora_rank] + q_compressed, _ = self.linear_q_down_proj(hidden_states) + + # When output is sharded (ColumnParallelLinear), two things are needed to be + # identical to a normal Linear. + # 1. Manually gather output to restore output dim q_lora_rank; + # 2. Scatter sequence back to s / TP if sequence-parallel since it was + # gathered by ColumnParallelLinear. + if q_compressed.size(-1) != self.config.q_lora_rank: + q_compressed = gather_from_tensor_model_parallel_region(q_compressed) + if self.config.sequence_parallel: + q_compressed = scatter_to_sequence_parallel_region(q_compressed) + + q_compressed = self.q_layernorm(q_compressed) + else: + q_compressed = hidden_states + + # if linear_kv_down_proj is ColumnParallelLinear: + # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] + # elif linear_kv_down_proj is Linear: + # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] + kv_combined, _ = self.linear_kv_down_proj(hidden_states) + if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: + # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] + kv_combined = gather_from_tensor_model_parallel_region(kv_combined) + # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) + if self.config.sequence_parallel: + # kv_compressed:[s / TP, b, kv_lora_rank] + kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) + else: + # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) + if parallel_state.get_tensor_model_parallel_world_size() > 1: + # k_pos_emb: [s, b, qk_pos_emb_head_dim] + k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) + + kv_compressed = self.kv_layernorm(kv_compressed) + + # ========================================= + # QKV up projection and RoPE apply + # ========================================= + def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): + if self.config.q_lora_rank is not None: + q, _ = self.linear_q_up_proj(q_compressed) + else: + # hidden_states:[s, b, 2048], q: [s, b, n * 192] + q, _ = self.linear_q_proj(q_compressed) + + q_len, bsz, _ = q.size() + + # q: [s, b, n, 192] + q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) + + # kv: [s, b, 2048] + kv, _ = self.linear_kv_up_proj(kv_compressed) + + # kv: [s, b, n, 256] + kv = kv.view( + q_len, + bsz, + self.num_attention_heads_per_partition, + self.config.qk_head_dim + self.config.v_head_dim, + ) + + if inference_context is not None: + # add offset to the sequence start for inference + sequence_start = inference_context.sequence_len_offset + sequence_end = sequence_start + q_len + rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] + else: + # Shorten rotary_pos_emb to the sequence length when inference_params + # is not provided. This makes sure we can run forward directly with + # any sequence length. During training, the sequence length is always + # the full rotary_pos_emb length. + rotary_pos_emb = rotary_pos_emb[0:q_len] + + # [s, b, 64] -> [s, b, 1, 64] + k_pos_emb = torch.unsqueeze(k_pos_emb, 2) + + # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] + q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) + + # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] + k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + q_pos_emb = q_pos_emb.squeeze(1) + k_pos_emb = k_pos_emb.squeeze(1) + q_no_pe = q_no_pe.squeeze(1) + k_no_pe = k_no_pe.squeeze(1) + value = value.squeeze(1) + else: + cu_seqlens_q = cu_seqlens_kv = None + + # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64] + q_pos_emb = apply_rotary_pos_emb( + q_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + mscale=mscale, + ) + k_pos_emb = apply_rotary_pos_emb( + k_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + mscale=mscale, + ) + + # query: [s, b, n, 192] + query = torch.cat([q_no_pe, q_pos_emb], dim=-1) + if packed_seq_params is not None: + k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) + key = torch.cat([k_no_pe, k_pos_emb], dim=-1) + else: + # key: [s, b, n, 192] + k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) + key = torch.cat([k_no_pe, k_pos_emb], dim=-1) + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + return query, key, value + + if self.recompute_up_proj: + self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() + query, key, value = self.qkv_up_checkpoint.checkpoint( + qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) + else: + query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) + + return query, key, value + + MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/__init__.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/__init__.py new file mode 100644 index 000000000..8842d0249 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .model import Qwen2_5VLModel +from .vision_config import get_vision_model_config, get_vision_projection_config + +__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/attention.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/attention.py new file mode 100644 index 000000000..91a27cc3e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/attention.py @@ -0,0 +1,221 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.transformer.attention import * + +from .rope_utils import apply_rotary_pos_emb_absolute + + +class Qwen2_5VLSelfAttention(SelfAttention): + """ + Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute + instead of apply_rotary_pos_emb + """ + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert flash_decode_and_prefill_kernel is not None, ( + "Internal use only: install package `nvidia_chunked_flash_attn`." + ) + + # hidden_states: [sq, b, h] + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + if ( + self.config.flash_decode + and inference_context is not None + and inference_context.is_decode_only() + and not self.training + and rotary_pos_cos is not None + ): + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill( + q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths + ) + core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/model.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/model.py new file mode 100644 index 000000000..74e4406c3 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/model.py @@ -0,0 +1,340 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.models.gpt.gpt_model import GPTModel + +# from .transformer_config import Qwen2VLTransformerConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from .attention import Qwen2_5VLSelfAttention +from .vision_model import Qwen2_5VisionModel + + +# Note: This is under development and may be missing features. +class Qwen2_5VLModel(MegatronModule): + """Qwen2.5VL multi-modal model. + + Args: + language_transformer_config (TransformerConfig): Transformer config for the language model. + language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + language model. + language_vocab_size (int): Language model vocabulary size. + language_max_sequence_length (int): Language model maximum sequence length. This is used for + positional embedding. + vision_transformer_config (TransformerConfig): Transformer config for the vision model. + vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + vision model. + vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to + language model inputs. + vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision + projection. + vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This + is typically True for training and False for inference. + language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings + in the language model. Defaults to 1.0. + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). + Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline + parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + img_h (int): The height of each image that the ViT will see. + img_w (int): The width of each image that the ViT will see. + patch_dim (int): The size of each patch side. + img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be + inserted. Defaults to 0. + """ + + def __init__( + self, + language_transformer_config: TransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + vision_transformer_config: TransformerConfig, + vision_transformer_layer_spec: ModuleSpec, + vision_projection_config: TransformerConfig, + vision_projection_layer_spec: ModuleSpec, + vision_projection_type: str = "mlp", + parallel_output: bool = True, + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + language_rotary_base: int = 10000, + fp16_lm_cross_entropy: bool = False, + language_share_embeddings_and_output_weights: bool = False, + image_token_id: int = 151655, + video_token_id: int = 151656, + ) -> None: + super().__init__(config=language_transformer_config) + + # patch self_attention to use qwen2_5_vl attention + vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + for layer_spec in language_transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + + logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.vision_projection = None + self.language_model = None + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + if self.pre_process: + self.vision_model = Qwen2_5VisionModel( + vision_transformer_config, + vision_transformer_layer_spec, + vision_projection_config, + vision_projection_layer_spec, + projection_type=vision_projection_type, + pre_process=True, + post_process=True, + ) + + self.language_model = GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_rotary_base, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_projection is not None: + modules.append(self.vision_projection) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + ) -> torch.Tensor: + """Forward function of the Qwen2VL model. + + Args: + image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, + combined_seq_len]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + + video_start_index: + 0 -- all video + len(video_seq) -- all image + others -- mixture + *_input_mask: should not be None in the first PP stage + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape + [b, s, vocab_size]. + """ + video_start_index = 0 + vision_grid_thw = None + vision_data = None + if image_grid_thw is not None: + image_mask = input_ids == self.image_token_id + vision_grid_thw = image_grid_thw + vision_data = pixel_values + video_start_index = image_mask.sum().item() + if video_grid_thw is not None: + video_mask = input_ids == self.video_token_id + vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) + vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) + video_start_index = image_mask.sum().item() + video_mask.sum().item() + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + if use_inference_kv_cache: + raise NotImplementedError() + + if self.pre_process: + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + vision_embeds = self.vision_model( + vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) + grid_thw=vision_grid_thw, # should provided in each EPP stage + ) + + # If running inference, the language model KV cache will be updated for image token positions. + # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. + if inference_params is not None: + raise NotImplementedError() + # inference_params.key_value_memory_dict["image_tokens_count"] = ( + # vision_embeddings.shape[0] + # ) + + # If running inference, we can skip image token computation if they were computed already earlier + # for this sample. + if use_inference_kv_cache: + language_embeddings: torch.Tensor = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + # NOTE: why not cat here? is it the combined embeddings useless? + combined_embeddings = language_embeddings + elif vision_embeds is not None: + if video_start_index == 0: + image_embeds = None + video_embeds = vision_embeds + elif video_start_index == vision_embeds.shape[0]: + image_embeds = vision_embeds + video_embeds = None + elif 0 < video_start_index < vision_embeds.shape[0]: + image_embeds = vision_embeds[:video_start_index] + video_embeds = vision_embeds[video_start_index:] + else: + raise ValueError( + f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " + f"{video_start_index}" + ) + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + + if image_embeds is not None or video_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if image_embeds is not None: + image_mask = (input_ids == self.image_token_id).contiguous() + if image_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[image_mask] = image_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + if video_embeds is not None: + video_mask = (input_ids == self.video_token_id).contiguous() + if video_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[video_mask] = video_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + else: + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + from .rope_utils import get_rope_index + + position_ids, _ = get_rope_index( + input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask + ) + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + # inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + **(extra_block_kwargs or {}), + ) + + return output diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/rope_utils.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/rope_utils.py new file mode 100644 index 000000000..fadc74daa --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/rope_utils.py @@ -0,0 +1,266 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from megatron.core.models.common.embeddings.rope_utils import * +from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd +from torch import Tensor + +logger = logging.getLogger(__name__) + + +# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index +def get_rope_index( + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + + Examples: + + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + + Examples: + + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each + second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal + tokens" are conceptually packed into a one-second interval of the video. + In this case, we have 25 tokens per second. So each second of the video will be + represented with 25 separate time points. It essentially defines the temporal + granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * + temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be + have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = 2 + tokens_per_second = 2 + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def apply_rotary_pos_emb_thd_absolute( + t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 + if freqs.shape[1] > 1: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return fused_apply_rotary_pos_emb(t, freqs) + else: + # NOTE: as expected, thd format can use bshd + return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_config.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_config.py new file mode 100644 index 000000000..0631c90f6 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_config.py @@ -0,0 +1,85 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from megatron.core.transformer import TransformerConfig + + +def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: + # Given a Transformer Config from decoder, build vision encoder config + # diff: out_hidden_size & intermediate_size + + # mlp: hidden_size -> intermediate_size -> embed_dim, silu + # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on + if config.num_layers in [28, 36]: + config.ffn_hidden_size = 3420 + else: + config.ffn_hidden_size = 3456 + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth + else: + config.num_layers = 32 # depth + config.num_attention_heads = 16 # num_heads + config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) + config.add_qkv_bias = True # qkv_proj in attn has bias + config.hidden_size = 1280 # hidden_size + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + + # config.gated_linear_unit = False # no gated + # config.activation_func = quick_gelu # hidden_act + config.kv_channels = config.hidden_size // config.num_attention_heads + config.num_query_groups = config.num_attention_heads # no GQA + config.layernorm_zero_centered_gamma = False # False + config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) + config.bias_activation_fusion = False # no swiglu, set false + config.bias_dropout_fusion = False # no dropout, set false + config.attention_softmax_in_fp32 = True # use True + # config.normalization = 'LayerNorm' # use RMSNorm + config.seq_length = 1 + + config.tp_comm_overlap = False + config.sequence_parallel = False + config.temporal_patch_size = 2 + config.patch_size = 14 + config.in_channels = 3 + config.spatial_merge_size = 2 + + config.fullatt_block_indexes = [7, 15, 23, 31] + config._qwen2_5_vl_window_size = 112 + return config + + +def get_vision_projection_config( + config: TransformerConfig, embed_dim: int, spatial_merge_size: int +) -> TransformerConfig: + # merger: + # context_dim = hidden_size * merge_size**2 + # out_hidden_size = hidden_size + # context_dim -> context_dim -> out_hidden_size + # MLP: + # input_size -> ffn_hidden_size -> hidden_size + # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = True + config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) + config.activation_func = torch.nn.functional.gelu + config.tp_comm_overlap = False + config.sequence_parallel = False + return config diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_model.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_model.py new file mode 100644 index 000000000..06b4fd328 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_model.py @@ -0,0 +1,309 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import InferenceParams +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn +from torch.nn import functional as F + +from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs.float() + + +class Qwen2_5VisionModel(VisionModule): + """Qwen2.5 ViT vision model. + + Args: + transformer_config (TransformerConfig): Transformer config. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. + ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. + add_class_token (bool, optional): Include a class token. Defaults to True. + class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. + patch_dim (int): Image patch size. + img_h (int): Input image height. + img_w (int): Input image width. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + projection_config: TransformerConfig, + projection_layer_spec: ModuleSpec, + projection_type: str = "mlp", + pre_process: bool = True, + post_process: bool = False, + ) -> None: + super().__init__(config=transformer_config) + + self.spatial_merge_size = transformer_config.spatial_merge_size + + embed_dim = transformer_config.hidden_size + num_heads = transformer_config.num_attention_heads + temporal_patch_size = transformer_config.temporal_patch_size + patch_size = transformer_config.patch_size + in_channels = transformer_config.in_channels + + self.patch_size = transformer_config.patch_size + self.fullatt_block_indexes = transformer_config.fullatt_block_indexes + self.window_size = transformer_config._qwen2_5_vl_window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.max_sequence_length = transformer_config.seq_length + self.patch_embed = PatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + head_dim = embed_dim // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.model_type = ModelType.encoder_or_decoder + self.pre_process = pre_process + self.post_process = post_process + + # Transformer layers. + # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting + # pipeline parallelism. + # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. + self.decoder = TransformerBlock( + config=transformer_config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=True, + ) + + self.merge_hidden_size = projection_config.ffn_hidden_size + self.square_merge_size = self.merge_hidden_size // embed_dim + + if self.post_process: + self.projection = MultimodalProjector( + projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size + ) + else: + self.projection = None + + self.input_tensor = None + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + """Sets input tensor to the model. + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + if self.pre_process: # always True + self.input_tensor = input_tensor + else: + raise NotImplementedError() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + vision_data: Optional[torch.Tensor], + grid_thw: torch.Tensor, + inference_params: Optional[InferenceParams] = None, + extra_block_kwargs: dict = None, + ) -> torch.Tensor: + """Forward function of the Qwen2 Vision Model. This function passes the input tensors + through the embedding layer and then the transformer. + + Args: + x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame + packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + + Returns: + x (torch.Tensor): output after final transformer block of shape [b, s, h]. + """ + assert grid_thw is not None + assert self.input_tensor is None + assert inference_params is None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + vision_data = self.patch_embed(vision_data) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=vision_data.device, + dtype=torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = vision_data.size() + vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + vision_data = vision_data[window_index, :, :] + vision_data = vision_data.reshape(seq_len, 1, -1) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) + + hidden_states = self.decoder( + hidden_states=vision_data, + attention_mask=None, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), + packed_seq_params_full=self.build_packed_seq_params(grid_thw), + fullatt_block_indexes=self.fullatt_block_indexes, + **(extra_block_kwargs or {}), + ) + + hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) + reverse_indices = torch.argsort(window_index) + return hidden_states[reverse_indices, :] + + def build_packed_seq_params( + self, + grid_thw: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor] = None, + ) -> PackedSeqParams: + # NOTE: each frame is a sequence (rather than each grid) + if grid_thw is not None: + seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = seqlens.cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() + else: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + max_seqlen_q = seqlens.max() + return PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format="thd", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_q, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py new file mode 100644 index 000000000..8f765a0ff --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py @@ -0,0 +1,265 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from megatron.core.transformer.transformer_block import * + + +class Qwen2_5VisionTransformerBlock(TransformerBlock): + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + packed_seq_params_full: PackedSeqParams, + fullatt_block_indexes, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + for index in range(start, end): + if index in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params_now, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + packed_seq_params_full: PackedSeqParams = None, + fullatt_block_indexes=None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Update the inference parameters with the current batch size in case it is variable + if inference_context and not self.training: + inference_context.current_batch_size = hidden_states.size(1) + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + packed_seq_params_full=packed_seq_params_full, + fullatt_block_indexes=fullatt_block_indexes, + ) + else: + for l_no, layer in enumerate(self.layers): + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() + ) + if l_no in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + with self.offload_context, inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params_now, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return hidden_states diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/readme.md b/toolbox/verl/v0.5.0/verl/models/mcore/readme.md new file mode 100644 index 000000000..606dcf189 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/readme.md @@ -0,0 +1,99 @@ +# verl Megatron-Core Models +The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features. + +The migration has been successful with the help of the mcore team and the community. What we have done is: +1. update `Megatron` version to `0.11.0` +2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel` +3. support sequence packing/thd format. +4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`. +5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format. + +We are working on the following features: +- support `Qwen2MoeForCausalLM` +- support `MixtralForCausalLM` +- support `DeepseekV3ForCausalLM` +- support `expert parallel` + +Features we invite the community to contribute: +- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. + - conversion of large models with multiple GPUs + - conversion of large models with single GPU +- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format. +- support llama4 +- support qwen2.5-vl + +To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033). + +## How things work now +To engage the community in contributing, here are the key steps in our mcore integration process and features under development. + +The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two. +main steps: +1. modelling the huggingface model with mcore `GPTModel` + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` +2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format) + - a. bridge the gap between mcore and huggingface weights format and name mapping + - b. online resharding the mcore weights to rollout engine + - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine +3. support the mcore features in verl + - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel` + - b. support recompute and other mcore speed up features + +4. checkpointing + - a. support recovering the verl training. + - b. support exporting the mcore checkpoint to huggingface format, for downstream inference. + +### Modelling the huggingface model with mcore `GPTModel` +The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`. + +There are two ways of loading the huggingface model weights to the `GPTModel` +1. Runtime loading + - every rank loads the entire huggingface model weights and then shard and convert to mcore weights. + - speed is slow and memory consumption is high. + - this way is deprecated and will not support new models. +2. Offline loading + - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format. + - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low. + - the offline script is in `verl/scripts/converter_hf_to_mcore.py`. + +### online weight conversion from mcore to huggingface +See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details. + +It should be refatored for extensibility and better performance. + +### support the mcore features in verl +Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. +Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching. + +### checkpointing +The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`. + +The existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. + + +## How to support new models +1. make sure the model is supported by vLLM +2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference) + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` + - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module. +3. offline weights conversion from huggingface to mcore `dist_checkpointing` format +4. support online weights conversion from mcore to huggingface + - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. + + +## How to scale up to larger models like deepseek-v3 or other 100B+ models +The greatest challenge for scaling up to larger models is the memory consumption. + +The necessary features under development for scaling up are +1. Training engine part + - expert parallel +2. Rollout engine part + - pipeline parallel + - expert parallel + - more efficient and general weight resharding and loading +3. Offline weights conversion + - support weights larger than single GPU memory diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/registry.py b/toolbox/verl/v0.5.0/verl/models/mcore/registry.py new file mode 100644 index 000000000..23f01e8b7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/registry.py @@ -0,0 +1,237 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Registry module for model architecture components. +""" + +from enum import Enum +from typing import Callable + +import torch +import torch.nn as nn + +from .config_converter import ( + PretrainedConfig, + TransformerConfig, + hf_to_mcore_config_dense, + hf_to_mcore_config_dpskv3, + hf_to_mcore_config_llama4, + hf_to_mcore_config_mixtral, + hf_to_mcore_config_qwen2_5_vl, + hf_to_mcore_config_qwen2moe, + hf_to_mcore_config_qwen3moe, +) +from .model_forward import ( + gptmodel_forward, + gptmodel_forward_qwen2_5_vl, +) +from .model_forward_fused import ( + fused_forward_gptmodel, + fused_forward_qwen2_5_vl, +) +from .model_initializer import ( + BaseModelInitializer, + DeepseekV3Model, + DenseModel, + MixtralModel, + Qwen2MoEModel, + Qwen3MoEModel, + Qwen25VLModel, +) +from .weight_converter import ( + McoreToHFWeightConverterDense, + McoreToHFWeightConverterDpskv3, + McoreToHFWeightConverterMixtral, + McoreToHFWeightConverterQwen2_5_VL, + McoreToHFWeightConverterQwen2Moe, + McoreToHFWeightConverterQwen3Moe, +) + + +class SupportedModel(Enum): + LLAMA = "LlamaForCausalLM" # tested + QWEN2 = "Qwen2ForCausalLM" # tested + QWEN2_MOE = "Qwen2MoeForCausalLM" # pending + DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested + MIXTRAL = "MixtralForCausalLM" # tested + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported + LLAMA4 = "Llama4ForConditionalGeneration" # not tested + QWEN3 = "Qwen3ForCausalLM" # tested + QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested + + +# Registry for model configuration converters +MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { + SupportedModel.LLAMA: hf_to_mcore_config_dense, + SupportedModel.QWEN2: hf_to_mcore_config_dense, + SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, + SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3, + SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, + SupportedModel.LLAMA4: hf_to_mcore_config_llama4, + SupportedModel.QWEN3: hf_to_mcore_config_dense, + SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, +} + +# Registry for model initializers +MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = { + SupportedModel.LLAMA: DenseModel, + SupportedModel.QWEN2: DenseModel, + SupportedModel.QWEN2_MOE: Qwen2MoEModel, + SupportedModel.MIXTRAL: MixtralModel, + SupportedModel.DEEPSEEK_V3: DeepseekV3Model, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, + SupportedModel.LLAMA4: DenseModel, + SupportedModel.QWEN3: DenseModel, + SupportedModel.QWEN3_MOE: Qwen3MoEModel, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, +} + +# Registry for model forward functions +MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: gptmodel_forward, + SupportedModel.QWEN2: gptmodel_forward, + SupportedModel.QWEN2_MOE: gptmodel_forward, + SupportedModel.MIXTRAL: gptmodel_forward, + SupportedModel.DEEPSEEK_V3: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward, + SupportedModel.LLAMA4: gptmodel_forward, + SupportedModel.QWEN3: gptmodel_forward, + SupportedModel.QWEN3_MOE: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, + SupportedModel.DEEPSEEK_V3: gptmodel_forward, +} + +# Registry for model forward functions +MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: fused_forward_gptmodel, + SupportedModel.QWEN2: fused_forward_gptmodel, + SupportedModel.QWEN2_MOE: fused_forward_gptmodel, + SupportedModel.MIXTRAL: fused_forward_gptmodel, + SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, + SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, + SupportedModel.LLAMA4: fused_forward_gptmodel, + SupportedModel.QWEN3: fused_forward_gptmodel, + SupportedModel.QWEN3_MOE: fused_forward_gptmodel, + SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, + SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, +} + +# Registry for model weight converters +MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = { + SupportedModel.LLAMA: McoreToHFWeightConverterDense, + SupportedModel.QWEN2: McoreToHFWeightConverterDense, + SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, + SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral, + SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, + SupportedModel.QWEN3: McoreToHFWeightConverterDense, + SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, + SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, +} + + +def get_supported_model(model_type: str) -> SupportedModel: + try: + return SupportedModel(model_type) + except ValueError as err: + supported_models = [e.value for e in SupportedModel] + raise NotImplementedError( + f"Model Type: {model_type} not supported. Supported models: {supported_models}" + ) from err + + +def hf_to_mcore_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + """Convert huggingface PretrainedConfig to mcore TransformerConfig. + + Args: + hf_config: The huggingface PretrainedConfig. + dtype: The dtype of the model. + **override_transformer_config_kwargs: The kwargs to override the transformer config. + + Returns: + The mcore TransformerConfig. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) + + +def init_mcore_model( + tfconfig: TransformerConfig, + hf_config: PretrainedConfig, + pre_process: bool = True, + post_process: bool = None, + *, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, # may be used for vlm and moe +) -> nn.Module: + """ + Initialize a Mcore model. + + Args: + tfconfig: The transformer config. + hf_config: The HuggingFace config. + pre_process: Optional pre-processing function. + post_process: Optional post-processing function. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + value: Whether to use value. + **extra_kwargs: Additional keyword arguments. + + Returns: + The initialized model. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + initializer_cls = MODEL_INITIALIZER_REGISTRY[model] + initializer = initializer_cls(tfconfig, hf_config) + return initializer.initialize( + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + value=value, + **extra_kwargs, + ) + + +def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_FORWARD_REGISTRY[model] + + +def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_FORWARD_FUSED_REGISTRY[model] + + +def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: + """ + Get the weight converter for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + tfconfig = hf_to_mcore_config(hf_config, dtype) + return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/saver.py b/toolbox/verl/v0.5.0/verl/models/mcore/saver.py new file mode 100644 index 000000000..2a954b241 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/saver.py @@ -0,0 +1,497 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank( + tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 +): + """Calculate global rank with support for CP/EP parallelism""" + + # Get parallel sizes for each dimension + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + # ep_size = mpu.get_expert_model_parallel_world_size() + + # Verify total GPU count matches (must be consistent with parallel_state.py) + total_size = tp_size * dp_size * pp_size * cp_size + assert total_size == torch.distributed.get_world_size(), ( + f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + ) + + # Core calculation logic (corresponds to RankGenerator order parameter) + # Assumes default order is "tp-cp-ep-dp-pp" + return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].decoder.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].decoder.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + if config.num_key_value_heads >= tp_size: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + if gpt_model_module.config.qk_layernorm: + _broadcast_tensor( + sync_layer.self_attention.q_layernorm.weight, + f"{layer_name}.self_attn.q_norm.weight", + src_pp_rank=src_pp_rank, + ) + _broadcast_tensor( + sync_layer.self_attention.k_layernorm.weight, + f"{layer_name}.self_attn.k_norm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + if gpt_model_module.config.add_qkv_bias: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + lm_head_weight = None + if pp_rank == pp_size - 1: + lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) + _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict + + +def merge_megatron_ckpt_gptmodel_qwen_moe( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") + + +def merge_megatron_ckpt_gptmodel_qwen2_5_vl( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") + + +def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") + + +def merge_megatron_ckpt_gptmodel_mixtral( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/util.py b/toolbox/verl/v0.5.0/verl/models/mcore/util.py new file mode 100644 index 000000000..c1ef7a211 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/util.py @@ -0,0 +1,240 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.packed_seq_params import PackedSeqParams + +from verl.utils.model import CausalLMOutputForPPO + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + max_seqlen_in_batch = seqlens_in_batch_padded.max().item() + + shape = list(input_ids.shape[1:]) + shape[0] = seqlens_in_batch_padded.sum().item() // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + if cp_size <= 1: + seqlen = seqlens_in_batch[i] + input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] + continue + seqlen = seqlens_in_batch_padded[i] // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) + remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = attention_mask[i].sum().item() + output_new[i, attention_mask[i]] = output[0][ + packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s + ] + continue + s_len_padded_chunk = ( + packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i] + ) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = attention_mask[i].sum().item() + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def remove_left_padding( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + sequence_parallel: bool = False, + pre_process: bool = True, +): + """ + Remove left padding from input_ids, attention_mask and position_ids + return new_input_ids, new_attention_mask, new_position_ids + """ + assert attention_mask.ndim == 2 + assert position_ids.ndim == 2 + cp_size = mpu.get_context_parallel_world_size() + assert cp_size == 1, "Context parallel size without seq_pack is not supported" + batch_size = input_ids.shape[0] + shape = list(input_ids.shape) # batch_size, seq_len,... + seq_lens = attention_mask.sum(dim=1) + seq_len = seq_lens.max().item() + if sequence_parallel: + sp_world_size = mpu.get_tensor_model_parallel_world_size() + pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size + seq_len = seq_len + pad_size + shape[1] = seq_len + if pre_process: + new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) + new_attention_mask = torch.zeros( + dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + ) + new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) + for i in range(batch_size): + if pre_process: + new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] + new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] + new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] + if pre_process: + return new_input_ids, new_attention_mask, new_position_ids + else: + return input_ids, new_attention_mask, new_position_ids + + +def recover_left_padding( + result, + attention_mask: torch.Tensor, + original_attention_mask: torch.Tensor, + origin_seqlen: int, + post_process: bool = True, +): + """ + Recover left padding from result + return result + """ + if not post_process: + return result + shape = list(result.shape) + batch_size = shape[0] + shape[1] = origin_seqlen + new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) + for i in range(batch_size): + new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] + return new_result + + +def postprocess_packed_seqs_for_dict_output( + labels_mask: torch.Tensor, + output: CausalLMOutputForPPO, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> dict[str, torch.Tensor]: + """_summary_ + For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. + This function post-processes each tensor in the output dictionary. + Args: + output (CausalLMOutputForPPO): _description_ + packed_seq_params (PackedSeqParams): _description_ + attention_mask (torch.Tensor): _description_ + batch_size (int): _description_ + seq_len (int): _description_ + post_process (bool, optional): _description_. Defaults to True. + Returns: + CausalLMOutputForPPO: _description_ + """ + ret = {} + output.entropy = output.entropy.view(1, -1) + output.log_probs = output.log_probs.view(1, -1) + output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) + ret["entropy"] = postprocess_packed_seqs( + output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + ret["log_probs"] = postprocess_packed_seqs( + output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + return ret diff --git a/toolbox/verl/v0.5.0/verl/models/mcore/weight_converter.py b/toolbox/verl/v0.5.0/verl/models/mcore/weight_converter.py new file mode 100644 index 000000000..791513f32 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/mcore/weight_converter.py @@ -0,0 +1,479 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# online convert mcore weight to pure huggingface weight, no any fusion +# including format conversion and name mapping +# not including resharding +import torch +from megatron.core.transformer import TransformerConfig +from transformers import PretrainedConfig + + +class McoreToHFWeightConverterBase: + def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig): + self.hf_config = hf_config + self.mcore_config = mcore_config + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError + + +class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.self_attention.linear_proj.weight' + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' + # 'decoder.layers.0.self_attention.linear_qkv.weight' + # 'decoder.layers.0.self_attention.linear_qkv.bias' + layer_number = name.split(".")[2] + convert_names = [] + if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: + param_type = name.split(".")[-1] + assert param_type == "bias" or param_type == "weight" + convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") + assert len(params) == 3 + elif "self_attention.linear_proj.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") + assert len(params) == 1 + elif "self_attention.linear_qkv.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight") + assert len(params) == 1 + elif "self_attention.q_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight") + assert len(params) == 1 + elif "self_attention.k_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' + # 'decoder.layers.0.mlp.linear_fc1.weight' + # 'decoder.layers.0.mlp.linear_fc2.weight' + layer_number = name.split(".")[2] + convert_names = [] + if "mlp.linear_fc1.weight" in name: + # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") + assert len(params) == 2 + elif "mlp.linear_fc1.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + +class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # 'decoder.layers.0.mlp.shared_experts.gate_weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight' + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "shared_experts.gate_weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") + assert len(params) == 1 + elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") + assert len(params) == 2 + elif "shared_experts.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", + "language_model.decoder.final_layernorm.weight": "model.norm.weight", + "language_model.output_layer.weight": "lm_head.weight", + "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", + "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", + "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", + "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", + "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", + "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "self_attention.linear_qkv.bias": [ + "self_attn.q_proj.bias", + "self_attn.k_proj.bias", + "self_attn.v_proj.bias", + ], + "self_attention.linear_qkv.weight": [ + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + ], + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + elif model_type == "vision_model": + name_map_after_layer = { + "self_attention.linear_proj.weight": "attn.proj.weight", + "self_attention.linear_proj.bias": "attn.proj.bias", + "self_attention.linear_qkv.layer_norm_weight": "norm1.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer, None) + if mapped_name is None: + assert "linear_qkv" in name_after_layer + assert len(params) == 3 + new_param = torch.cat(params, dim=0) + params = [new_param] + if "bias" in name_after_layer: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") + else: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + + elif model_type == "vision_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "norm2.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"visual.blocks.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + +class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # mcore + # 'decoder.layers.0.input_layernorm.weight' + # 'decoder.layers.0.self_attention.linear_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_proj.weight' + # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight' + # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight' + # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_down_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_up_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight' + # hf + # 'model.layers.0.input_layernorm.weight' + # 'model.layers.0.self_attn.o_proj.weight' + # 'model.layers.0.self_attn.q_proj.weight' + # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight' + # 'model.layers.0.self_attn.kv_a_layernorm.weight' + # 'model.layers.0.self_attn.kv_b_proj.weight' + # 'model.layers.0.self_attn.q_a_proj.weight' + # 'model.layers.0.self_attn.q_b_proj.weight' + # 'model.layers.0.self_attn.q_a_layernorm.weight' + name_map_after_layer = { + "input_layernorm.weight": "input_layernorm.weight", + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_q_proj.weight": "self_attn.q_proj.weight", + "self_attention.linear_kv_down_proj.weight": "self_attn.kv_a_proj_with_mqa.weight", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj.weight": "self_attn.kv_b_proj.weight", + "self_attention.linear_q_down_proj.weight": "self_attn.q_a_proj.weight", + "self_attention.linear_q_up_proj.weight": "self_attn.q_b_proj.weight", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + } + assert len(params) == 1 + convert_names = [] + layer_number = name.split(".")[2] + name_after_layer = name.split(f".{layer_number}.")[1] + convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # mcore dense + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' + # 'decoder.layers.0.mlp.linear_fc2.weight' + # 'decoder.layers.0.mlp.linear_fc1.weight' + # --- + # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight' + # --- + # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight' + # hf dense + # 'model.layers.0.post_attention_layernorm.weight' + # 'model.layers.0.mlp.down_proj.weight' + # 'model.layers.0.mlp.gate_proj.weight' + # 'model.layers.0.mlp.up_proj.weight' + # 'model.layers.1.mlp.shared_experts.gate_proj.weight' + # 'model.layers.1.mlp.shared_experts.up_proj.weight' + # 'model.layers.1.mlp.shared_experts.down_proj.weight' + + # mcore moe + # 'decoder.layers.1.pre_mlp_layernorm.weight' + # 'decoder.layers.1.mlp.router.weight' + # 'decoder.layers.1.mlp.router.expert_bias' + # 'decoder.layers.1.mlp.experts.linear_fc1.weight0' + # --- + # 'decoder.layers.1.mlp.experts.linear_fc2.weight0' + # hf moe + # 'model.layers.1.post_attention_layernorm.weight' + # 'model.layers.1.mlp.gate.weight' + # 'model.layers.1.mlp.gate.e_score_correction_bias' + # 'model.layers.1.mlp.experts.0.gate_proj.weight' + # 'model.layers.1.mlp.experts.0.up_proj.weight' + # 'model.layers.1.mlp.experts.0.down_proj.weight' + + name_map_after_layer = { + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight", + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.shared_experts.linear_fc1.weight": [ + "mlp.shared_experts.gate_proj.weight", + "mlp.shared_experts.up_proj.weight", + ], + "pre_mlp_layernorm.weight": "post_attention_layernorm.weight", + "mlp.router.weight": "mlp.gate.weight", + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + } + convert_names = [] + layer_number = name.split(".")[2] + name_after_layer = name.split(f".{layer_number}.")[1] + if name_after_layer in name_map_after_layer: + mapped_name = name_map_after_layer[name_after_layer] + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + else: + if "mlp.experts.linear_fc1.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + return convert_names, params + + def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" + assert self.mcore_config.num_layers == 61, "only support 61 layers for now" + direct_name_mapping = { + "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", + "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", + "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", + "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params[0]] + assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" + # use proxy name to convert + proxy_name = name.replace("mtp.layers.0.transformer_layer", "decoder.layers.61") + if "self_attention" in proxy_name or "input_layernorm.weight" in proxy_name: + convert_names, params = self._convert_attention_param(proxy_name, params) + elif "mlp" in proxy_name: + convert_names, params = self._convert_mlp_param(proxy_name, params) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + if "mtp" in name: + return self._convert_mtp_param(name, params_one_group) + elif "self_attention" in name or "input_layernorm.weight" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + +class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # decoder.layers.0.mlp.router.weight + # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 + # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 + + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") + elif "mlp.experts.linear_fc1.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") + elif "mlp.experts.linear_fc2.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # qwen3 moe no share expert + + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/__init__.py b/toolbox/verl/v0.5.0/verl/models/qwen2/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/__init__.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/__init__.py new file mode 100644 index 000000000..57e33ee9e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_qwen2_megatron import ( + ParallelQwen2ForCausalLM, + # rmpad with megatron + ParallelQwen2ForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelQwen2ForCausalLMRmPadPP, + ParallelQwen2ForValueRmPad, + ParallelQwen2ForValueRmPadPP, + # original model with megatron + ParallelQwen2Model, +) + +__all__ = [ + "ParallelQwen2ForCausalLM", + "ParallelQwen2ForCausalLMRmPad", + "ParallelQwen2ForCausalLMRmPadPP", + "ParallelQwen2ForValueRmPad", + "ParallelQwen2ForValueRmPadPP", + "ParallelQwen2Model", +] diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py new file mode 100644 index 000000000..3168635c7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -0,0 +1,337 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor = tensor.data.copy_(state_dict[name], non_blocking=True) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print(f"{torch.distributed.get_rank()} loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + print( + f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, " + f"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" + ) + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py new file mode 100644 index 000000000..770e36533 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -0,0 +1,475 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py new file mode 100644 index 000000000..737f73b4c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -0,0 +1,448 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/__init__.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/__init__.py new file mode 100644 index 000000000..263ea596f --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelQwen2Attention +from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + +__all__ = [ + "ParallelQwen2Attention", + "ParallelQwen2DecoderLayer", + "ParallelQwen2DecoderLayerRmPad", + "ParallelQwen2MLP", + "ParallelQwen2RMSNorm", +] diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_attention.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_attention.py new file mode 100644 index 000000000..702c429c2 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -0,0 +1,399 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch.nn.functional as F +from einops import rearrange +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +import torch +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import Qwen2Config + +from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelQwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + # bias=config.attention_bias, + bias=True, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + # bias=config.attention_bias, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return q_embed, k_embed + + +class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, + + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Qwen2RMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_decoder.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_decoder.py new file mode 100644 index 000000000..3c8a2a6ee --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + + +class ParallelQwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelQwen2DecoderLayerRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_linear.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_linear.py new file mode 100644 index 000000000..e6d4a09f4 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_linear.py @@ -0,0 +1,79 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + + +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_mlp.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_mlp.py new file mode 100644 index 000000000..672908a21 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelQwen2MLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 000000000..2f4c90dd4 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelQwen2RMSNorm(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/modeling_qwen2_megatron.py new file mode 100644 index 000000000..92e81be8d --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -0,0 +1,737 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast + +from verl.utils.device import get_device_name +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from Qwen2 pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelQwen2Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLM(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelQwen2Model(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +class ParallelQwen2ModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config: Qwen2Config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelQwen2ModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: Qwen2Config, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + if pre_process or post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + **column_kwargs, + ) + + def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initalizes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.model.embed_tokens.weight.is_embedding_or_output_parameter = True + if self.post_process and self.lm_head.weight is not None: + self.lm_head.weight.is_embedding_or_output_parameter = True + + if not self.share_embeddings_and_output_weights: + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.lm_head.weight.data.fill_(0) + self.lm_head.weight.shared = True + self.lm_head.weight.shared_embedding = True + + if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.to(get_device_name()) + torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) + + def shared_embedding_or_output_weight(self) -> torch.Tensor: + if self.pre_process: + return self.model.embed_tokens.weight + elif self.post_process: + return self.lm_head.weight + return None + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = ' + # f'{self.config.vocab_size}') # [4, 32, 4096] + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits = self.lm_head(hidden_states, weight=output_weight)[0] + # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/toolbox/verl/v0.5.0/verl/models/registry.py b/toolbox/verl/v0.5.0/verl/models/registry.py new file mode 100644 index 000000000..829b9e20c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/registry.py @@ -0,0 +1,58 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Optional + +import torch.nn as nn + +# Supported models in Megatron-LM +# Architecture -> (module, class). +_MODELS = { + "LlamaForCausalLM": ( + "llama", + ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), + ), + "Qwen2ForCausalLM": ( + "qwen2", + ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), + ), + "MistralForCausalLM": ( + "mistral", + ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), + ), +} + + +# return model class +class ModelRegistry: + @staticmethod + def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]: + if model_arch not in _MODELS: + return None + + megatron = "megatron" + + module_name, model_cls_name = _MODELS[model_arch] + if not value: # actor/ref + model_cls_name = model_cls_name[0] + elif value: # critic/rm + model_cls_name = model_cls_name[1] + + module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> list[str]: + return list(_MODELS.keys()) diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/__init__.py b/toolbox/verl/v0.5.0/verl/models/transformers/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/dense_common.py b/toolbox/verl/v0.5.0/verl/models/transformers/dense_common.py new file mode 100644 index 000000000..56fe293f5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/dense_common.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> CausalLMOutputWithPast: + r""" + Copy paste LLaMa's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py + + This function should be generic enough for all pure text models. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def forward_with_triton_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/kimi_vl.py b/toolbox/verl/v0.5.0/verl/models/transformers/kimi_vl.py new file mode 100644 index 000000000..edd79364b --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/kimi_vl.py @@ -0,0 +1,185 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _ulysses_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # patch + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a + k_nope = repeat_kv(k_nope, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1) + k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1) + k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = q.size(2) # full_q_len = seq_length + + else: + full_q_len = q_len + + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + dropout=dropout_rate, + sliding_window=None, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + position_ids=position_ids, # important: pass position ids + softmax_scale=self.softmax_scale, + ) + + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None, None diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/llama.py b/toolbox/verl/v0.5.0/verl/models/transformers/llama.py new file mode 100644 index 000000000..687ceab71 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/llama.py @@ -0,0 +1,239 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable, Optional + +import torch + +if sys.version_info >= (3, 11): + pass +else: + pass + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def llama_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1]. + """ + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # trade off: repeat first and then all to all + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def llama_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.llama.modeling_llama import eager_attention_forward + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/monkey_patch.py b/toolbox/verl/v0.5.0/verl/models/transformers/monkey_patch.py new file mode 100644 index 000000000..d6be65a77 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/monkey_patch.py @@ -0,0 +1,340 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Apply monkey-patch function to models +""" + +import importlib.metadata +import sys +from functools import lru_cache +from typing import Optional + +import torch +from packaging import version +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_utils import PreTrainedModel + +from verl.utils.import_utils import is_trl_available +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + slice_input_tensor, +) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch, + seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) + + +def _ulysses_flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + *args, + position_ids: Optional[torch.Tensor] = None, + **kwargs, +): + """Insert all-to-all before and after flash attention. + DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509 + + Args: + query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim) + key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size) + + Returns: + torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim) + """ + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism" + + # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k, + # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA. + # For example: + # - nheads_k=4, sp=8, repeats=2 + # - nheads_k=8, sp=8, repeats=1 + # - nheads_k=16, sp=8, repeats=1 + repeats = max(ulysses_sp_size // key_states.size(2), 1) + key_states = repeat_kv(key_states, repeats) + value_states = repeat_kv(value_states, repeats) + + # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) + key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) + value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) + + # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate + # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly. + # https://github.com/huggingface/transformers/pull/33932 + + # (bsz, seq_len/n) -> (bsz, seq_len) + position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] + torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids = torch.concat(position_ids_list, dim=-1) + + # (bsz, seq_len, n_head/n, head_dim) + attn_output = _flash_attention_forward( + query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs + ) + + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + + return attn_output + + +def patch_vlm_for_ulysses_input_slicing(model_class: type): + """ + Applies a monkey patch to the forward method of a given model class + to enable Ulysses sequence parallelism input slicing. + """ + + def _create_ulysses_wrapped_decoder_forward(original_forward): + def ulysses_wrapped_decoder_forward(self, *args, **kwargs): + inputs_embeds = kwargs.get("inputs_embeds") + call_kwargs = kwargs.copy() + + current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + slice_now = ( + inputs_embeds is not None + and current_ulysses_sp_size > 1 + and getattr(self, "_needs_initial_slice", True) + ) + if slice_now: + call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) + self._needs_initial_slice = False + try: + return original_forward(self, *args, **call_kwargs) + finally: + if slice_now: + self._needs_initial_slice = True + + return ulysses_wrapped_decoder_forward + + original_forward = model_class.forward + wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward) + model_class.forward = wrapped_forward + print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") + + +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Choose the forward function based on the model and backend. + Args: + model (PreTrainedModel): The model to apply the monkey patch. + use_fused_kernels (bool): Whether to use fused kernels. + fused_kernels_backend (str): The backend to use for fused kernels. + """ + if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: + print( + f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is " + f"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}" + ) + return + + forward_with_torch_backend_function = model.__class__.forward + forward_with_triton_backend_function = model.__class__.forward + if model.config.model_type == "qwen2_5_vl": + from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type == "qwen2_vl": + from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + else: + from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + + if fused_kernels_backend == "triton": + model.__class__.forward = forward_with_triton_backend_function + print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") + elif fused_kernels_backend == "torch": + model.__class__.forward = forward_with_torch_backend_function + print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") + else: + raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + + +def apply_monkey_patch( + model: PreTrainedModel, + ulysses_sp_size: int = 1, + use_remove_padding: bool = True, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Apply monkey patch to the models for ulysses sequence parallel and fused kernel. + + In the end of this function forward function of the model is patched for fused kernel. + If the model is not supported with fused kernel, please return after patch. + """ + + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" + module = sys.modules[model.__module__] + + try: + num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads + except AttributeError: + num_attention_heads, num_key_value_heads = ( + model.config.text_config.num_attention_heads, + model.config.text_config.num_key_value_heads, + ) + + assert num_attention_heads % ulysses_sp_size == 0, ( + f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + ) + assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( + f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size " + f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," + f"kv heads are repeated to ensure correctness." + ) + + if is_trl_available(): + from trl import AutoModelForCausalLMWithValueHead # type: ignore + + def state_dict(self, *args, **kwargs): + return torch.nn.Module.state_dict(self, *args, **kwargs) + + AutoModelForCausalLMWithValueHead.state_dict = state_dict + print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") + + # TODO: VLM models only, unify monkey patch to LLM models. + if model.config.model_type == "qwen2_5_vl": + if is_transformers_version_in_range(min_version="4.53.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention + + # TODO: Support transformers 4.53 + raise ValueError("Transformers 4.53 is not supported") + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, + ) + + if use_remove_padding or ulysses_sp_size > 1: + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward + + Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in Qwen2.5VL") + + if ulysses_sp_size > 1: + if is_transformers_version_in_range(min_version="4.52.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel + + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) + + elif model.config.model_type == "qwen2_vl": + if is_transformers_version_in_range(min_version="4.53.0"): + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention + + # TODO: Support transformers 4.53 + raise ValueError("Transformers 4.53 is not supported") + else: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention + + if use_remove_padding or ulysses_sp_size > 1: + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward + + Qwen2VLAttention.forward = ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in Qwen2VL") + + if ulysses_sp_size > 1: + if is_transformers_version_in_range(min_version="4.52.0"): + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + + patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) + else: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel + + patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) + + elif model.config.model_type == "kimi_vl": + if use_remove_padding or ulysses_sp_size > 1: + # TODO: Changes need to be made when transformers are adapted. + from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward + + module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in KimiVL") + + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM) + + if use_fused_kernels: + print("Not support fused kernels for KimiVL") + + return + + # transformers<=4.47.1 + if use_remove_padding or ulysses_sp_size > 1: + if hasattr(module, "_flash_attention_forward"): + module._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {model.__module__}") + else: + # transformers>=4.48.0 + from transformers.integrations import flash_attention + + flash_attention._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") + + patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) + + +@lru_cache +def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool: + try: + # Get the installed version of the transformers library + transformers_version_str = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError as e: + raise ModuleNotFoundError("The `transformers` package is not installed.") from e + + transformers_version = version.parse(transformers_version_str) + + lower_bound_check = True + if min_version is not None: + lower_bound_check = version.parse(min_version) <= transformers_version + + upper_bound_check = True + if max_version is not None: + upper_bound_check = transformers_version <= version.parse(max_version) + + return lower_bound_check and upper_bound_check diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/npu_patch.py b/toolbox/verl/v0.5.0/verl/models/transformers/npu_patch.py new file mode 100644 index 000000000..e6bb37368 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/npu_patch.py @@ -0,0 +1,50 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch_npu +from torch_npu import npu_rotary_mul as apply_rotary_emb +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm + + +# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in +# subsequent versions +# https://github.com/huggingface/transformers/pull/38491 +def apply_rotary_pos_emb_flashatt_npu( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) + q_embed = apply_rotary_emb( + q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(q) + k_embed = apply_rotary_emb( + k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(k) + return q_embed, k_embed + + +# This api can improve performance on ASCEND NPU +def rms_norm_forward(self, x): + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] + + +Qwen2RMSNorm.forward = rms_norm_forward +modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/qwen2.py b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2.py new file mode 100644 index 000000000..e55fb26d5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2.py @@ -0,0 +1,241 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def qwen2_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +): + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1. + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # use full_q_len to reshape + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def qwen2_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + bsz, q_len, _ = hidden_states.shape + hidden_shape = (bsz, q_len, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_5_vl.py b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_5_vl.py new file mode 100644 index 000000000..51d9753fb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_5_vl.py @@ -0,0 +1,288 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLForConditionalGeneration, +) + + +@dataclass +class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, +) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: + r""" + Copy paste Qwen2_5_VL's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, " + f"features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, " + f"features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + +def forward_with_triton_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_vl.py b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_vl.py new file mode 100644 index 000000000..358b00b6b --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/transformers/qwen2_vl.py @@ -0,0 +1,559 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLCausalLMOutputWithPast, + Qwen2VLForConditionalGeneration, +) +from transformers.utils import is_flash_attn_greater_or_equal + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +try: + from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + flash_attn_varlen_func = None + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 + """ + spatial_merge_size = processor.image_processor.merge_size + tokens_per_second = 2 + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + image_index, video_index = 0, 0 + input_ids = input_ids[attention_mask == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 + + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def prepare_fa2_from_position_ids( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor +): + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seqlens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope + return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) + + +def flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool = True, + position_ids: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + deterministic: Optional[bool] = None, + **kwargs, +): + """ + Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) + """ + causal = is_causal if not use_top_left_mask else is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids[0] + ) # remove channel dimension + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=kwargs.pop("dropout", 0.0), + softmax_scale=kwargs.pop("softmax_scale", None), + causal=causal, + **flash_kwargs, + ) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=is_causal, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + deterministic=deterministic, + **kwargs, + ) # do not pass position_ids to old flash_attention_forward + + return attn_output + + +def ulysses_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, None, None]: + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size + query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = query_states.size(2) # full_q_len = seq_length + else: + full_q_len = q_len + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + position_ids=position_ids, # important: pass position ids + ) # (batch_size, seq_length, num_head / sp_size, head_size) + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, None + + +@dataclass +class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> tuple | Qwen2VLCausalLMOutputWithPast: + r""" + Copy paste Qwen2VL's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, " + f"features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, " + f"features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + +def forward_with_triton_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/toolbox/verl/v0.5.0/verl/models/weight_loader_registry.py b/toolbox/verl/v0.5.0/verl/models/weight_loader_registry.py new file mode 100644 index 000000000..8aa3bc71f --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/models/weight_loader_registry.py @@ -0,0 +1,56 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_weight_loader(arch: str): + from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel + + _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { + "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, + "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, + } + + if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] + raise ValueError( + f"Model architectures {arch} loader are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" + ) + + +def get_weight_saver(arch: str): + from verl.models.mcore.saver import ( + merge_megatron_ckpt_gptmodel, + merge_megatron_ckpt_gptmodel_dpskv3, + merge_megatron_ckpt_gptmodel_mixtral, + merge_megatron_ckpt_gptmodel_qwen2_5_vl, + merge_megatron_ckpt_gptmodel_qwen_moe, + ) + + _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { + "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, + "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, + "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, + "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, + "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + } + if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] + raise ValueError( + f"Model architectures {arch} saver are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" + ) diff --git a/toolbox/verl/v0.5.0/verl/protocol.py b/toolbox/verl/v0.5.0/verl/protocol.py new file mode 100644 index 000000000..39979f848 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/protocol.py @@ -0,0 +1,964 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement base data transfer protocol between any two functions, modules. +We can subclass Protocol to define more detailed batch info with specific keys +""" + +import contextlib +import copy +import logging +import os +import pickle +from dataclasses import dataclass, field +from typing import Callable, Optional + +import numpy as np +import pandas as pd +import ray +import tensordict +import torch +import torch.distributed +from packaging import version +from tensordict import TensorDict +from torch.utils.data import DataLoader + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.py_functional import union_two_dict +from verl.utils.torch_functional import allgather_dict_tensors + +__all__ = ["DataProto", "union_tensor_dict"] + +with contextlib.suppress(Exception): + tensordict.set_lazy_legacy(False).set() + + +class _DataProtoConfigMeta(type): + _config = {} + + auto_padding_key = "_verl_auto_padding" + + @property + def auto_padding(cls): + enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in ["TRUE", "1"] + return enabled_by_env or cls._config.get(cls.auto_padding_key, False) + + @auto_padding.setter + def auto_padding(cls, enabled: bool): + assert isinstance(enabled, bool), f"enabled must be a boolean, got {enabled} as {type(enabled)}" + cls._config[cls.auto_padding_key] = enabled + + +class DataProtoConfig(metaclass=_DataProtoConfigMeta): + pass + + +_padding_size_key = "_padding_size_key_x123d" + + +def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): + """Pad a DataProto to size divisible by size_divisor + + Args: + size_divisor (int): size divisor + + Returns: + data: (DataProto): the padded DataProto + pad_size (int) + """ + assert isinstance(data, DataProto), "data must be a DataProto" + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + padding_protos = [] + remaining_pad = pad_size + while remaining_pad > 0: + take_size = min(remaining_pad, len(data)) + padding_protos.append(data[:take_size]) + remaining_pad -= take_size + data_padded = DataProto.concat([data] + padding_protos) + else: + if len(data) == 0: + logging.warning("padding a DataProto with no item, no changed made") + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad_dataproto(data: "DataProto", pad_size): + """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" + if pad_size != 0: + data = data[:-pad_size] + return data + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Union two tensordicts.""" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + tensor_dict1[key] = tensor_dict2[key] + else: + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + + return tensor_dict1 + + +def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + for key, val in tensor_dict2.items(): + if key in tensor_dict1: + assert isinstance(tensor_dict2[key], np.ndarray) + assert isinstance(tensor_dict1[key], np.ndarray) + # to properly deal with nan and object type + assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + tensor_dict1[key] = val + + return tensor_dict1 + + +def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): + if len(list_of_dict) == 0: + return {} + keys = list_of_dict[0].keys() + output = {key: [] for key in keys} + for data in list_of_dict: + for key, item in data.items(): + assert key in output + output[key].append(item) + return output + + +def fold_batch_dim(data: "DataProto", new_batch_size): + """ + Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] + """ + batch_size = data.batch.batch_size[0] + + assert batch_size % new_batch_size == 0 + + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + + tensor = tensor.view(new_batch_size, -1) + tensor.auto_batch_size_(batch_dims=1) + + for key, val in non_tensor.items(): + non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) + + return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) + + +def unfold_batch_dim(data: "DataProto", batch_dims=2): + """ + Unfold the first n dims as new batch dim + """ + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + tensor.auto_batch_size_(batch_dims=batch_dims) + tensor = tensor.view(-1) + + batch_size = tensor.batch_size[0] + + non_tensor_new = {} + + for key, val in non_tensor.items(): + non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) + + return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) + + +def collate_fn(x: list["DataProtoItem"]): + batch = [] + non_tensor_batch = [] + for data in x: + batch.append(data.batch) + non_tensor_batch.append(data.non_tensor_batch) + batch = torch.stack(batch).contiguous() + non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.array(val, dtype=object) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +@dataclass +class DataProtoItem: + # TODO(zhangchi.usc1992) add consistency check + batch: TensorDict = None + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) + + +@dataclass +class DataProto: + """ + A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. + It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. + TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the + same batch size should be put inside batch. + """ + + batch: TensorDict = None + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) + + def __post_init__(self): + # perform necessary checking + self.check_consistency() + + def __len__(self): + if self.batch is not None: + return self.batch.batch_size[0] + elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: + random_key = list(self.non_tensor_batch.keys())[0] + return self.non_tensor_batch[random_key].shape[0] + else: + return 0 + + def __getitem__(self, item): + """ + Enhanced indexing for DataProto objects. + + Args: + item: Can be one of: + - int: A single index + - slice: A slice object (start:stop:step) + - list: A list of indices + - numpy.ndarray: An array of indices + - torch.Tensor: A tensor of indices + + Returns: + DataProto: For all indexing types except single integers + DataProtoItem: Only for single integer indices + """ + # Case 1: Slice object - use the slice method + if isinstance(item, slice): + return self.slice(item.start, item.stop, item.step) + + # Case 2: List, numpy array, or torch tensor - use sel_idxs + elif isinstance(item, list | np.ndarray | torch.Tensor): + return self.select_idxs(item) + + # Case 3: Single integer - return DataProtoItem for backward compatibility + elif isinstance(item, int | np.integer): + tensor_data = self.batch[item] if self.batch is not None else None + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + # # Case 4: Unsupported type + else: + raise TypeError(f"Indexing with {type(item)} is not supported") + + def __getstate__(self): + import io + + buffer = io.BytesIO() + if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: + self.batch = self.batch.contiguous() + self.batch = self.batch.consolidate() + torch.save(self.batch, buffer) + buffer_bytes = buffer.getvalue() + return buffer_bytes, self.non_tensor_batch, self.meta_info + + def __setstate__(self, data): + import io + + batch_deserialized_bytes, non_tensor_batch, meta_info = data + batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) + batch = torch.load( + batch_deserialized, + weights_only=False, + map_location="cpu" if not get_torch_device().is_available() else None, + ) + self.batch = batch + self.non_tensor_batch = non_tensor_batch + self.meta_info = meta_info + + def save_to_disk(self, filepath): + with open(filepath, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load_from_disk(filepath) -> "DataProto": + with open(filepath, "rb") as f: + data = pickle.load(f) + return data + + def print_size(self, prefix=""): + size_of_tensordict = 0 + if self.batch is not None: + for _, tensor in self.batch.items(): + size_of_tensordict += tensor.element_size() * tensor.numel() + size_of_numpy_array = 0 + for _, numpy_array in self.non_tensor_batch.items(): + size_of_numpy_array += numpy_array.nbytes + + size_of_numpy_array /= 1024**3 + size_of_tensordict /= 1024**3 + + message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" + + if prefix: + message = f"{prefix}, " + message + print(message) + + def check_consistency(self): + """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch + We expose this function as a public one so that user can call themselves directly + """ + if self.batch is not None: + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" + + if self.non_tensor_batch is not None: + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + + if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0: + # TODO: we can actually lift this restriction if needed + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." + + batch_size = self.batch.batch_size[0] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray), ( + f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " + f"{key=}, got {type(val)=}" + ) + assert val.shape[0] == batch_size, ( + f"key {key} length {len(val)} is not equal to batch size {batch_size}" + ) + + @classmethod + def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False): + """Create a DataProto from a dict of tensors and non_tensors""" + tensors = {} + non_tensors = {} + + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + elif isinstance(val, np.ndarray): + non_tensors[key] = val + else: + raise ValueError(f"Unsupported type in data {type(val)}") + + return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) + + @classmethod + def from_dict( + cls, + tensors: Optional[dict[str, torch.Tensor]] = None, + non_tensors=None, + meta_info=None, + num_batch_dims=1, + auto_padding=False, + ): + """Create a DataProto from a dict of tensors. This assumes that + 1. All the tensor in tensors have the same dim0 + 2. Only dim0 is the batch dim + """ + + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" + if non_tensors is not None: + assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." + + if tensors is None: + tensors = {} + if meta_info is None: + meta_info = {} + if non_tensors is None: + non_tensors = {} + + assert isinstance(non_tensors, dict) + + # get and check batch size + batch_size = None + pivot_key = None + for key, tensor in tensors.items(): + if batch_size is None: + batch_size = tensor.shape[:num_batch_dims] + pivot_key = key + else: + current_batch = tensor.shape[:num_batch_dims] + assert batch_size == current_batch, ( + f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " + f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" + ) + + for key, val in non_tensors.items(): + if not isinstance(val, np.ndarray): + non_tensors[key] = np.array(val, dtype=object) + + tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None + if auto_padding: + meta_info[DataProtoConfig.auto_padding_key] = True + return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) + + def to(self, device) -> "DataProto": + """move the batch to device + + Args: + device (torch.device, str): torch device + + Returns: + DataProto: the current DataProto + + """ + if self.batch is not None: + self.batch = self.batch.to(device) + return self + + def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": + """Select a subset of the DataProto via batch_keys and meta_info_keys + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to select + meta_info_keys (list, optional): a list of keys indicating the meta info to select + + Returns: + DataProto: the DataProto with the selected batch_keys and meta_info_keys + """ + # TODO (zhangchi.usc1992) whether to copy + if batch_keys is not None: + batch_keys = tuple(batch_keys) + sub_batch = self.batch.select(*batch_keys) + else: + sub_batch = self.batch + + if non_tensor_batch_keys is not None: + non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} + else: + non_tensor_batch = self.non_tensor_batch + + if deepcopy: + non_tensor_batch = copy.deepcopy(non_tensor_batch) + + if meta_info_keys is not None: + sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} + else: + sub_meta_info = self.meta_info + + if deepcopy: + sub_meta_info = copy.deepcopy(sub_meta_info) + + return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + + def select_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = torch.tensor(idxs) + if idxs.dtype != torch.bool: + idxs = idxs.type(torch.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_torch = torch.from_numpy(idxs) + else: # torch.Tensor + idxs_torch = idxs + idxs_np = idxs.detach().cpu().numpy() + + batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0] + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict( + source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, + batch_size=(batch_size,), + device=self.batch.device, + ) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + + def slice(self, start=None, end=None, step=None): + """ + Slice the DataProto and return a new DataProto object. + This is an improved version of direct slicing which returns a DataProtoItem. + + Args: + start (int, optional): Start index. Defaults to None (start from beginning). + end (int, optional): End index (exclusive). Defaults to None (go to end). + step (int, optional): Step size. Defaults to None (step=1). + + Returns: + DataProto: A new DataProto containing the sliced data + + Examples: + # Using the slice method directly + sliced_data = data_proto.slice(10, 20) + + # Using enhanced indexing (returns DataProto) + sliced_data = data_proto[10:20] + sliced_data = data_proto[::2] # Every other element + + # Using list indexing (returns DataProto) + indices = [1, 5, 10] + selected_data = data_proto[indices] + + # Single index still returns DataProtoItem + single_item = data_proto[5] + """ + # Create a slice object + slice_obj = slice(start, end, step) + + # Handle the batch data + if self.batch is not None: + # Use TensorDict's built-in slicing capabilities + sliced_batch = self.batch[slice_obj] + else: + sliced_batch = None + + # Handle the non-tensor batch data + sliced_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + sliced_non_tensor[key] = val[slice_obj] + + # Return a new DataProto object + return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": + """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to pop + meta_info_keys (list, optional): a list of keys indicating the meta info to pop + + Returns: + DataProto: the DataProto with the poped batch_keys and meta_info_keys + """ + if batch_keys is None: + batch_keys = [] + if meta_info_keys is None: + meta_info_keys = [] + if non_tensor_batch_keys is None: + non_tensor_batch_keys = [] + + tensors = {} + # tensor batch + for key in batch_keys: + assert key in self.batch.keys() + tensors[key] = self.batch.pop(key) + non_tensors = {} + # non tensor batch + for key in non_tensor_batch_keys: + assert key in self.non_tensor_batch.keys() + non_tensors[key] = self.non_tensor_batch.pop(key) + meta_info = {} + for key in meta_info_keys: + assert key in self.meta_info.keys() + meta_info[key] = self.meta_info.pop(key) + return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + def rename(self, old_keys=None, new_keys=None) -> "DataProto": + """ + Note that this function only rename the key in the batch + """ + + def validate_input(keys): + if keys is not None: + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + pass + else: + raise TypeError(f"keys must be a list or a string, but got {type(keys)}") + return keys + + old_keys = validate_input(old_keys) + new_keys = validate_input(new_keys) + + if len(new_keys) != len(old_keys): + raise ValueError( + f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" + ) + + self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) + + return self + + def union(self, other: "DataProto") -> "DataProto": + """Union with another DataProto. Union batch and meta_info separately. + Throw an error if + + - there are conflict keys in batch and they are not equal + - the batch size of two data batch is not the same + - there are conflict keys in meta_info and they are not the same. + + Args: + other (DataProto): another DataProto to union + + Returns: + DataProto: the DataProto after union + """ + self.batch = union_tensor_dict(self.batch, other.batch) + self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.meta_info = union_two_dict(self.meta_info, other.meta_info) + return self + + def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch + dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. + + + Args: + mini_batch_size (int): mini-batch size when iterating the dataset. We require that + ``batch.batch_size[0] % mini_batch_size == 0``. + epochs (int): number of epochs when iterating the dataset. + dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The + dataloader_kwargs is the kwargs passed to the DataLoader. + + Returns: + Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration + steps is ``self.batch.batch_size * epochs // mini_batch_size`` + """ + assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, dict) + train_dataloader = DataLoader( + dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + ) + + def get_data(): + for _ in range(epochs): + for d in train_dataloader: + d.meta_info = self.meta_info + yield d + + return iter(get_data()) + + def is_padding_enabled(self): + """ + Check if padding is enabled for the DataProto. + Returns: + bool: True if padding is enabled, False otherwise. + """ + dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False) + return dataproto_specific_padding or DataProtoConfig.auto_padding + + def padding(self, padding_size, padding_candidate=""): + """Pad the DataProto by concating with padding_candidate.repeat(padding_size) + + Args: + padding_size (int): the number of repeated padding_candidate + padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] + """ + if padding_size == 0: + return + padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) + padding_part = padding_candidate.repeat(padding_size) + padded_dp = DataProto.concat([self, padding_part]) + self.batch = padded_dp.batch + self.non_tensor_batch = padded_dp.non_tensor_batch + + def chunk(self, chunks: int) -> list["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + chunks (int): the number of chunks to split on dim=0 + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + if not self.is_padding_enabled(): + assert len(self) % chunks == 0, ( + f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + ) + + bsz_in_batch = None + if self.batch is not None: + batch_lst = self.batch.chunk(chunks=chunks, dim=0) + bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst]) + chunk_indices = np.cumsum(bsz_in_batch)[:-1] + else: + batch_lst = [None for _ in range(chunks)] + + non_tensor_batch_lst = [{} for _ in range(chunks)] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + if bsz_in_batch is not None: + non_tensor_lst = np.array_split(val, chunk_indices.tolist()) + else: + non_tensor_lst = np.array_split(val, chunks) + assert len(non_tensor_lst) == chunks + for i in range(chunks): + non_tensor_batch_lst[i][key] = non_tensor_lst[i] + + output = [] + for i in range(chunks): + output.append( + type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + ) + + return output + + def split(self, split_size: int) -> list["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + split_size (int): the size of each split + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + return [self[i : i + split_size] for i in range(0, len(self), split_size)] + + @staticmethod + def concat(data: list["DataProto"]) -> "DataProto": + """Concat a list of DataProto. The batch is concatenated among dim=0. + The meta_info is assumed to be identical and will use the first one. + + Args: + data (List[DataProto]): list of DataProto + + Returns: + DataProto: concatenated DataProto + """ + batch_lst = [] + for batch in data: + batch_lst.append(batch.batch) + new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None + + non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.concatenate(val, axis=0) + + cls = type(data[0]) if len(data) > 0 else DataProto + return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + + def reorder(self, indices): + """ + Note that this operation is in-place + """ + indices_np = indices.detach().numpy() + self.batch = self.batch[indices] + self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + + def repeat(self, repeat_times=2, interleave=True): + """ + Repeat the batch data a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if self.batch is not None: + if interleave: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + else: + # Stack the data + repeated_tensors = { + key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(self.batch.batch_size[0] * repeat_times,), + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if interleave: + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + else: + repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + + return type(self)( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): + """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) + Useful in passing grouped tensors that doesn't want to be shuffled in dataset. + keys not in split_keys are repeated to match the shape + Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim. + """ + if self.batch is not None: + unfolded_batch = {} + for key in self.batch.keys(): + if key in split_keys if split_keys is not None else False: + shape = list(self.batch[key].shape) + shape[0] = self.batch[key].shape[0] * n_split + shape[1] = self.batch[key].shape[1] // n_split + unfolded_batch[key] = self.batch[key].reshape(*shape) + else: + unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) + # locate the `unfolded_batch` as a TensorDict on the same device as the original batch + unfolded_batch = TensorDict( + source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device + ) + else: + unfolded_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if key in split_keys: + shape = list(val.shape) + shape[0] = val.shape[0] * n_split + shape[1] = val.shape[1] // n_split + repeated_non_tensor_batch[key] = val.reshape(*shape) + else: + repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0) + + return type(self)( + batch=unfolded_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + def sample_level_repeat(self, repeat_times): + """ + Repeat each row of the batch data a specified number of times. + + Args: + repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if isinstance(repeat_times, tuple): + repeat_times = list(repeat_times) + elif isinstance(repeat_times, torch.Tensor): + assert len(repeat_times.shape) == 1 + repeat_times = repeat_times.tolist() + elif isinstance(repeat_times, np.ndarray): + assert len(repeat_times.shape) == 1 + repeat_times = repeat_times.tolist() + else: + assert isinstance(repeat_times, list), ( + f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" + ) + repeat_times = torch.tensor(repeat_times) + + if self.batch is not None: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(repeat_times.sum().item(),), + device=self.batch.device, + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + + return type(self)( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + +@dataclass +class DataProtoFuture: + """ + DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait + for data so that asynchronous execution becomes possible. + DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. + - collect_fn is a Callable that reduces the list of futures to a DataProto + - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size + and then select + + Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination + - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any + operation on the DataProtoFuture in driver. + """ + + collect_fn: Callable + futures: list[ray.ObjectRef] + dispatch_fn: Callable = None + + @staticmethod + def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture": + output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) + return output + + def chunk(self, chunks: int) -> list["DataProtoFuture"]: + from functools import partial + + arg_future_lst = [] + for i in range(chunks): + # note that we can't directly pass i and chunks + def dispatch_fn(x, i, chunks): + return x.chunk(chunks=chunks)[i] + + arg_future = DataProtoFuture( + collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures + ) + arg_future_lst.append(arg_future) + return arg_future_lst + + def get(self): + output = ray.get(self.futures) # dp_size. + for o in output: + assert isinstance(o, DataProto) + output = self.collect_fn(output) # select dp, concat + if self.dispatch_fn is not None: + output = self.dispatch_fn(output) # split in batch dim, select using dp + return output + + +def all_gather_data_proto(data: DataProto, process_group): + # Note that this is an inplace operator just like torch.distributed.all_gather + group_size = torch.distributed.get_world_size(group=process_group) + assert isinstance(data, DataProto) + prev_device = data.batch.device + data.batch = data.batch.to(get_device_id()) + data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) + data.batch = data.batch.to(prev_device) + # all gather non_tensor_batch + all_non_tensor_batch = [None for _ in range(group_size)] + torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group) + data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} diff --git a/toolbox/verl/v0.5.0/verl/py.typed b/toolbox/verl/v0.5.0/verl/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/toolbox/verl/v0.5.0/verl/single_controller/__init__.py b/toolbox/verl/v0.5.0/verl/single_controller/__init__.py new file mode 100644 index 000000000..ad6c42a80 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from . import base +from .base import * + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +# Note(haibin.lin): single_controller.__version__ is deprecated +with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: + __version__ = f.read().strip() + + +__all__ = base.__all__ diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/__init__.py b/toolbox/verl/v0.5.0/verl/single_controller/base/__init__.py new file mode 100644 index 000000000..b24bd9942 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .worker import Worker +from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup + +__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/decorator.py b/toolbox/verl/v0.5.0/verl/single_controller/base/decorator.py new file mode 100644 index 000000000..303d9ed90 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/decorator.py @@ -0,0 +1,527 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from functools import wraps +from types import FunctionType + +from verl.protocol import DataProtoFuture, _padding_size_key +from verl.utils.py_functional import DynamicEnum + +# here we add a magic number of avoid user-defined function already have this attribute +MAGIC_ATTR = "attrs_3141562937" + + +class Dispatch(DynamicEnum): + """Enum class defining different dispatch modes for distributed computation. + + Each mode represents a specific strategy for distributing data across + different ranks in a distributed system. The modes are used to control + how data is partitioned and processed across different worker groups. + """ + + _registry = {} + _next_value = 0 + + +def init_predefined_dispatch_mode(): + Dispatch.register("RANK_ZERO") + Dispatch.register("ONE_TO_ALL") + Dispatch.register("ALL_TO_ALL") + Dispatch.register("MEGATRON_COMPUTE") + Dispatch.register("MEGATRON_PP_AS_DP") + Dispatch.register("MEGATRON_PP_ONLY") + Dispatch.register("MEGATRON_COMPUTE_PROTO") + Dispatch.register("MEGATRON_PP_AS_DP_PROTO") + Dispatch.register("DP_COMPUTE") + Dispatch.register("DP_COMPUTE_PROTO") + Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC") + Dispatch.register("DP_COMPUTE_METRIC") + # This is a special dispatch mode for vllm ExternalRayDistributedExecutor + Dispatch.register("DIRECT_ROLLOUT_METHOD") + + +class Execute(DynamicEnum): + """Enum class defining different execution modes for distributed computation. + + These modes control how a function should be executed across different ranks + in a distributed system. + """ + + _registry = {} + _next_value = 0 + + +def init_predefined_execute_mode(): + Execute.register("ALL") + Execute.register("RANK_ZERO") + + +# Initialize the two Dynamic Enum Classes +init_predefined_dispatch_mode() +init_predefined_execute_mode() + + +def _split_args_kwargs_data_proto(chunks, *args, **kwargs): + from verl.protocol import DataProto, DataProtoFuture + + splitted_args = [] + for arg in args: + assert isinstance(arg, DataProto | DataProtoFuture) + splitted_args.append(arg.chunk(chunks=chunks)) + + splitted_kwargs = {} + for key, val in kwargs.items(): + assert isinstance(val, DataProto | DataProtoFuture) + splitted_kwargs[key] = val.chunk(chunks=chunks) + + return splitted_args, splitted_kwargs + + +def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): + from verl.protocol import DataProto, DataProtoFuture + + data_proto_len = None + padding_size = None + + def _padding_and_split_data(obj, chunks): + nonlocal data_proto_len, padding_size + assert isinstance(obj, DataProto | DataProtoFuture) + if isinstance(obj, DataProto) and obj.is_padding_enabled(): + # for padding, we only support DataProto with same length + if data_proto_len is None: + data_proto_len = len(obj) + padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0 + else: + assert data_proto_len == len(obj), ( + f"expecting all arg share same length of {data_proto_len}, but got {len(obj)}" + ) + obj.padding(padding_size=padding_size) + return obj.chunk(chunks=chunks) + + splitted_args = [_padding_and_split_data(arg, chunks) for arg in args] + splitted_kwargs = {key: _padding_and_split_data(val, chunks) for key, val in kwargs.items()} + if padding_size is not None: + splitted_kwargs[_padding_size_key] = padding_size + + return splitted_args, splitted_kwargs + + +def dispatch_one_to_all(worker_group, *args, **kwargs): + args = tuple([arg] * worker_group.world_size for arg in args) + kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} + return args, kwargs + + +def dummy_direct_rollout_call(worker_group, *args, **kwargs): + raise NotImplementedError("Direct rollout call is forbidden.") + + +def dispatch_all_to_all(worker_group, *args, **kwargs): + return args, kwargs + + +def collect_all_to_all(worker_group, output): + return output + + +def dispatch_megatron_compute(worker_group, *args, **kwargs): + """ + User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup), ( + f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" + ) + + # ray put all the args in advance to avoid duplicate serialization cost + import ray + + args = [[ray.put(dp_arg) for dp_arg in arg] for arg in args] + kwargs = {k: [ray.put(dp_v) for dp_v in v] for k, v in kwargs.items()} + + def _transform_data(obj_list, worker_group): + assert isinstance(obj_list, tuple | list) and len(obj_list) == worker_group.dp_size + transformed_data = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + transformed_data.append(obj_list[local_dp_rank]) + return transformed_data + + all_args = tuple([_transform_data(arg, worker_group) for arg in args]) + all_kwargs = {key: _transform_data(val, worker_group) for key, val in kwargs.items()} + + return all_args, all_kwargs + + +def collect_megatron_compute(worker_group, output): + """ + Only collect the data from the tp=0 and pp=last and every dp ranks + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_dp = [] + pp_size = worker_group.get_megatron_global_info().pp_size + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1 and local_rank_info.cp_rank == 0: + output_in_dp.append(output[global_rank]) + return output_in_dp + + +def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): + """ + All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) + return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) + + +def _concat_data_proto_or_future(output: list): + import ray + + from verl.protocol import DataProto, DataProtoFuture + + # make sure all the elements in output has the same type + for o in output: + assert type(o) is type(output[0]) + + o = output[0] + + if isinstance(o, DataProto): + return DataProto.concat(output) + elif isinstance(o, ray.ObjectRef): + return DataProtoFuture.concat(output) + else: + raise NotImplementedError + + +def collect_megatron_compute_data_proto(worker_group, output): + """ + Each output must be a DataProto. We concat the dim=0 of output + """ + import ray + + from verl.protocol import DataProto + + output = collect_megatron_compute(worker_group, output) + for o in output: + assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + + return _concat_data_proto_or_future(output) + + +def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): + """ + treat pp as dp. + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + pp_size = worker_group.pp_size + dp_size = worker_group.dp_size + cp_size = worker_group.cp_size + pp_dp_cp_size = pp_size * dp_size * cp_size + + def _transform_data(obj_list, worker_group): + assert isinstance(obj_list, list | tuple) and len(obj_list) == pp_dp_cp_size + transformed_data = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank + local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank + # compute the rank in obj_list. Note that the order is dp then cp then pp + # Also note that the outputs within a pp group will be firstly allgathered, then only the + # output of pp0 will be collected. + # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order: + # dispatch: pp_allgther: collect: + # dp 0 1 2 3 dp 0 1 2 3 + # pp +---------+ pp +-------------+ + # 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH + # 1 | B D F H | 1 | AB CD EF GH | + # +---------+ +-------------+ + dp_cp_rank = local_cp_rank * dp_size + local_dp_rank + arg_rank = dp_cp_rank * pp_size + local_pp_rank + transformed_data.append(obj_list[arg_rank]) + return transformed_data + + all_args = tuple([_transform_data(arg, worker_group) for arg in args]) + all_kwargs = {key: _transform_data(val, worker_group) for key, val in kwargs.items()} + + return all_args, all_kwargs + + +def collect_megatron_pp_as_dp(worker_group, output): + """ + treat pp as dp. Only collect data on tp=0 + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_dp = [] + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0: + output_in_dp.append(output[global_rank]) + return output_in_dp + + +def collect_megatron_pp_only(worker_group, output): + """ + Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_pp = [] + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0: + output_in_pp.append(output[global_rank]) + return output_in_pp + + +def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_cp_size, *args, **kwargs) + ret = dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs) + return ret + + +def collect_megatron_pp_as_dp_data_proto(worker_group, output): + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + output = collect_megatron_pp_as_dp(worker_group, output) + return _concat_data_proto_or_future(output) + + +def dispatch_dp_compute(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + for arg in args: + assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size + for k, v in kwargs.items(): + assert isinstance(v, tuple | list) and len(v) == worker_group.world_size + return args, kwargs + + +def collect_dp_compute(worker_group, output): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + assert len(output) == worker_group.world_size + return output + + +def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + # Note: enable auto padding for dp compute DatapProto + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( + worker_group.world_size, + *args, + **kwargs, + ) + return splitted_args, splitted_kwargs + + +def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + assert isinstance(args[0], FunctionType) # NOTE: The first one args is a function! + + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) + splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args + return splitted_args_with_func, splitted_kwargs + + +def collect_dp_compute_data_proto(worker_group, output): + import ray + + from verl.protocol import DataProto + + for o in output: + assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + + output = collect_dp_compute(worker_group, output) + return _concat_data_proto_or_future(output) + + +# Global registry for dispatch mode. +DISPATCH_MODE_FN_REGISTRY = { + Dispatch.ONE_TO_ALL: { + "dispatch_fn": dispatch_one_to_all, + "collect_fn": collect_all_to_all, + }, + Dispatch.ALL_TO_ALL: { + "dispatch_fn": dispatch_all_to_all, + "collect_fn": collect_all_to_all, + }, + Dispatch.MEGATRON_COMPUTE: { + "dispatch_fn": dispatch_megatron_compute, + "collect_fn": collect_megatron_compute, + }, + Dispatch.MEGATRON_PP_AS_DP: { + "dispatch_fn": dispatch_megatron_pp_as_dp, + "collect_fn": collect_megatron_pp_as_dp, + }, + Dispatch.MEGATRON_PP_ONLY: {"dispatch_fn": dispatch_one_to_all, "collect_fn": collect_megatron_pp_only}, + Dispatch.MEGATRON_COMPUTE_PROTO: { + "dispatch_fn": dispatch_megatron_compute_data_proto, + "collect_fn": collect_megatron_compute_data_proto, + }, + Dispatch.MEGATRON_PP_AS_DP_PROTO: { + "dispatch_fn": dispatch_megatron_pp_as_dp_data_proto, + "collect_fn": collect_megatron_pp_as_dp_data_proto, + }, + Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, + Dispatch.DP_COMPUTE_PROTO: { + "dispatch_fn": dispatch_dp_compute_data_proto, + "collect_fn": collect_dp_compute_data_proto, + }, + Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { + "dispatch_fn": dispatch_dp_compute_data_proto_with_func, + "collect_fn": collect_dp_compute_data_proto, + }, + Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, + Dispatch.DIRECT_ROLLOUT_METHOD: { + "dispatch_fn": dummy_direct_rollout_call, + "collect_fn": dummy_direct_rollout_call, + }, +} + + +def get_predefined_dispatch_fn(dispatch_mode): + return DISPATCH_MODE_FN_REGISTRY[dispatch_mode] + + +def register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn): + """ + Register a new dispatch mode. + """ + dispatch_mode = Dispatch.register(dispatch_mode_name) + _check_dispatch_mode(dispatch_mode) + assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode_name {dispatch_mode_name} already exists" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + + +def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): + """ + Update the dispatch mode. + """ + _check_dispatch_mode(dispatch_mode) + assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode {dispatch_mode} not found" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + + +def get_predefined_execute_fn(execute_mode): + """ + Note that here we only asks execute_all and execute_rank_zero to be implemented + Leave the choice of how these two functions handle argument 'blocking' to users + """ + predefined_execute_mode_fn = { + Execute.ALL: {"execute_fn_name": "execute_all"}, + Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, + } + return predefined_execute_mode_fn[execute_mode] + + +def _check_dispatch_mode(dispatch_mode): + assert isinstance(dispatch_mode, Dispatch | dict), ( + f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" + ) + if isinstance(dispatch_mode, dict): + necessary_keys = ["dispatch_fn", "collect_fn"] + for key in necessary_keys: + assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" + + +def _check_execute_mode(execute_mode): + assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" + + +def _materialize_futures(*args, **kwargs): + new_args = [] + for arg in args: + if isinstance(arg, DataProtoFuture): + arg = arg.get() + # add more type to materialize + new_args.append(arg) + for k, v in kwargs.items(): + if isinstance(v, DataProtoFuture): + kwargs[k] = v.get() + + new_args = tuple(new_args) + return new_args, kwargs + + +def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): + """Register a function with distributed execution configuration. + + This decorator registers a function with specific dispatch and execution modes + for distributed computation. It handles both synchronous and asynchronous + functions, and optionally materializes futures before execution. + + Args: + dispatch_mode: + Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL. + execute_mode: + Execute mode for computation distribution. Default: Execute.ALL. + blocking: + Whether the execution should be blocking. Defaults to True. + materialize_futures: + Whether to materialize the data before dispatching. Defaults to True. + + Returns: + A decorator that wraps the original function with distributed execution + configuration. + """ + _check_dispatch_mode(dispatch_mode=dispatch_mode) + _check_execute_mode(execute_mode=execute_mode) + + def decorator(func): + @wraps(func) + def inner(*args, **kwargs): + if materialize_futures: + args, kwargs = _materialize_futures(*args, **kwargs) + return func(*args, **kwargs) + + @wraps(func) + async def async_inner(*args, **kwargs): + if materialize_futures: + args, kwargs = _materialize_futures(*args, **kwargs) + return await func(*args, **kwargs) + + wrapper = async_inner if inspect.iscoroutinefunction(func) else inner + attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} + setattr(wrapper, MAGIC_ATTR, attrs) + return wrapper + + return decorator diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/__init__.py b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker.py b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker.py new file mode 100644 index 000000000..e938a0bff --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.single_controller.base.worker import DistGlobalInfo, DistRankInfo, Worker + + +class MegatronWorker(Worker): + def __init__(self, cuda_visible_devices=None) -> None: + super().__init__(cuda_visible_devices) + + def get_megatron_global_info(self): + from megatron.core import parallel_state as mpu + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size) + return info + + def get_megatron_rank_info(self): + from megatron.core import parallel_state as mpu + + tp_rank = mpu.get_tensor_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) + return info + + def _init_hf_config_and_tf_config( + self, + model_path, + tokenizer_or_path, + dtype, + override_model_config, + override_transformer_config, + trust_remote_code=False, + use_mbridge=False, + ): + from transformers import AutoConfig + + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.fs import copy_to_local + from verl.utils.model import update_model_config + + # Step 1: initialize the tokenizer + self.local_path = copy_to_local(model_path) + if tokenizer_or_path is None: + self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) + elif isinstance(tokenizer_or_path, str): + self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + else: + self.tokenizer = tokenizer_or_path + self.processor = tokenizer_or_path + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + # Step 2: get the hf + hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) + + # Step 3: override the hf config + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config.get("model_config", {})) + self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + self.architectures = getattr(hf_config, "architectures", None) + if self.rank == 0: + print(f"Model config after override: {hf_config}") + tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) + + if use_mbridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + self.bridge = bridge + else: + self.bridge = None + + print(f"TF config: {tf_config}") + self.hf_config = hf_config + self.tf_config = tf_config diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker_group.py b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker_group.py new file mode 100644 index 000000000..b9beb844c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/megatron/worker_group.py @@ -0,0 +1,55 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from verl.single_controller.base import ResourcePool, WorkerGroup + +from .worker import DistGlobalInfo, DistRankInfo + + +class MegatronWorkerGroup(WorkerGroup): + def __init__(self, resource_pool: ResourcePool, **kwargs): + super().__init__(resource_pool=resource_pool, **kwargs) + self._megatron_rank_info = None + self._megatron_global_info: DistGlobalInfo = None + + def init_megatron(self, default_megatron_kwargs: dict = None): + raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") + + def get_megatron_rank_info(self, rank: int) -> DistRankInfo: + assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}" + return self._megatron_rank_info[rank] + + @property + def tp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.tp_size + + @property + def dp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.dp_size + + @property + def pp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.pp_size + + @property + def cp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.cp_size + + def get_megatron_global_info(self): + return self._megatron_global_info diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/__init__.py b/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/ray.py b/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/ray.py new file mode 100644 index 000000000..ac071cde5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/register_center/ray.py @@ -0,0 +1,37 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ray + + +@ray.remote +class WorkerGroupRegisterCenter: + def __init__(self, rank_zero_info): + self.rank_zero_info = rank_zero_info + # rank -> node_id + self.workers_info: dict[int, str] = {} + + def get_rank_zero_info(self): + return self.rank_zero_info + + def set_worker_info(self, rank, node_id) -> None: + self.workers_info[rank] = node_id + + def get_worker_info(self) -> dict[int, str]: + return self.workers_info + + +def create_worker_group_register_center(name, info): + return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/worker.py b/toolbox/verl/v0.5.0/verl/single_controller/base/worker.py new file mode 100644 index 000000000..2606a3ef3 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/worker.py @@ -0,0 +1,301 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +the class for Worker +""" + +import os +import socket +from dataclasses import dataclass + +import ray + +from verl.utils.device import get_torch_device, get_visible_devices_keyword + +from .decorator import Dispatch, Execute, register + + +@dataclass +class DistRankInfo: + tp_rank: int + dp_rank: int + pp_rank: int + cp_rank: int + + +@dataclass +class DistGlobalInfo: + tp_size: int + dp_size: int + pp_size: int + cp_size: int + + +class WorkerHelper: + @staticmethod + def _get_node_ip(): + if os.getenv("WG_BACKEND", None) == "ray": + return ray.util.get_node_ip_address() + else: + raise NotImplementedError("WG_BACKEND now just support ray mode.") + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_availale_master_addr_port(self): + return self._get_node_ip().strip("[]"), str(self._get_free_port()) + + +# we assume that in each WorkerGroup, there is a Master Worker +class Worker(WorkerHelper): + """A distributed worker that handles initialization and configuration for distributed training. + + This class manages worker initialization, configuration, and provides methods for executing + distributed operations. It handles communication settings, device configuration, and worker + metadata management. + """ + + fused_worker_attr_name = "fused_worker_dict" + + def __new__(cls, *args, **kwargs): + """Create a new Worker instance with proper initialization based on environment settings.""" + instance = super().__new__(cls) + + # note that here we use int to distinguish + disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) + if disable_worker_init: + return instance + + rank = os.environ.get("RANK", None) + worker_group_prefix = os.environ.get("WG_PREFIX", None) + + # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init + if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: + instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) + + return instance + + def _configure_before_init(self, register_center_name: str, rank: int): + """Configure worker settings before initialization. + + Args: + register_center_name (str): + Name of the register center Ray actor for worker coordination + rank (int): + Rank of the worker in the distributed setup + """ + assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" + + if rank == 0: + master_addr, master_port = self.get_availale_master_addr_port() + rank_zero_info = { + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + } + + if os.getenv("WG_BACKEND", None) == "ray": + from verl.single_controller.base.register_center.ray import create_worker_group_register_center + + self.register_center = create_worker_group_register_center( + name=register_center_name, info=rank_zero_info + ) + + os.environ.update(rank_zero_info) + else: + self.register_center = ray.get_actor(register_center_name) + + # set worker info for node affinity scheduling + ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) + + @classmethod + def env_keys(cls): + """The keys of the environment variables that are used to configure the Worker.""" + return [ + "WORLD_SIZE", + "RANK", + "LOCAL_WORLD_SIZE", + "LOCAL_RANK", + "MASTER_ADDR", + "MASTER_PORT", + get_visible_devices_keyword().upper(), + ] + + def __init__(self, cuda_visible_devices=None) -> None: + """Initialize the worker with environment settings and device configuration. + + Args: + cuda_visible_devices (str, optional): + CUDA visible devices configuration. Defaults to None. + """ + # construct a meta from environment variable. Note that the import must be inside the class because + # it is executed remotely + import os + + self._setup_env_cuda_visible_devices() + + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + self._rank = rank + self._world_size = world_size + + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + + local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + store = { + "_world_size": world_size, + "_rank": rank, + "_local_world_size": local_world_size, + "_local_rank": local_rank, + "_master_addr": master_addr, + "_master_port": master_port, + } + if cuda_visible_devices is not None: + store[f"_{get_visible_devices_keyword()}".lower()] = cuda_visible_devices + + self._configure_with_store(store=store) + + self.fused_worker_dict = {} + + def get_fused_worker_by_name(self, worker_name: str): + """Get a fused worker by its name. + + Args: + worker_name (str): + Name of the worker to retrieve + """ + return self.fused_worker_dict.get(worker_name, None) + + def _setup_env_cuda_visible_devices(self): + from verl.utils.ray_utils import ray_noset_visible_devices + + is_ray_noset_visible_devices = ray_noset_visible_devices() + + # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES`` + rocr_val = os.environ.get("ROCR_VISIBLE_DEVICES", None) + hip_val = os.environ.get("HIP_VISIBLE_DEVICES", None) + cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if hip_val: + # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency. + # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES + # at this point. + val = os.environ.pop("HIP_VISIBLE_DEVICES") + hip_val = None + if cuda_val: + assert val == cuda_val, ( + f"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values " + f"found: {val} and {cuda_val}." + ) + else: + cuda_val = val + os.environ["CUDA_VISIBLE_DEVICES"] = val + # os.environ["HIP_VISIBLE_DEVICES"] = val + + if rocr_val: + # You must take care if both HIP/CUDA and ROCR env vars are set as they have + # different meanings. Both env vars accept either a list of ints or a + # list of UUIDs. The ROCR env var is processed first which then reduces + # the number of GPUs that HIP can select from. + # https://github.com/pytorch/pytorch/pull/144026 + # To avoid the complexity of this, we simply gives out error if both are set + # (Also to keep consistency with ray's practice with 2.45.0). + # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES + # and remove ROCR_VISIBLE_DEVICES. + if cuda_val: + raise ValueError("Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.") + + cuda_val = os.environ.pop("ROCR_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_val + rocr_val = None + + if is_ray_noset_visible_devices: + # NOTE: Ray will automatically set the *_VISIBLE_DEVICES + # environment variable for each actor, unless + # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, + # so we need to set local rank when the flag is set. + local_rank = os.environ.get("RAY_LOCAL_RANK") + os.environ["LOCAL_RANK"] = local_rank + get_torch_device().set_device(int(local_rank)) + + def _configure_with_store(self, store: dict): + """ + This function should only be called inside by WorkerGroup + """ + store_env_dict = {f"_{key.lower()}": store.get(f"_{key.lower()}", None) for key in type(self).env_keys()} + self.__dict__.update(store_env_dict) # this is hacky + # print(f"__dict__: {self.__dict__}") + for key in type(self).env_keys(): + val = self.__dict__.get(f"_{key.lower()}", None) + if val is not None: + # print(f"set {key} to {val}") + os.environ[key] = str(val) + os.environ["REDIS_STORE_SERVER_HOST"] = ( + str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + ) + + def get_master_addr_port(self): + """Get the master address and port for distributed communication.""" + return self._master_addr, self._master_port + + def get_cuda_visible_devices(self): + """Get the CUDA visible devices configuration.""" + import os + + visible_devices = os.environ.get(get_visible_devices_keyword().upper(), "not set") + return visible_devices + + @property + def world_size(self): + """Get the total number of workers in the distributed setup.""" + return self._world_size + + @property + def rank(self): + """Get the rank of this worker in the distributed setup.""" + return self._rank + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) + def execute_with_func_generator(self, func, *args, **kwargs): + """Execute a function with function generator dispatch mode. + + Args: + func: + Function to execute + *args: + Positional arguments for the function + **kwargs: + Keyword arguments for the function + """ + ret_proto = func(self, *args, **kwargs) + return ret_proto + + @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) + def execute_func_rank_zero(self, func, *args, **kwargs): + """Execute a function in rank zero execution mode. + + Args: + func: + Function to execute + *args: + Positional arguments for the function + **kwargs: + Keyword arguments for the function + """ + result = func(*args, **kwargs) + return result diff --git a/toolbox/verl/v0.5.0/verl/single_controller/base/worker_group.py b/toolbox/verl/v0.5.0/verl/single_controller/base/worker_group.py new file mode 100644 index 000000000..cb86ab4f5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/base/worker_group.py @@ -0,0 +1,252 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +the class of WorkerGroup +""" + +import logging +import signal +import threading +import time +from typing import Any, Callable + +from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn + + +class ResourcePool: + """ + Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations. + The class provides methods to calculate world size, local world sizes, and local ranks + across all nodes in the pool. + """ + + def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None: + """Initialize the ResourcePool with node processes and GPU configuration. + + Args: + process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list. + max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10. + n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8. + """ + if process_on_nodes is None: + process_on_nodes = [] + self._store = process_on_nodes + self.max_colocate_count = max_colocate_count + self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node + + def add_node(self, process_count): + self._store.append(process_count) + + @property + def world_size(self): + """Total number of processes across all nodes in the pool.""" + return sum(self._store) + + def __call__(self) -> Any: + return self._store + + @property + def store(self): + return self._store + + def local_world_size_list(self) -> list[int]: + """Returns a flat list where each process has its local world size.""" + nested_local_world_size_list = [ + [local_world_size for _ in range(local_world_size)] for local_world_size in self._store + ] + return [item for row in nested_local_world_size_list for item in row] + + def local_rank_list(self) -> list[int]: + """Returns a flat list of local ranks for all processes across all nodes.""" + nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] + return [item for row in nested_local_rank_list for item in row] + + +class ClassWithInitArgs: + """ + Wrapper class that stores constructor arguments for deferred instantiation. + This class is particularly useful for remote class instantiation where + the actual construction needs to happen at a different time or location. + """ + + def __init__(self, cls, *args, **kwargs) -> None: + """Initialize the ClassWithInitArgs instance. + + Args: + cls: The class to be instantiated later + *args: Positional arguments for the class constructor + **kwargs: Keyword arguments for the class constructor + """ + self.cls = cls + self.args = args + self.kwargs = kwargs + + self.fused_worker_used = False + + def __call__(self) -> Any: + """Instantiate the stored class with the stored arguments.""" + return self.cls(*self.args, **self.kwargs) + + +def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None: + """Continuously monitors worker processes and raises SIGABRT if any worker dies. + + Args: + workers (List): + List of worker objects to monitor + is_alive (Callable): + Function to check if a worker is alive + gap_time (float): + Time interval between checks + """ + import time + + while True: + for worker in workers: + if not is_alive(worker): + logging.warning(f"worker {worker} is not alive sending signal to main thread") + signal.raise_signal(signal.SIGABRT) + time.sleep(gap_time) + + +class WorkerGroup: + """ + Base class for managing a group of workers in a distributed system. + The class provides methods for worker management, aliveness checking, and method binding. + """ + + fused_worker_execute_fn_name = "_fuw_execute" + + def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: + self._is_init_with_detached_workers = resource_pool is None + + self.fused_worker_used = False + + if resource_pool is not None: + # handle the case when WorkGroup is attached to an existing one + self._procecss_dispatch_config = resource_pool() + else: + self._procecss_dispatch_config = None + + self._workers = [] + self._worker_names = [] + + self._master_addr = None + self._master_port = None + + self._checker_thread: threading.Thread = None + + def _is_worker_alive(self, worker): + """Check if a worker is alive. Must be implemented by derived classes.""" + raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") + + def _block_until_all_workers_alive(self) -> None: + """Blocks until all workers in the group are alive.""" + while True: + all_state = [self._is_worker_alive(worker) for worker in self._workers] + if False in all_state: + time.sleep(1) + else: + break + + def start_worker_aliveness_check(self, every_n_seconds=1) -> None: + """Starts a background thread to monitor worker aliveness. + + Args: + every_n_seconds (int): Interval between aliveness checks + """ + # before starting checking worker aliveness, make sure all workers are already alive + self._block_until_all_workers_alive() + + self._checker_thread = threading.Thread( + target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) + ) + self._checker_thread.start() + + @property + def world_size(self): + """Number of workers in the group.""" + return len(self._workers) + + def _bind_worker_method(self, user_defined_cls, func_generator): + """Binds worker methods to the WorkerGroup based on registered attributes. + + Args: + user_defined_cls (type): The class containing methods to bind + func_generator (Callable): Function that generates the bound method + + Returns: + List[str]: List of method names that were successfully bound + """ + method_names = [] + for method_name in dir(user_defined_cls): + try: + method = getattr(user_defined_cls, method_name) + assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + except Exception: + # if it is a property, it will fail because Class doesn't have instance property + continue + + if hasattr(method, MAGIC_ATTR): + # this method is decorated by register + attribute = getattr(method, MAGIC_ATTR) + assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}" + assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" + + dispatch_mode = attribute["dispatch_mode"] + execute_mode = attribute["execute_mode"] + blocking = attribute["blocking"] + + # get dispatch fn + if isinstance(dispatch_mode, Dispatch): + # get default dispatch fn + fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) + dispatch_fn = fn["dispatch_fn"] + collect_fn = fn["collect_fn"] + else: + assert isinstance(dispatch_mode, dict) + assert "dispatch_fn" in dispatch_mode + assert "collect_fn" in dispatch_mode + dispatch_fn = dispatch_mode["dispatch_fn"] + collect_fn = dispatch_mode["collect_fn"] + + # get execute_fn_name + execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) + wg_execute_fn_name = execute_mode["execute_fn_name"] + + # get execute_fn from string + try: + execute_fn = getattr(self, wg_execute_fn_name) + assert callable(execute_fn), "execute_fn must be callable" + except Exception: + print(f"execute_fn {wg_execute_fn_name} is invalid") + raise + + # bind a new method to the RayWorkerGroup + func = func_generator( + self, + method_name, + dispatch_fn=dispatch_fn, + collect_fn=collect_fn, + execute_fn=execute_fn, + blocking=blocking, + ) + + try: + setattr(self, method_name, func) + method_names.append(method_name) + except Exception as e: + raise ValueError(f"Fail to set method_name {method_name}") from e + + return method_names diff --git a/toolbox/verl/v0.5.0/verl/single_controller/ray/__init__.py b/toolbox/verl/v0.5.0/verl/single_controller/ray/__init__.py new file mode 100644 index 000000000..d2a5d6d3c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/ray/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + create_colocated_worker_cls, + create_colocated_worker_cls_fused, +) + +__all__ = [ + "RayClassWithInitArgs", + "RayResourcePool", + "RayWorkerGroup", + "create_colocated_worker_cls", + "create_colocated_worker_cls_fused", +] diff --git a/toolbox/verl/v0.5.0/verl/single_controller/ray/base.py b/toolbox/verl/v0.5.0/verl/single_controller/ray/base.py new file mode 100644 index 000000000..b692206be --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/ray/base.py @@ -0,0 +1,893 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +import time +from copy import deepcopy +from typing import Any, Optional + +import ray +from ray.experimental.state.api import get_actor +from ray.util import list_named_actors +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy + +from verl.protocol import DataProto, _padding_size_key +from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch +from verl.utils.py_functional import temp_env_var + +__all__ = ["Worker"] + + +def get_random_string(length: int) -> str: + import random + import string + + letters_digits = string.ascii_letters + string.digits + return "".join(random.choice(letters_digits) for _ in range(length)) + + +def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking): + class Functor: + def __call__(this, *args, **kwargs): + args, kwargs = dispatch_fn(self, *args, **kwargs) + padding_count = kwargs.pop(_padding_size_key, 0) + output = execute_fn(method_name, *args, **kwargs) + if blocking: + output = ray.get(output) + output = collect_fn(self, output) + if padding_count > 0: + if isinstance(output, DataProto): + indices = [i for i in range(len(output))][:-padding_count] + output = output.select_idxs(indices) + elif isinstance(output, list): + output = output[:-padding_count] + return output + + # use class type to pass the method_name to get a better observability + return type(method_name, (Functor,), {})() + + +def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]: + """ + Sort the placement groups by node ip, all bundles in a single placement group should be on the same node. + + FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK + to be consistent across nodes when resume from checkpoint. + + With this function, if there's only one resource pool and there's no node change, RANK should be consistent + across nodes in multiple ray jobs, even if the whole ray cluster is restarted. + """ + node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()} + pg_ip = {} + for pg in pgs: + specs = ray._private.state.state.placement_group_table(pg.id) + # all bunles should be on the same node + node_id = specs["bundles_to_node_id"][0] + pg_ip[pg.id] = node_ip[node_id] + return sorted(pgs, key=lambda pg: pg_ip[pg.id]) + + +class RayResourcePool(ResourcePool): + def __init__( + self, + process_on_nodes: Optional[list[int]] = None, + use_gpu: bool = True, + name_prefix: str = None, + max_colocate_count: int = 10, + detached=False, + accelerator_type: Optional[str] = None, + ) -> None: + super().__init__(process_on_nodes, max_colocate_count) + self.use_gpu = use_gpu + # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") + self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self.pgs = None + self.detached = detached + self.accelerator_type = accelerator_type + + def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): + if self.pgs is not None: + return self.pgs + + pg_name_prefix = ( + name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + ) + # print(f"pg_name_prefix = {pg_name_prefix}") + if device_name == "npu": + device_name = "NPU" + elif device_name == "cuda": + device_name = "GPU" + + bundle = {"CPU": self.max_colocate_count} + if self.use_gpu: + bundle[device_name] = 1 + if self.accelerator_type is not None: + bundle[self.accelerator_type] = 1e-4 + pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store] + + lifetime = "detached" if self.detached else None + + pgs = [ + placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) + for idx, bundles in enumerate(pg_scheme) + ] + + ray.get([pg.ready() for pg in pgs]) + + self.pgs = pgs + return pgs + + +def extract_pg_from_exist( + resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool +) -> list: + src_pgs = [ + pg + for role_name, resource_pool in resource_pools.items() + for pg in resource_pool.get_placement_groups() + if role_name in src_role_names + ] + + sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) + sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) + + unsorted_pgs: list[tuple[int, PlacementGroup]] = [] + searching_idx = 0 + for request_process, original_idx in sorted_process_on_nodes: + assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" + assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( + f"requesting {request_process} processes, bundle count cannot satisfy" + ) + unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) + searching_idx += 1 + + return [pg for _, pg in sorted(unsorted_pgs)] + + +def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: + assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" + assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" + assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" + assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" + + new_store = rp1.store + rp2.store + + merged = type(rp1)(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}") + merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups() + + return merged + + +class RayClassWithInitArgs(ClassWithInitArgs): + """A wrapper class for Ray actors with initialization arguments. + + This class extends ClassWithInitArgs to provide additional functionality for + configuring and creating Ray actors with specific resource requirements and + scheduling strategies. + """ + + def __init__(self, cls, *args, **kwargs) -> None: + # self._options = kwargs.pop('options', dict()) + super().__init__(cls, *args, **kwargs) + self._options = {} + self._additional_resource = {} + + def set_additional_resource(self, additional_resource): + """Set additional resource requirements for the actor. + + Args: + additional_resource: Dictionary specifying additional resource requirements + """ + self._additional_resource = additional_resource + + def update_options(self, options: dict): + """Update the Ray actor creation options. + + Args: + options: Dictionary of options to update + """ + self._options.update(options) + + def __call__( + self, + placement_group, + placement_group_bundle_idx, + use_gpu: bool = True, + num_gpus=1, + sharing_with=None, + device_name="cuda", + ) -> Any: + """Create and return a Ray actor with the configured options. + + Args: + placement_group: Ray placement group for scheduling + placement_group_bundle_idx: Index of the bundle in the placement group + use_gpu: Whether to use GPU resources + num_gpus: Number of GPUs to allocate + sharing_with: Actor to share resources with + device_name: Device for training + + Returns: + A Ray actor handle with the configured options + """ + if sharing_with is not None: + target_node_id = ray.get(sharing_with.get_node_id.remote()) + visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) + options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} + return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs) + + options = { + "scheduling_strategy": PlacementGroupSchedulingStrategy( + placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx + ) + } + options.update(self._options) + + if use_gpu and device_name == "cuda": + options["num_gpus"] = num_gpus + if use_gpu and device_name == "npu": + options["resources"] = {"NPU": num_gpus} + + if len(self._additional_resource) > 1: + for k, v in self._additional_resource.items(): + options[k] = v + + # print("cls:", self.cls) + # print("args: ", self.args) + # print("kwargs: ", self.kwargs) + return self.cls.options(**options).remote(*self.args, **self.kwargs) + + +class RayWorkerGroup(WorkerGroup): + """A group of Ray workers that can be managed collectively. + + This class extends WorkerGroup to provide Ray-specific functionality for + creating and managing groups of Ray actors with specific resource requirements + and scheduling strategies. + """ + + def __init__( + self, + resource_pool: RayResourcePool = None, + ray_cls_with_init: RayClassWithInitArgs = None, + bin_pack: bool = True, + name_prefix: str = None, + detached=False, + worker_names=None, + worker_handles: list[ray.actor.ActorHandle] = None, + ray_wait_register_center_timeout: int = 300, + **kwargs, + ) -> None: + """Initialize a RayWorkerGroup. + + Args: + resource_pool: Resource pool for worker allocation + ray_cls_with_init: Class with initialization arguments for workers + bin_pack: Whether to use strict bin packing for resource allocation + name_prefix: Prefix for worker names + detached: Whether workers should be detached + worker_names: Names of existing workers to attach to + ray_wait_register_center_timeout: Timeout for waiting on register center + **kwargs: Additional keyword arguments + """ + super().__init__(resource_pool=resource_pool, **kwargs) + self.ray_cls_with_init = ray_cls_with_init + self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self._ray_wait_register_center_timeout = ray_wait_register_center_timeout + # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. + self.fused_worker_used = ray_cls_with_init.fused_worker_used + # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to + # this WorkerGroup. + self.sub_cls_name = "" + self.device_name = kwargs.get("device_name", "cuda") + self.profile_steps = kwargs.get("profile_steps", None) + self.worker_nsight_options = kwargs.get("worker_nsight_options", None) + if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: + self.worker_nsight_options["capture-range-end"] = f"repeat-shutdown:{6 * len(self.profile_steps)}" + + if worker_names is not None and (not self.fused_worker_used): + assert self._is_init_with_detached_workers + self._worker_names = worker_names + + if self._is_init_with_detached_workers: + self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) + else: + self._init_with_resource_pool( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached + ) + + if ray_cls_with_init is not None: + self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) + + self.wg_dict = None + self.method_names = [] + + def _is_worker_alive(self, worker: ray.actor.ActorHandle): + """Check if a worker actor is still alive. + + Args: + worker: Ray actor handle to check + + Returns: + bool: True if the worker is alive, False otherwise + """ + worker_state_dict = get_actor(worker._actor_id.hex()) + return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False + + def _init_with_detached_workers(self, worker_names, worker_handles): + # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly + # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have + # strong reference to these actors. + # https://github.com/ray-project/ray/pull/45699 + workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] + self._workers = workers + self._world_size = len(worker_names) + + def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): + """Initialize the worker group by creating new workers from a resource pool. + + Args: + resource_pool: Resource pool for worker allocation + ray_cls_with_init: Class with initialization arguments for workers + bin_pack: Whether to use strict bin packing for resource allocation + detached: Whether workers should be detached + """ + use_gpu = resource_pool.use_gpu + + strategy = "PACK" + if bin_pack: + strategy = "STRICT_PACK" + pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) + world_size = resource_pool.world_size + self._world_size = world_size + # cia.add_kwarg("_world_size", world_size) + num_gpus = 1 / resource_pool.max_colocate_count + + rank = -1 + local_world_size = resource_pool.store[0] + for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): + assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " + for local_rank in range(local_world_size): + rank += 1 + + # we pass in environment variable at option so that Worker can use environment variable to set + env_vars = { + "WORLD_SIZE": str(world_size), + "RANK": str(rank), + "WG_PREFIX": self.name_prefix, + "WG_BACKEND": "ray", + "RAY_LOCAL_WORLD_SIZE": str(local_world_size), + "RAY_LOCAL_RANK": str(local_rank), + } + if rank != 0: + env_vars["MASTER_ADDR"] = self._master_addr + env_vars["MASTER_PORT"] = self._master_port + + import re + + cia_name = type(ray_cls_with_init.cls).__name__ + match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" + cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" + name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 + + if self.profile_steps and self.device_name == "cuda": + ray_cls_with_init.update_options( + { + "runtime_env": { + "env_vars": env_vars, + "nsight": self.worker_nsight_options, + }, + "name": name, + } + ) + else: + ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) + + if detached: + ray_cls_with_init.update_options({"lifetime": "detached"}) + + # create a worker + worker = ray_cls_with_init( + placement_group=pg, + placement_group_bundle_idx=local_rank, + use_gpu=use_gpu, + num_gpus=num_gpus, + device_name=self.device_name, + ) + self._workers.append(worker) + self._worker_names.append(name) + + if rank == 0: + register_center_actor = None + actor_name = f"{self.name_prefix}_register_center" + start_time = time.time() + + while time.time() - start_time < self._ray_wait_register_center_timeout: + if actor_name in list_named_actors(): + register_center_actor = ray.get_actor(actor_name) + break + + elapsed = int(time.time() - start_time) + if elapsed % 30 == 0: + logging.warning( + "Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of " + "%s seconds.", + actor_name, + elapsed, + self._ray_wait_register_center_timeout, + ) + time.sleep(1) + + if register_center_actor is None: + raise TimeoutError( + f"Failed to get register_center_actor {actor_name} " + f"in {list_named_actors(all_namespaces=True)} " + f"for {self._ray_wait_register_center_timeout} seconds. " + "Ensure that any lingering Ray resources from previous " + "runs are cleaned up (e.g., by restarting the Ray cluster), " + "or adjust the waiting time by modifying the config " + "`trainer.ray_wait_register_center_timeout`." + ) + + rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) + self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] + # print(f"rank_zero_info: {rank_zero_info}") + # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") + + @property + def worker_names(self): + return self._worker_names + + @classmethod + def from_detached( + cls, + name_prefix=None, + worker_names=None, + worker_handles=None, + ray_cls_with_init=None, + **kwargs, + ): + """Create a worker group from existing detached workers. + + Args: + name_prefix: Prefix for worker names + worker_names: Names of existing workers to attach to + ray_cls_with_init: Class with initialization arguments for workers + + Returns: + A new RayWorkerGroup instance + """ + worker_group = cls( + resource_pool=None, + ray_cls_with_init=ray_cls_with_init, + name_prefix=name_prefix, + worker_names=worker_names, + worker_handles=worker_handles, + **kwargs, + ) + return worker_group + + def spawn(self, prefix_set): + """Spawn to a dictionary of worker groups, each with a subset of method with prefix. + + Args: + prefix_set: Set of prefixes to create worker groups for + + Returns: + Dictionary of worker groups keyed by prefix + """ + if self.fused_worker_used: + return self.spawn_fused(prefix_set) + + def _rebind_actor_methods(worker_group, actor_name): + prefix: str = actor_name + "_" + for method_name in dir(worker_group): + if method_name.startswith(prefix): + original_method_name = method_name.removeprefix(prefix) + method = getattr(worker_group, method_name) + setattr(worker_group, original_method_name, method) + + new_worker_group_dict = {} + for prefix in prefix_set: + new_worker_group = self.from_detached( + name_prefix=self.name_prefix, + worker_names=self._worker_names, + worker_handles=self._workers, + ray_cls_with_init=self.ray_cls_with_init, + profile_steps=self.profile_steps, + worker_nsight_options=self.worker_nsight_options, + ) + + _rebind_actor_methods(new_worker_group, prefix) + new_worker_group_dict[prefix] = new_worker_group + return new_worker_group_dict + + def spawn_fused(self, prefix_set): + """Create a dictionary of worker groups for fused workers. + + Args: + prefix_set: Set of prefixes to create worker groups for + + Returns: + Dictionary of worker groups keyed by prefix + """ + wg_dict = dict() + for key in prefix_set: + new_wg = deepcopy(self) + new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator) + new_wg.sub_cls_name = key + wg_dict[key] = new_wg + return wg_dict + + def fuse(self, prefix_set): + """Fuse multiple worker groups into the current worker group. + + Args: + prefix_set: Set of prefixes to fuse into the worker group + """ + if self.wg_dict is None: + self.wg_dict = self.spawn(prefix_set) + for role_name, role_wg in self.wg_dict.items(): + setattr(self, role_name, role_wg) + self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) + + def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs): + """Execute a method on a single worker remotely. + + Args: + worker: The worker actor handle + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ + if self.fused_worker_used and method_name not in self.method_names: + remote_call = getattr(worker, self.fused_worker_execute_fn_name) + return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) + # fused worker not used + remote_call = getattr(worker, method_name) + return remote_call.remote(*args, **kwargs) + + def execute_rank_zero_sync(self, method_name: str, *args, **kwargs): + """Execute a method on rank zero worker synchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Result of the method execution + """ + return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs)) + + def execute_rank_zero_async(self, method_name: str, *args, **kwargs): + """Execute a method on rank zero worker asynchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ + return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs) + + def execute_rank_zero(self, method_name: str, *args, **kwargs): + """Alias for execute_rank_zero_async. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + Remote object reference to the method execution + """ + return self.execute_rank_zero_async(method_name, *args, **kwargs) + + def execute_all(self, method_name: str, *args, **kwargs): + """Alias for execute_all_async. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of remote object references to the method executions + """ + return self.execute_all_async(method_name, *args, **kwargs) + + def execute_all_sync(self, method_name: str, *args, **kwargs): + """Execute a method on all workers synchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of results from all workers + """ + return ray.get(self.execute_all_async(method_name, *args, **kwargs)) + + def execute_all_async(self, method_name: str, *args, **kwargs): + """Execute a method on all workers asynchronously. + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of remote object references to the method executions + """ + # Here, we assume that if all arguments in args and kwargs are lists, + # and their lengths match len(self._workers), we'll distribute each + # element in these lists to the corresponding worker + # print(f"execute_all_async: method {method_name}({args}, {kwargs})") + length = len(self._workers) + if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): + if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): + # print(f"splitting args and kwargs into {length} shards") + result = [] + for i in range(length): + sliced_args = tuple(arg[i] for arg in args) + sliced_kwargs = {k: v[i] for k, v in kwargs.items()} + result.append( + self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs) + ) + return result + + return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] + + @property + def master_address(self): + return self._master_addr + + @property + def master_port(self): + return self._master_port + + @property + def workers(self): + return self._workers + + @property + def world_size(self): + return self._world_size + + +""" +Utilities that enables creating workers inside the same ray.Actor, +with code written in separate ray.Actors. +""" + + +# deprecated, switching to FusedWorker +def _bind_workers_method_to_parent(cls, key, user_defined_cls): + """ + Binds the methods of each worker to the WorkerDict. + Note that we only bind public methods that are decorated by register + """ + + for method_name in dir(user_defined_cls): + try: + method = getattr(user_defined_cls, method_name) + assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + except Exception: + # if it is a property, it will fail because Class doesn't have instance property + continue + + if hasattr(method, MAGIC_ATTR): + + def generate_function(name, key=key): + def func(self, *args, **kwargs): + # dispatch to the actual worker + return getattr(self.worker_dict[key], name)(*args, **kwargs) + + async def async_func(self, *args, **kwargs): + # dispatch to the actual worker + return await getattr(self.worker_dict[key], name)(*args, **kwargs) + + wrapper = async_func if inspect.iscoroutinefunction(method) else func # noqa: B023 + + return wrapper + + func = generate_function(method_name) + # pass MAGIC_ATTR for outer worker group + attrs = getattr(method, MAGIC_ATTR) + setattr(func, MAGIC_ATTR, attrs) + try: + # bind direct rollout method to class without prefix + if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: + assert not hasattr(cls, method_name), ( + f"conflict direct rollout method {method_name} with role {key}" + ) + setattr(cls, method_name, func) + print(f"bind role {key} method {method_name} to class {cls}") + else: + method_name_with_prefix = key + "_" + method_name + setattr(cls, method_name_with_prefix, func) + except Exception as e: + raise ValueError(f"Fail to set method_name {method_name}") from e + + +def _unwrap_ray_remote(cls): + if hasattr(cls, "__ray_actor_class__"): + cls = cls.__ray_actor_class__ + return cls + + +def _determine_fsdp_megatron_base_class(mros: list): + """ + - megatron: base class should be MegatronWorker + - fsdp: base class should be Worker + """ + for cls in mros[0]: + if cls.__name__ == "MegatronWorker": + return cls + if cls.__name__ == "Worker": + return cls + raise ValueError(f"Cannot determine base class for {mros}") + + +# deprecated, switching to FusedWorker +def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function should return a class instance that delegates the calls to every + cls in cls_dict + """ + cls_dict = {} + init_args_dict = {} + worker_cls = _determine_fsdp_megatron_base_class( + [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] + ) + assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" + print(f"colocated worker base class {worker_cls}") + + for key, cls in class_dict.items(): + cls_dict[key] = cls.cls + init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} + + assert cls_dict.keys() == init_args_dict.keys() + + # TODO: create a class with customizable name + class WorkerDict(worker_cls): + def __init__(self): + super().__init__() + self.worker_dict = {} + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + # directly instantiate the class without remote + # in worker class, e.g. + # when DISABLE_WORKER_INIT == 1 it will return immediately + with temp_env_var("DISABLE_WORKER_INIT", "1"): + self.worker_dict[key] = user_defined_cls( + *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) + ) + + # now monkey-patch the methods from inner class to WorkerDict + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls) + + remote_cls = ray.remote(WorkerDict) + remote_cls = RayClassWithInitArgs(cls=remote_cls) + return remote_cls + + +FusedWorkerCLSName = "FusedWorker" + + +def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function returns a FusedWorker class. + + `FusedWorker.{class_name}` -> FusedClass + Use `class_name` as a param to directly access the underlying class. + + `FusedWorker._fuw_execute("{class_name}_fwmn_{method_name}", *args, **kwargs)` + First param must be "{class_name}_fwmn_{method_name}" in order to access `method_name` + of underlying class `{class_name}`. + + `FusedWorker.fused_worker_dict` -> {"class_name": FusedClass} + Stores all underlying classes. + + `FusedClass.fused_worker_dict` -> {"class_name": FusedClass} + The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other + underlying classes. + """ + raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()} + init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()} + init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()} + cls_names = list(class_dict.keys()) + + # FusedWorker_Actor_Critic + class_name_renamed = "_".join([FusedWorkerCLSName] + cls_names) + + class FusedWorker(Worker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cls_names = cls_names + self.raw_cls_dict = raw_cls_dict + self.init_args_dict = init_args_dict + self.init_kwargs_dict = init_kwargs_dict + + for cls_name, udc, ud_args, ud_kwargs in zip( + self.cls_names, + self.raw_cls_dict.values(), + self.init_args_dict.values(), + self.init_kwargs_dict.values(), + strict=True, + ): + with temp_env_var("DISABLE_WORKER_INIT", "1"): + udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed + udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f"{name_prefixed}_" + # cls_name = "actor", "critic", udc = ActorWorker, CriticWorker + self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs) + setattr(self, cls_name, self.fused_worker_dict[cls_name]) + + # injecting fused_worker to each sub worker so they can be aware of existence of each other + for _, worker in self.fused_worker_dict.items(): + setattr(worker, Worker.fused_worker_attr_name, self.fused_worker_dict) + + def _fuw_execute(self, method_name: str, *args, **kwargs): + # for fused_worker, method_name is in a form of "{cls_name}_fwmn_{method_name}" + # where fwmn stands "fused worker method name" + names = method_name.split("_fwmn_") + cls_name = names[0] + method_name = names[1] + + assert cls_name in self.fused_worker_dict, ( + f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" + ) + udc_method = getattr(self.fused_worker_dict[cls_name], method_name) + return udc_method(*args, **kwargs) + + renamed_fused_worker_cls = type(class_name_renamed, (FusedWorker,), {}) + renamed_fused_worker_cls.is_fused_worker = True + renamed_fused_worker_cls.raw_cls_dict = raw_cls_dict + + return renamed_fused_worker_cls + + +def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function returns a RayClassWithInitArgs instance of FusedWorker, which is an replacement + of `create_colocated_worker_cls`. WorkerGroup constructed using this class will be a colocated + WorkerGroup, which will be referenced as `ColocateWorkerGroup` below. + + `ColocateWorkerGroup.spawn(prefix_set)` + returns a dict of WorkerGroup {"class_name": WorkerGroup}, WorkerGroup in this dict will + have methods of underlying class `class_name` attached. + + `ColocateWorkerGroup.fuse(prefix_set)` + After executing this function, `ColocateWorkerGroup.{class_name}` will return WorkerGroup + with methods of underlying class `class_name` attached. + """ + raw_colocated_worker_cls = create_colocated_worker_raw_cls(class_dict) + + remote_cls = ray.remote(raw_colocated_worker_cls) + cia = RayClassWithInitArgs(cls=remote_cls) + cia.fused_worker_used = True + + return cia diff --git a/toolbox/verl/v0.5.0/verl/single_controller/ray/megatron.py b/toolbox/verl/v0.5.0/verl/single_controller/ray/megatron.py new file mode 100644 index 000000000..b46fe44a1 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/single_controller/ray/megatron.py @@ -0,0 +1,77 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import ray + +from verl.single_controller.base.megatron.worker import DistGlobalInfo, DistRankInfo +from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + +from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + + +# NOTE(sgm): for open-source megatron-core +class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): + """ + MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup + so that the dispatcher can use it to dispatch data. + """ + + def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): + """ + Initialize the NVMegatronRayWorkerGroup. + + Args: + resource_pool (RayResourcePool): The resource pool containing worker resources + ray_cls_with_init (RayClassWithInitArgs): The Ray class with initialization arguments + **kwargs: Additional keyword arguments to pass to the parent class + """ + super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + self._megatron_global_info: DistGlobalInfo = ray.get( + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) + + +class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): + """ + MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup + so that the dispatcher can use it to dispatch data. + """ + + def __init__( + self, + resource_pool: RayResourcePool, + ray_cls_with_init: RayClassWithInitArgs, + default_megatron_kwargs: dict = None, + **kwargs, + ): + super().__init__( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + default_megatron_kwargs=default_megatron_kwargs, + **kwargs, + ) + self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + self._megatron_global_info: DistGlobalInfo = ray.get( + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) + + def init_megatron(self, default_megatron_kwargs: Optional[dict] = None): + # after super, we will call init of each worker + if not self._is_init_with_detached_workers: + # only init_megatron if the WorkerGroup is created from scratch + self.execute_all_sync(method_name="init_megatron", default_megatron_kwargs=default_megatron_kwargs) diff --git a/toolbox/verl/v0.5.0/verl/third_party/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/third_party/sglang/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/sglang/__init__.py new file mode 100644 index 000000000..15593caaf --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/sglang/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/third_party/sglang/parallel_state.py b/toolbox/verl/v0.5.0/verl/third_party/sglang/parallel_state.py new file mode 100644 index 000000000..cdec743d1 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/sglang/parallel_state.py @@ -0,0 +1,328 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The SGlang team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" + +import os +from typing import Optional + +import sglang.srt.distributed.parallel_state as ps +import torch +import torch.distributed +from sglang.srt.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) + +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +# NOTE(linjunrong): this function is for megatron +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a separate inference group with infer tp & micro dp + initialize_model_parallel_for_sglang( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call +# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the conterparts +# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None. +# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for +# verl itself as how it was done in verl.third_party.vllm.parallel_state. Note that the process is a little +# bit different +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + f"tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + ) + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}" + ) + + +# TODO(sgm): deviate from the v0.5.4, not pp now +# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return _TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_sglang( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + if ps._TP is not None: + _TP = ps._TP + else: + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + if ps._TP is not None: + _PP = ps._TP + else: + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +# NOTE(linjunrong): In the vllm version parallel_state.py. verl created its own _TP and _PP as verl want to use +# the process group for some extra purpose. Under the hood, there is no difference between them and the original +# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference +# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly. +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/toolbox/verl/v0.5.0/verl/third_party/torch/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/torch/__init__.py new file mode 100644 index 000000000..7664279b7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/torch/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/__init__.py new file mode 100644 index 000000000..7664279b7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/_state_dict_utils.py b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/_state_dict_utils.py new file mode 100644 index 000000000..d308449f7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/_state_dict_utils.py @@ -0,0 +1,840 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +# ruff: noqa: B028, UP038, UP007, E721, E501 +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from collections.abc import Mapping, MutableMapping +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union, cast + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._functional_collectives import AsyncCollectiveTensor + +if dist.is_available() or TYPE_CHECKING: + from torch.distributed import distributed_c10d + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed.tensor import DTensor, Replicate, distribute_tensor + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, +) -> torch.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = distributed_c10d._get_pg_default_device(pg) if device is None else device + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros(chunk_size, dtype=sharded_tensor.dtype, device=pg_device) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + pass + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[torch.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = torch.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, torch.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) or iter_object is None: + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = "" if isinstance(companion_obj, dict) else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, torch.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + if isinstance(companion_obj, DTensor): + assert isinstance(ret, DTensor) + companion_obj._local_tensor.copy_(ret._local_tensor, non_blocking=non_blocking) + else: + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + return ret + + +def _gather_state_dict( + state_dict: dict[str, Any], + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[torch.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = torch.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = value.local_shards()[0].tensor.device if value.local_shards() else cpu_device + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: dict[str, Any], + *, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +@torch.no_grad() +def _copy_state_dict( + state_dict: dict[str, Any], + copy_state_dict: dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +@torch.no_grad() +def _create_cpu_state_dict( + state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> torch.Tensor: + if len(obj.size()) == 0: + return torch.tensor(0, dtype=obj.dtype) + + if share_memory: + t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + + def unpin_memory(t): + succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) + assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}" + + weakref.finalize(t, unpin_memory, t) + succ = int( + torch.cuda.cudart().cudaHostRegister( + t.data_ptr(), + t.numel() * t.element_size(), + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + assert succ == 0, f"Pinning shared memory failed with error-code: {succ}" + return t + elif pin_memory: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype) + + def dtensor_func( + obj: DTensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> DTensor: + if len(obj.size()) == 0: + return obj + + if obj.device != torch.device("cpu"): + ret = cast(DTensor, obj.to(device="cpu")) + else: + ret = copy.deepcopy(obj) + ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None) + return ret + + ret = _iterate_state_dict( + state_dict, + _identity_func, + dtensor_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: dict[str, Any], + compared_state_dict: dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, + ) -> torch.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: torch.Size + dtype: torch.dtype + + +def _broadcast_tensors( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + tensors = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + assert isinstance(full_state, torch.Tensor) + full_tensor = full_state.detach().to(device) + else: + tensor_info = full_state_dict[key] + full_tensor = torch.empty( + size=tensor_info.size, + device=device, + dtype=tensor_info.dtype, + ) + tensors.append(full_tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = (local_state, full_tensor) + else: + local_state_dict[key] = full_tensor + + if pg is None: + pg = dist.distributed_c10d._get_default_group() + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key, None) + if _local_state is None or torch.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) for cur_shape, cur_offset in zip(shape, offset, strict=False) + ] + if local_state.is_meta: + # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. + local_tensor = full_tensor[slices].detach().clone() + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + ret = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + else: + ret = local_state + # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. + ret.to_local().copy_(full_tensor[slices]) + local_state_dict[key] = ret + + +def _broadcast_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, + cpu_offload: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not torch.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + + +def _distribute_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not torch.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + assert isinstance(value, torch.Tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from torch.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = dict[str, OBJ_PATH] +STATE_DICT_TYPE = dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: CONTAINER_TYPE | list[Any] = {} if type(key) == str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast(CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(list[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict(state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/__init__.py new file mode 100644 index 000000000..7664279b7 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/state_dict.py b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/state_dict.py new file mode 100644 index 000000000..e4555802a --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/torch/distributed/checkpoint/state_dict.py @@ -0,0 +1,1493 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +# ruff: noqa: B028, UP038, UP007, E721 +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from collections.abc import Generator, Iterable +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import Any, Callable, Optional, Union, cast, no_type_check + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp._common_utils import ( + FSDP_WRAPPED_MODULE, + _get_module_fsdp_state_if_fully_sharded_module, +) +from torch.distributed.tensor import DTensor +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils._pytree import tree_map_only + +from verl.third_party.torch.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = set[str] +PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] +ValueType = Union[PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]] +DictValueType = dict[str, ValueType] +ListDictValueType = list[DictValueType] +OptimizerStateType = dict[str, DictValueType | ListDictValueType] + + +_patched_state_dict: set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + dsd_fqn_modifiers: str = "_fqn_modifiers" + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: dict[ + str | torch.Tensor, + FQNS_T | torch.Tensor, + ] = field(default_factory=dict) + shared_params_mapping: dict[ + str | torch.Tensor, + FQNS_T | torch.Tensor, + ] = field(default_factory=dict) + submodule_prefixes: set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: list[nn.Module] = field(default_factory=list) + + +@functools.cache +def _get_fqns( + model: nn.Module, + name: str, + dsd_fqn_modifiers: str = "_fqn_modifiers", + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + assert curr_obj_name == "module" + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + assert curr_obj_name == "_orig_mod" + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + # In some modeuls, _fqn_modifiers would not shown in the state_dict keys, + # skip them in the fqn to ensure load stat dict successfully for them. + if hasattr(curr_obj, dsd_fqn_modifiers): + if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(curr_obj_name): + if hasattr(curr_obj, removed_fqn): + curr_obj = getattr(curr_obj, removed_fqn) + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"): + visited_modules: set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + # if user have state_dict_hooks in their model, they can add the state_dict key changes + # at dsd_fqn_modifiers in input to align with the function of state_dict_hook + if hasattr(module, dsd_fqn_modifiers) and name in getattr(module, dsd_fqn_modifiers)().values(): + # skip _fqn_modifiers here thus remove the last `.` added + new_fqn = curr_fqn[:-1] + else: + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain(module.named_buffers(recurse=False), module.named_parameters(recurse=False)): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) != nn.Module.get_extra_state: + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) + if optim_only and not optims: + raise RuntimeError("Optimizers are not passed in but optim_only is set to True.") + + options = options or StateDictOptions() + + fqn_param_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {} + shared_params_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + + submodule_prefixes: set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "Submodule FQN should only have 1 instance" + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError("full_state_dict must be True when broadcast_from_rank0 is True.") + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="FSDP.state_dict_type", category=FutureWarning) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(list[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict.keys(): + if _FLAT_PARAM in key: + raise RuntimeError(f"{key} contains {_FLAT_PARAM}. This can happen if the model is not the root module.") + + +def _state_dict_fn(obj: nn.Module | torch.optim.Optimizer, api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict(state_dict: dict[str, Any], info: _StateDictInfo) -> dict[str, Any]: + if info.full_state_dict: + ranks_only = () if (not info.cpu_offload or not torch.distributed.is_initialized()) else (0,) + return _gather_state_dict(state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@torch.no_grad() +def _get_model_state_dict(model: nn.Module, info: _StateDictInfo) -> dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + assert len(fqns) == 1, (key, fqns) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict.keys(): + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + for key, p in list(state_dict.items()): + if torch.is_tensor(p) and p.is_meta: + state_dict.pop(key) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@torch.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers): + fqns = _get_fqns(model, key, info.dsd_fqn_modifiers) + fqns_with_prefix = _get_fqns( + model, + key, + info.dsd_fqn_modifiers, + skip_ddp_prefix=False, + skip_compiler_prefix=False, + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix, strict=False): + if (not info.broadcast_from_rank0 or dist.get_rank() == 0) and fqn != fqn_with_prefix: + load_value = state_dict.pop(fqn, None) + if load_value is None: + if info.strict: + raise RuntimeError(f"Missing key: {fqn}.") + else: + state_dict[fqn_with_prefix] = load_value + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + devices = set() + for key, value in local_state_dict.items(): + if torch.is_tensor(value) and value.dim() > 0: + devices.add(value.device) + # In lora state_dict, there could be multiple devices, with meta device inside. + # Take the other device in the broadcast/distribtue, and set assign to True + if torch.device("meta") in devices: + devices.remove(torch.device("meta")) + assign = True + if len(devices) == 0: + devices.add(dist.distributed_c10d._get_pg_default_device()) + elif len(devices) > 1: + raise ValueError("Multiple devices found") + + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, + local_state_dict, + device=devices.pop(), + strict=info.strict, + cpu_offload=info.cpu_offload, + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=devices.pop()) + for fqn, local_state in local_state_dict.items(): + state_dict[fqn] = local_state + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")(state_dict=state_dict, strict=info.strict, assign=assign), + ) + + +def _init_optim_state(optim: torch.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = torch.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = torch.tensor(0.0) if isinstance(param_group["lr"], torch.Tensor) else 0.0 + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float)): + raise NotImplementedError( + f"Flattening optimizer state_dict only supports tensor, int, float states now. Type is {type(v)}." + ) + + ret: dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{_STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(list[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + # If a parameter is shared, only one of the FQN will be used. + # So we need to verify which if this fqn is actually used in + # the state_dict. + if fqn in info.shared_params_mapping: + in_params = False + for k in param_group.keys(): + if k == _PARAMS: + continue + flatten_key = f"{_PG}.{fqn}.{k}" + if flatten_key in state_dict: + in_params = True + break + else: + in_params = True + + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[f"{_STATE}.{fqn}.{state_name}"] + + first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] + for k in param_group.keys(): + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@torch.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacment without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)), strict=False)) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + assert len(fqns) == 1 + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast(OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (torch.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: dict[int, int] = {} + + if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) + params.append(fqn) + if param.requires_grad: + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + if len(param_group[_PARAMS]) == 0: + # Param_group with empty params. + ret = [] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if len(cast(list[str], loaded_param_group[_PARAMS])) == 0: + ret.append(loaded_param_group) + if len(ret) != 1: + raise ValueError( + "There are param groups that have zero parameters. " + "In such a case, DSD only support exactly one param group " + "with zero parameters." + "But the loaded state_dict has zero or more than one param groups " + "that have zero parameters." + ) + if len(optim_state_dict[_PG]) != len(optim.param_groups): + raise ValueError( + "When there is a parameter group that has zero parameters, multiple optimizers are not supported." + ) + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + pg_idx = pg_mapping.get(id(param_group), -1) + if pg_idx == -1: + continue + + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[pg_idx][key] = value + + return return_osd + + +@torch.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) + else: + optim_state_dict = _unflatten_optim_state_dict(optim, cast(dict[str, ValueType], state_dict), info) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns(model, original_fqn, skip_compiler_prefix=False) + if fqns == fqns_with_compiler: + continue + + assert len(fqns) == 1 + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(dict[str, Any], g) + params = [key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load(model, optim, optim_state_dict) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(torch.Tensor, _device, local_state_dict) + assert device is not None + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict(flatten_local_osd, local_osd_mapping) + for pg in optim_state_dict[_PG]: + if _PARAMS not in pg: + cast(dict[str, ValueType], pg)[_PARAMS] = [] + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> tuple[dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: dict[nn.Module, dict[str, ValueType]] | dict[str, ValueType], +) -> dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) + cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) + new_state_dict: dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" + prefix = f"{next(iter(fqns))}." + new_state_dict.update({prefix + subfqn: value for subfqn, value in sub_state_dict.items()}) + return new_state_dict + else: + return cast(dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after + ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be + initialized correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()`` + is called on the optimizers. Otherwise, the optimizer states won't be initialized + correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict) + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options(model, optimizers, optim_only=not model_state_dict, options=options) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: tuple[torch.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/toolbox/verl/v0.5.0/verl/third_party/vllm/__init__.py b/toolbox/verl/v0.5.0/verl/third_party/vllm/__init__.py new file mode 100644 index 000000000..d0fdeb208 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/third_party/vllm/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib.metadata import PackageNotFoundError, version + +from packaging import version as vs + +from verl.utils.import_utils import is_sglang_available + + +def get_version(pkg): + try: + ver = version(pkg) + if "+" in ver: + return ver.split("+")[0] + else: + return version(pkg) + except PackageNotFoundError: + return None + + +package_name = "vllm" +package_version = get_version(package_name) +vllm_version = None + +if package_version is None: + if not is_sglang_available(): + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) +elif vs.parse(package_version) >= vs.parse("0.7.0"): + vllm_version = package_version + from vllm import LLM + from vllm.distributed import parallel_state +else: + if vs.parse(package_version) in [vs.parse("0.5.4"), vs.parse("0.6.3")]: + raise ValueError( + f"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer " + f"supported. Please use vLLM 0.7.0 or later." + ) + if not is_sglang_available(): + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) + +__all__ = ["LLM", "parallel_state"] diff --git a/toolbox/verl/v0.5.0/verl/tools/__init__.py b/toolbox/verl/v0.5.0/verl/tools/__init__.py new file mode 100644 index 000000000..c4b932b1a --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/tools/base_tool.py b/toolbox/verl/v0.5.0/verl/tools/base_tool.py new file mode 100644 index 000000000..9a1189d20 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/base_tool.py @@ -0,0 +1,92 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.rollout_trace import rollout_trace_op + +from .schemas import OpenAIFunctionToolSchema + + +class BaseTool: + """Base class for tools. + + A tool should support the following methods: + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + self.config = config + self.tool_schema = tool_schema or self.get_openai_tool_schema() + assert self.tool_schema is not None, "Tool schema is not set!" + self.name = self.tool_schema.function.name + print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + """Execute the tool. + + Args: + instance_id: The instance id of the tool. + parameters: The json string of the parameters of the tool. + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + return "Updated the tool state.", 0.0, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + """Calculate the reward of the tool. + + Args: + instance_id: The instance id of the tool. + + Returns: + The reward of the tool. + """ + return 0.0 + + async def release(self, instance_id: str, **kwargs) -> None: + """Release the tool instance. + + Args: + instance_id: The instance id of the tool. + """ + pass diff --git a/toolbox/verl/v0.5.0/verl/tools/geo3k_tool.py b/toolbox/verl/v0.5.0/verl/tools/geo3k_tool.py new file mode 100644 index 000000000..6ffd6fb2c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/geo3k_tool.py @@ -0,0 +1,99 @@ +# Copyright 2023-2025 SGLang Team +# Copyright Amazon.com, Inc. or its affiliates. +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import geo3k +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Geo3kTool(BaseTool): + """A demo tool for calculating the reward of geo3k. + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_geo3k_reward", + "description": "A tool for calculating the reward of geo3k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question, enclosed in \\boxed{}", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id, None + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + self._instance_dict[instance_id]["response"] = answer + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + return f"Current parsed {answer=} {reward=}", tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return geo3k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + use_boxed=False, + format_score=0.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/toolbox/verl/v0.5.0/verl/tools/gsm8k_tool.py b/toolbox/verl/v0.5.0/verl/tools/gsm8k_tool.py new file mode 100644 index 000000000..f6d89134d --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/gsm8k_tool.py @@ -0,0 +1,106 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kTool(BaseTool): + """A demo tool for calculating the reward of gsm8k. + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + + if answer.startswith("#### "): + self._instance_dict[instance_id]["response"] = answer + else: + self._instance_dict[instance_id]["response"] = "#### " + answer + + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + + return f"Current parsed {answer=} {reward=}", tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", + format_score=0.0, + score=1.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/toolbox/verl/v0.5.0/verl/tools/mcp_base_tool.py b/toolbox/verl/v0.5.0/verl/tools/mcp_base_tool.py new file mode 100644 index 000000000..dacd18ebe --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/mcp_base_tool.py @@ -0,0 +1,116 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from fastmcp.exceptions import ClientError + +from verl.tools.utils.mcp_clients.McpClientManager import ClientManager +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPBaseTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self._instance_dict = {} + self.timeout = config.get("timeout", 30) + + # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool + logger.info(f"Initialized MCPBaseTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id + + async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: + err_msg = "" + try: + call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout) + except ClientError as e: + err_msg = f"\n Tool call failed: {e}" + except ConnectionError as e: + err_msg = f"\n Connection failed: {e}" + except Exception as e: + err_msg = f"\n An unexpected error occurred: {e}" + + logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}") + result, metadata = self._parse_tool_result(call_tool_result.content) + metadata["api_request_error"] += err_msg + return result, metadata + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + if self.name == "" or self.name is None or parameters is None: + error_msg = "Error: 'parameters' is missing or empty." + logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}") + return json.dumps({"result": error_msg}), 0.0, {} + + try: + result_text, metadata = await self._call_tool(instance_id, parameters) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } + + return result_text, 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Tool execution failed: {e}"}) + logger.error(f"[MCPBaseTool] Execution failed: {e}") + return error_result, 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + tools_content = [part.text for part in filter(lambda x: x.type == "text", content)] + return " ".join(tools_content), {} diff --git a/toolbox/verl/v0.5.0/verl/tools/mcp_search_tool.py b/toolbox/verl/v0.5.0/verl/tools/mcp_search_tool.py new file mode 100644 index 000000000..ac823719b --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/mcp_search_tool.py @@ -0,0 +1,69 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import re + +from verl.tools.mcp_base_tool import MCPBaseTool + +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPSearchTool(MCPBaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + res = "" + res_cnt = 0 + query_list = [] + metadata = { + "api_request_error": "", + "status": "unknown", + "total_results": 0, + } + try: + for part in content: + if part.type != "text": + continue + text = part.text.replace("'", '"') + query_match = re.search(r'query"\s*:\s*"([^"]+)"', text) + query = query_match.group(1) if query_match else "" + query_list.append(query) + + title_matches = re.findall(r'"title"\s*:', text) + title_count = len(title_matches) + + results_match = re.search(r'"results"\s*:\s*(\[.*?\])', text, re.DOTALL) + results_content = results_match.group(1) if results_match else "" + + res += results_content + res_cnt += title_count + except json.JSONDecodeError: + err_msg = "json parse error." + logger.error(err_msg) + metadata["api_request_error"] = err_msg + metadata["status"] = "error" + + # update metadata + metadata["status"] = "success" + metadata["queries"] = query_list + metadata["query_count"] = len(query_list) + metadata["total_results"] = res_cnt + return res, metadata diff --git a/toolbox/verl/v0.5.0/verl/tools/sandbox_fusion_tools.py b/toolbox/verl/v0.5.0/verl/tools/sandbox_fusion_tools.py new file mode 100644 index 000000000..c3a2748d9 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/sandbox_fusion_tools.py @@ -0,0 +1,193 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from typing import Any, Callable, Optional, TypeVar +from uuid import uuid4 + +import ray + +from verl.tools.base_tool import BaseTool +from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case +from verl.utils.rollout_trace import rollout_trace_op + +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +class PoolMode(Enum): + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + # this only used for observalability + self.current_count = 0 + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + return self.current_count + + +class ExecutionWorker: + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + # TODO validation for rate_limit + # A Singleton Rate Limitor + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing code: {e}") + + +def init_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): + if mode == PoolMode.ThreadMode: + return ( + ray.remote(ExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) + else: + raise NotImplementedError("Process mode is not implemented yet") + # return ray.util.multiprocessing.Pool(processes=num_workers) + + +class SandboxFusionTool(BaseTool): + """A tool for executing the code using sanbox fusion image. + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "code_interpreter", + "description": "A tool for execute code", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "code needs to be execute and grad", + }, + }, + "required": ["code"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + # TODO: better documentation for the config + self.num_workers = config.get("num_workers", 10) + self.rate_limit = config.get("rate_limit", 10) + self.default_timeout = config.get("default_timeout", 30) + self.default_language = config.get("default_language", "python") + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) + self.sandbox_fusion_url = config.get("sandbox_fusion_url", "") + self.memory_limit_mb = config.get("memory_limit_mb", 1024) + if self.sandbox_fusion_url == "": + raise ValueError("sandbox_fusion_url is not set") + log_msg = f"Init SandboxFusionTool with config: {config}" + logger.info(log_msg) + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": [], + } + return instance_id + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + code = parameters.get("code", "") + timeout = parameters.get("timeout", self.default_timeout) + language = parameters.get("language", self.default_language) + if not isinstance(code, str): + code = str(code) + + result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + # sandbox has no score or metrics, use Nones + return result, None, None + + def execute_code(self, instance_id, code, timeout=30, language="python"): + result_status, metadata = _process_single_case( + 0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language + ) + # we should always expect this since we don't have correct answer + if metadata["run_status"] == "Finished": + actual_output = metadata["stdout"] + metadata["stderr"] + logger.debug(f"actual_output from sandbox fusion: {actual_output},{instance_id}") + return actual_output + else: + return "no stdout here" + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/toolbox/verl/v0.5.0/verl/tools/schemas.py b/toolbox/verl/v0.5.0/verl/tools/schemas.py new file mode 100644 index 000000000..c0c65a30e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/schemas.py @@ -0,0 +1,89 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Literal + +from pydantic import BaseModel + + +class OpenAIFunctionPropertySchema(BaseModel): + """The schema of a parameter in OpenAI format.""" + + type: str + description: str | None = None + enum: list[str] | None = None + + +class OpenAIFunctionParametersSchema(BaseModel): + """The schema of parameters in OpenAI format.""" + + type: str + properties: dict[str, OpenAIFunctionPropertySchema] + required: list[str] + + +class OpenAIFunctionSchema(BaseModel): + """The schema of a function in OpenAI format.""" + + name: str + description: str + parameters: OpenAIFunctionParametersSchema + strict: bool = False + + +class OpenAIFunctionToolSchema(BaseModel): + """The schema of a tool in OpenAI format.""" + + type: str + function: OpenAIFunctionSchema + + +class OpenAIFunctionParsedSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: str # JSON string + + +class OpenAIFunctionCallSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: dict[str, Any] + + @staticmethod + def from_openai_function_parsed_schema( + parsed_schema: OpenAIFunctionParsedSchema, + ) -> tuple["OpenAIFunctionCallSchema", bool]: + has_decode_error = False + try: + arguments = json.loads(parsed_schema.arguments) + except json.JSONDecodeError: + arguments = {} + has_decode_error = True + # If the arguments is not a dict, it means the arguments is not a valid JSON string + if not isinstance(arguments, dict): + arguments = {} + has_decode_error = True + + return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error + + +class OpenAIFunctionToolCall(BaseModel): + """The tool call in OpenAI format.""" + + id: str + type: Literal["function"] = "function" + function: OpenAIFunctionCallSchema diff --git a/toolbox/verl/v0.5.0/verl/tools/search_tool.py b/toolbox/verl/v0.5.0/verl/tools/search_tool.py new file mode 100644 index 000000000..3cc6cda53 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/search_tool.py @@ -0,0 +1,278 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from typing import Any, Callable, Optional, TypeVar +from uuid import uuid4 + +import ray +import ray.actor + +from verl.tools.utils.search_r1_like_utils import perform_single_search_batch +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +# Adapted from verl/tools/sandbox_fusion_tools.py +class PoolMode(Enum): + """Execution pool mode enumeration.""" + + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + """Ray actor for rate limiting using token bucket algorithm.""" + + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + self.current_count = 0 # For observability + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + """Acquire a token from the bucket.""" + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + """Release a token back to the bucket.""" + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + """Get current number of acquired tokens.""" + return self.current_count + + +class SearchExecutionWorker: + """Worker for executing search operations with optional rate limiting.""" + + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + """Initialize singleton rate limiter.""" + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + """Health check method.""" + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + """Execute function with optional rate limiting.""" + if self.rate_limit_worker: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing search: {e}") + else: + return fn(*fn_args, **fn_kwargs) + + +def init_search_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): + """Initialize search execution pool.""" + if mode == PoolMode.ThreadMode: + return ( + ray.remote(SearchExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) + else: + raise NotImplementedError("Process mode is not implemented yet") + + +class SearchTool(BaseTool): + """Search tool for retrieving information using external retrieval services. + + This tool provides search functionality with rate limiting and concurrent execution + support through Ray. It integrates with external retrieval services to perform + semantic search operations. + + Methods: + get_openai_tool_schema: Return the tool schema in OpenAI format + create: Create a tool instance for a trajectory + execute: Execute the search tool + calc_reward: Calculate the reward with respect to tool state + release: Release the tool instance + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """Initialize SearchTool with configuration and schema. + + Args: + config: Configuration dictionary containing tool settings + tool_schema: OpenAI function tool schema definition + + Example tool_schema: + { + "type": "function", + "function": { + "name": "search", + "description": "Searches for relevant information based on queries.", + "parameters": { + "type": "object", + "properties": { + "query_list": { + "type": "array", + "items": {"type": "string"}, + "description": "List of search queries" + } + }, + "required": ["query_list"] + } + } + } + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + # Worker and rate limiting configuration + self.num_workers = config.get("num_workers", 120) + self.rate_limit = config.get("rate_limit", 120) + self.timeout = config.get("timeout", 30) + + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_search_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) + + # Retrieval service configuration + self.retrieval_service_url = config.get("retrieval_service_url") + assert self.retrieval_service_url, "Configuration must include 'retrieval_service_url'" + self.topk = config.get("topk", 3) + if self.retrieval_service_url == "": + raise ValueError("retrieval_service_url is not set") + + logger.info(f"Initialized SearchTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id + + def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int): + """Execute search operation using retrieval service. + + Args: + instance_id: Tool instance ID + query_list: List of search queries + retrieval_service_url: URL of the retrieval service + topk: Number of top results to return + timeout: Request timeout in seconds + + Returns: + Tuple of (result_text, metadata) + """ + result_text, metadata = perform_single_search_batch( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + concurrent_semaphore=None, # Ray handles concurrency control + timeout=timeout, + ) + logger.debug(f"Search result for instance {instance_id}: {result_text}") + return result_text, metadata + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + """Execute the search tool. + + Args: + instance_id: The instance ID of the tool + parameters: Tool parameters containing query_list and optional timeout + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + timeout = self.timeout + query_list_from_params = parameters.get("query_list") + + if not query_list_from_params or not isinstance(query_list_from_params, list): + error_msg = "Error: 'query_list' is missing, empty, or not a list in parameters." + logger.error(f"[SearchTool] {error_msg} Received parameters: {parameters}") + return json.dumps({"result": error_msg}), 0.0, {} + + # Execute search using Ray execution pool + try: + result_text, metadata = await self.execution_pool.execute.remote( + self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout + ) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } + + return result_text, 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Search execution failed: {e}"}) + logger.error(f"[SearchTool] Execution failed: {e}") + return error_result, 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] diff --git a/toolbox/verl/v0.5.0/verl/tools/utils/__init__.py b/toolbox/verl/v0.5.0/verl/tools/utils/__init__.py new file mode 100644 index 000000000..c4b932b1a --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/McpClientManager.py b/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/McpClientManager.py new file mode 100644 index 000000000..ee5fe3119 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/McpClientManager.py @@ -0,0 +1,97 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import logging +from typing import Any + +from fastmcp import Client +from fastmcp.client.transports import SSETransport + +from verl.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai + +logger = logging.getLogger(__name__) + + +class MCPClientManager: + rootServerName = "mcpServers" + initialized = False + clients = [] + tool_client_mapping = {} + rate_limiter = None + + async def initialize(self, config_path, rate_limit: float = 10.0): + if self.initialized: + return + """Initialize the MCP Client Manager and start all clients""" + result = self._load_config(config_path) + servers = result[self.rootServerName] + exclude_sse_servers = {self.rootServerName: {}} + for server_name in servers.keys(): + server = servers[server_name] + if "auth_token" in server: + transport = SSETransport(url=server["url"], headers={"Authorization": f"Bearer {server['auth_token']}"}) + client = Client(transport) + self.clients.append(client) + else: + exclude_sse_servers[self.rootServerName][server_name] = server + + if exclude_sse_servers[self.rootServerName]: + self.clients.append(Client(exclude_sse_servers)) + + # Initialize rate limiter + self.rate_limiter = TokenBucket(rate_limit) + self.initialized = True + + async def call_tool(self, tool_name, parameters, timeout): + # Apply rate limiting + while not self.rate_limiter.acquire(): + await asyncio.sleep(0.1) + + client = self.get_client_with_tool_name(tool_name) + async with client: + return await client.call_tool_mcp(tool_name, parameters) + + async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]: + tool_schemas = [] + for client in self.clients: + async with client: + tools = await client.list_tools_mcp() + for tool in tools.tools: + if not tool_selected_list: + self.tool_client_mapping[tool.name] = client + tool_schemas.append(mcp2openai(tool)) + elif tool.name in tool_selected_list: + self.tool_client_mapping[tool.name] = client + tool_schemas.append(mcp2openai(tool)) + + return tool_schemas + + def get_client_with_tool_name(self, tool_name: str): + return self.tool_client_mapping[tool_name] + + def _load_config(self, file: str) -> dict[str, Any]: + try: + with open(file) as f: + return json.load(f) + except FileNotFoundError: + logger.warning(f'the "{file}" file was not found') + except Exception: + logger.error(f'there was an error reading the "{file}" file') + + return {} + + +ClientManager = MCPClientManager() diff --git a/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/utils.py b/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/utils.py new file mode 100644 index 000000000..22a5f6353 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/utils/mcp_clients/utils.py @@ -0,0 +1,58 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time + +from mcp import Tool + +logger = logging.getLogger(__file__) + + +class TokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit # tokens per second + self.tokens = rate_limit + self.last_update = time.time() + self.lock = threading.Lock() + + def acquire(self) -> bool: + with self.lock: + now = time.time() + # Add new tokens based on time elapsed + new_tokens = (now - self.last_update) * self.rate_limit + self.tokens = min(self.rate_limit, self.tokens + new_tokens) + self.last_update = now + + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + +def mcp2openai(mcp_tool: Tool) -> dict: + """Convert a MCP Tool to an OpenAI ChatCompletionTool.""" + openai_format = { + "type": "function", + "function": { + "name": mcp_tool.name, + "description": mcp_tool.description, + "parameters": mcp_tool.inputSchema, + "strict": False, + }, + } + if not openai_format["function"]["parameters"].get("required", None): + openai_format["function"]["parameters"]["required"] = [] + return openai_format diff --git a/toolbox/verl/v0.5.0/verl/tools/utils/search_r1_like_utils.py b/toolbox/verl/v0.5.0/verl/tools/utils/search_r1_like_utils.py new file mode 100644 index 000000000..23669e44c --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/utils/search_r1_like_utils.py @@ -0,0 +1,243 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import threading +import time +import traceback +import uuid +from typing import Any, Optional + +import requests + +DEFAULT_TIMEOUT = 30 # Default search request timeout +MAX_RETRIES = 10 +INITIAL_RETRY_DELAY = 1 +API_TIMEOUT = 10 + +logger = logging.getLogger(__name__) + + +def call_search_api( + retrieval_service_url: str, + query_list: list[str], + topk: int = 3, + return_scores: bool = True, + timeout: int = DEFAULT_TIMEOUT, +) -> tuple[Optional[dict[str, Any]], Optional[str]]: + """ + Calls the remote search API to perform retrieval with retry logic for various errors, + using increasing delay between retries. Logs internal calls with a unique ID. + + Args: + retrieval_service_url: The URL of the retrieval service API. + query_list: List of search queries. + topk: Number of top results to return. + return_scores: Whether to return scores. + timeout: Request timeout in seconds. + + Returns: + A tuple (response_json, error_message). + If successful, response_json is the API's returned JSON object, error_message is None. + If failed after retries, response_json is None, error_message contains the error information. + """ + request_id = str(uuid.uuid4()) + log_prefix = f"[Search Request ID: {request_id}] " + + payload = {"queries": query_list, "topk": topk, "return_scores": return_scores} + + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + last_error = None + + for attempt in range(MAX_RETRIES): + try: + logger.info( + f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}" + ) + response = requests.post( + retrieval_service_url, + headers=headers, + json=payload, + timeout=timeout, + ) + + # Check for Gateway Timeout (504) and other server errors for retrying + if response.status_code in [500, 502, 503, 504]: + last_error = ( + f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt " + f"{attempt + 1}/{MAX_RETRIES}" + ) + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + + # Check for other HTTP errors (e.g., 4xx) + response.raise_for_status() + + # If successful (status code 2xx) + logger.info(f"{log_prefix}Search API call successful on attempt {attempt + 1}") + return response.json(), None + + except requests.exceptions.ConnectionError as e: + last_error = f"{log_prefix}Connection Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.Timeout as e: + last_error = f"{log_prefix}Timeout Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.RequestException as e: + last_error = f"{log_prefix}API Request Error: {e}" + break # Exit retry loop on other request errors + except json.JSONDecodeError as e: + raw_response_text = response.text if "response" in locals() else "N/A" + last_error = f"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}" + break # Exit retry loop on JSON decode errors + except Exception as e: + last_error = f"{log_prefix}Unexpected Error: {e}" + break # Exit retry loop on other unexpected errors + + # If loop finishes without returning success, return the last recorded error + logger.error(f"{log_prefix}Search API call failed. Last error: {last_error}") + return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" + + +def _passages2string(retrieval_result): + """Convert retrieval results to formatted string.""" + format_reference = "" + for idx, doc_item in enumerate(retrieval_result): + content = doc_item["document"]["contents"] + title = content.split("\n")[0] + text = "\n".join(content.split("\n")[1:]) + format_reference += f"Doc {idx + 1} (Title: {title})\n{text}\n\n" + return format_reference.strip() + + +def perform_single_search_batch( + retrieval_service_url: str, + query_list: list[str], + topk: int = 3, + concurrent_semaphore: Optional[threading.Semaphore] = None, + timeout: int = DEFAULT_TIMEOUT, +) -> tuple[str, dict[str, Any]]: + """ + Performs a single batch search for multiple queries (original search tool behavior). + + Args: + retrieval_service_url: The URL of the retrieval service API. + query_list: List of search queries. + topk: Number of top results to return. + concurrent_semaphore: Optional semaphore for concurrency control. + timeout: Request timeout in seconds. + + Returns: + A tuple (result_text, metadata). + result_text: The search result JSON string. + metadata: Metadata dictionary for the batch search. + """ + logger.info(f"Starting batch search for {len(query_list)} queries.") + + api_response = None + error_msg = None + + try: + if concurrent_semaphore: + with concurrent_semaphore: + api_response, error_msg = call_search_api( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + return_scores=True, + timeout=timeout, + ) + else: + api_response, error_msg = call_search_api( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + return_scores=True, + timeout=timeout, + ) + except Exception as e: + error_msg = f"API Request Exception during batch search: {e}" + logger.error(f"Batch search: {error_msg}") + traceback.print_exc() + + metadata = { + "query_count": len(query_list), + "queries": query_list, + "api_request_error": error_msg, + "api_response": None, + "status": "unknown", + "total_results": 0, + "formatted_result": None, + } + + result_text = json.dumps({"result": "Search request failed or timed out after retries."}) + + if error_msg: + metadata["status"] = "api_error" + result_text = json.dumps({"result": f"Search error: {error_msg}"}) + logger.error(f"Batch search: API error occurred: {error_msg}") + elif api_response: + logger.debug(f"Batch search: API Response: {api_response}") + metadata["api_response"] = api_response + + try: + raw_results = api_response.get("result", []) + if raw_results: + pretty_results = [] + total_results = 0 + + for retrieval in raw_results: + formatted = _passages2string(retrieval) + pretty_results.append(formatted) + total_results += len(retrieval) if isinstance(retrieval, list) else 1 + + final_result = "\n---\n".join(pretty_results) + result_text = json.dumps({"result": final_result}) + metadata["status"] = "success" + metadata["total_results"] = total_results + metadata["formatted_result"] = final_result + logger.info(f"Batch search: Successful, got {total_results} total results") + else: + result_text = json.dumps({"result": "No search results found."}) + metadata["status"] = "no_results" + metadata["total_results"] = 0 + logger.info("Batch search: No results found") + except Exception as e: + error_msg = f"Error processing search results: {e}" + result_text = json.dumps({"result": error_msg}) + metadata["status"] = "processing_error" + logger.error(f"Batch search: {error_msg}") + else: + metadata["status"] = "unknown_api_state" + result_text = json.dumps({"result": "Unknown API state (no response and no error message)."}) + logger.error("Batch search: Unknown API state.") + + return result_text, metadata diff --git a/toolbox/verl/v0.5.0/verl/tools/utils/tool_registry.py b/toolbox/verl/v0.5.0/verl/tools/utils/tool_registry.py new file mode 100644 index 000000000..5c14d1016 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/tools/utils/tool_registry.py @@ -0,0 +1,107 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib +import logging +import os +import sys +from enum import Enum + +from omegaconf import OmegaConf + +from verl.tools.schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class ToolType(Enum): + NATIVE = "native" + MCP = "mcp" + + +async def initialize_mcp_tool(tool_cls, tool_config) -> list: + from verl.tools.utils.mcp_clients.McpClientManager import ClientManager + + tool_list = [] + mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path + tool_selected_list = tool_config.mcp.tool_selected_list if "tool_selected_list" in tool_config.mcp else None + await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit) + # Wait for MCP client to be ready + max_retries = 10 + retry_interval = 2 # seconds + for i in range(max_retries): + tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list) + if tool_schemas: + break + if i < max_retries - 1: + logger.debug(f"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}") + await asyncio.sleep(retry_interval) + else: + raise RuntimeError("Failed to initialize MCP tools after maximum retries") + # mcp registry + assert len(tool_schemas), "mcp tool is empty" + for tool_schema_dict in tool_schemas: + logger.debug(f"tool_schema_dict: {tool_schema_dict}") + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) + tool_list.append(tool) + return tool_list + + +def get_tool_class(cls_name): + module_name, class_name = cls_name.rsplit(".", 1) + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + return tool_cls + + +def initialize_tools_from_config(tools_config_file): + tools_config = OmegaConf.load(tools_config_file) + tool_list = [] + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + tool_type = ToolType(tool_config.config.type) + tool_cls = get_tool_class(cls_name) + + match tool_type: + case ToolType.NATIVE: + if tool_config.get("tool_schema", None) is None: + tool_schema = None + else: + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) + tool_list.append(tool) + case ToolType.MCP: + loop = asyncio.get_event_loop() + mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config)) + tool_list.extend(mcp_tools) + case _: + raise NotImplementedError + return tool_list diff --git a/toolbox/verl/v0.5.0/verl/trainer/__init__.py b/toolbox/verl/v0.5.0/verl/trainer/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/__init__.py b/toolbox/verl/v0.5.0/verl/trainer/config/__init__.py new file mode 100644 index 000000000..121e05d5e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig +from .config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig + +__all__ = [ + "AlgoConfig", + "CriticConfig", + "FilterGroupsConfig", + "FSDPCriticConfig", + "KLControlConfig", + "MegatronCriticConfig", + "PFPPOConfig", +] diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_megatron_trainer.yaml new file mode 100644 index 000000000..91371321e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -0,0 +1,418 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + strategy: megatron + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + policy_loss: + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + entropy_coeff: 0 + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + optim: + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + optimizer: adam + clip_grad: 1.0 + lr_warmup_init: 0.0 + lr_warmup_steps: null + lr_decay_steps: null + lr_decay_style: constant + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + data_loader_seed: null + load_weight: true + megatron: + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + seed: 42 + override_ddp_config: {} + override_transformer_config: + recompute_granularity: null + recompute_modules: + - core_attn + recompute_method: null + recompute_num_layers: null + use_mbridge: false + profile: + use_profile: false + profile_ranks: null + step_start: -1 + step_end: -1 + save_path: null + ref: + strategy: megatron + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + megatron: + param_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: None + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: false + use_dist_checkpointing: false + dist_checkpointing_path: null + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + profile: + use_profile: false + profile_ranks: null + step_start: -1 + step_end: -1 + save_path: null + load_weight: true + rollout: + name: vllm + mode: sync + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: true + free_cache_engine: true + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + multi_stage_wake_up: false + engine_kwargs: + vllm: + swap_space: null + disable_mm_preprocessor_cache: false + sglang: + attention_backend: null + val_kwargs: + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + completion_callback: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + calculate_log_probs: false + agent: + num_workers: 8 + agent_loop_config_path: null + custom_async_server: + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + backend: null + token2text: false + enable_chunked_prefill: false + load_format: dummy_megatron + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + hybrid_engine: true + nccl_timeout: 600 + model: + path: ~/models/deepseek-llm-7b-chat + custom_chat_template: null + external_lib: null + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + use_fused_kernels: false + trust_remote_code: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] +trainer: + npu_profile: + options: + save_path: ./profiler_data + level: level1 + with_memory: false + record_shapes: false + with_npu: true + with_cpu: true + with_module: false + with_stack: false + analysis: true + balance_batch: true + total_epochs: 30 + total_training_steps: null + profile_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + del_local_ckpt_after_load: false + val_before_train: true + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + return_raw_input_ids: false + return_raw_chat: false + return_full_prompt: false + shuffle: true + dataloader_num_workers: 8 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null +critic: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + optim: + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + optimizer: adam + lr: 1.0e-06 + clip_grad: 1.0 + lr_warmup_init: 0.0 + lr_warmup_steps: null + lr_decay_steps: null + lr_decay_style: linear + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + _target_: verl.trainer.config.MegatronCriticConfig + nccl_timeout: 600 + megatron: + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_ddp_config: ${oc.select:actor_rollout_ref.actor.megatron.override_ddp_config,{}} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + load_weight: true + data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} +reward_model: + enable: false + strategy: megatron + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + nccl_timeout: 600 + megatron: + param_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: false + use_dist_checkpointing: false + dist_checkpointing_path: null + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + load_weight: true +custom_reward_function: + path: null + name: compute_score +algorithm: + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow + weight_pow: 2.0 +ray_init: + num_cpus: null + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_trainer.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_trainer.yaml new file mode 100644 index 000000000..d7c71736f --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/_generated_ppo_trainer.yaml @@ -0,0 +1,368 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + policy_loss: + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + entropy_coeff: 0 + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + optim: + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + min_lr_ratio: 0.0 + num_cycles: 0.5 + warmup_style: constant + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + ref: + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + fsdp_config: + param_offload: false + reshard_after_forward: true + forward_prefetch: false + wrap_policy: + min_num_params: 0 + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + name: vllm + mode: sync + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: true + free_cache_engine: true + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + multi_stage_wake_up: false + engine_kwargs: + vllm: + swap_space: null + disable_mm_preprocessor_cache: false + sglang: + attention_backend: null + val_kwargs: + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + completion_callback: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + calculate_log_probs: false + agent: + num_workers: 8 + agent_loop_config_path: null + custom_async_server: + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + backend: null + token2text: false + enable_chunked_prefill: true + load_format: dummy_dtensor + layered_summon: false + hybrid_engine: true + model: + path: ~/models/deepseek-llm-7b-chat + custom_chat_template: null + use_shm: false + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + trust_remote_code: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] +trainer: + npu_profile: + options: + save_path: ./profiler_data + level: level1 + with_memory: false + record_shapes: false + with_npu: true + with_cpu: true + with_module: false + with_stack: false + analysis: true + balance_batch: true + total_epochs: 30 + total_training_steps: null + profile_steps: null + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + return_raw_input_ids: false + return_raw_chat: false + return_full_prompt: false + shuffle: true + dataloader_num_workers: 8 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null +critic: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + optim: + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr: 1.0e-05 + min_lr_ratio: null + warmup_style: constant + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + fsdp_config: + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + fsdp_size: -1 + forward_prefetch: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + _target_: verl.trainer.config.FSDPCriticConfig + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + ulysses_sequence_parallel_size: 1 +custom_reward_function: + path: null + name: compute_score +algorithm: + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow + weight_pow: 2.0 +ray_init: + num_cpus: null + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/actor/actor.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/actor/actor.yaml new file mode 100644 index 000000000..d5402d870 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/actor/actor.yaml @@ -0,0 +1,111 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# the abstract actor configs +# fsdp, fsdp2 or megatron. must be set. +strategy: ??? + +# Split each sample into sub-batches of this size for PPO +ppo_mini_batch_size: 256 + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: null + +# Whether to automatically adjust batch size at runtime +# oc.select: the default val for ref.log_prob_use_dynamic_bsz +use_dynamic_bsz: false + +# Max tokens per GPU in one PPO batch; affects gradient accumulation +# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} +# oc.select: the default val for ref.log_prob_max_token_len_per_gpu +ppo_max_token_len_per_gpu: 16384 + +# PPO clip ratio +clip_ratio: 0.2 + +# Lower bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_low: 0.2 + +# Upper bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_high: 0.2 + +# policy loss config +policy_loss: + + # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + loss_mode: "vanilla" + + # Ratio of tokens to be clipped for clip-cov loss + clip_cov_ratio: 0.0002 + + # Lower bound for clip-cov loss + clip_cov_lb: 1.0 + + # Upper bound for clip-cov loss + clip_cov_ub: 5.0 + + # Ratio of tokens to be applied kl penalty for kl-cov loss + kl_cov_ratio: 0.0002 + + # KL divergence penalty coefficient + ppo_kl_coef: 0.1 + +# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C +clip_ratio_c: 3.0 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: token-mean + +# Entropy regularization coefficient in PPO loss +entropy_coeff: 0 + +# Whether to use KL loss instead of KL reward penalty. True for GRPO +use_kl_loss: false + +# Whether to use torch.compile() +# oc.select: the default val for ref.use_torch_compile +use_torch_compile: true + +# KL loss coefficient when use_kl_loss is enabled. For GRPO +kl_loss_coef: 0.001 + +# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" +kl_loss_type: low_var_kl + +# Number of PPO epochs per batch +ppo_epochs: 1 + +# Shuffle training data across PPO epochs +shuffle: false + +# checkpoint configs +checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg + load_contents: ${.save_contents} + +# optimizer configs +optim: + + # Learning rate + lr: 1e-6 + + # Warmup steps ratio (used if lr_warmup_steps is negative) + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/actor/dp_actor.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/actor/dp_actor.yaml new file mode 100644 index 000000000..f298c3cfa --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/actor/dp_actor.yaml @@ -0,0 +1,73 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/actor/actor.yaml + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# TODO(haibin.lin): switch to fsdp2 +strategy: fsdp + +# Gradient clipping for actor updates, specific to the strategy. +grad_clip: 1.0 + +# Sequence parallelism size for Ulysses-style model parallelism +# oc.select: the default val for ref.ulysses_sequence_parallel_size +ulysses_sequence_parallel_size: 1 + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False + +# optimizer configs +optim: + + # Warmup steps; negative value delegates to lr_warmup_steps_ratio + lr_warmup_steps: -1 + + # Minimum LR ratio for cosine schedule + min_lr_ratio: 0.0 + + # Number of cosine cycles in LR schedule + num_cycles: 0.5 + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + +# configs for FSDP +fsdp_config: + + # policy for wrapping the model + wrap_policy: + + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + + # Whether to offload model parameters to CPU (trades speed for memory) + param_offload: false + + # Whether to offload optimizer state to CPU + optimizer_offload: false + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: false + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: true + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/actor/megatron_actor.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/actor/megatron_actor.yaml new file mode 100644 index 000000000..dca4d3b58 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/actor/megatron_actor.yaml @@ -0,0 +1,120 @@ +# megatron actor config, inheriting from trainer/config/actor/actor.yaml +defaults: + - actor + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +data_loader_seed: null + +load_weight: True + +checkpoint: + + async_save: False + +optim: + + optimizer: adam + + clip_grad: 1.0 + + # initial learning rate for warmup, default to 0.0 + lr_warmup_init: 0.0 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: null + + lr_decay_steps: null + + # select from constant/linear/cosine/inverse_square_root + lr_decay_style: constant + + # minimum learning rate, default to 0.0 + min_lr: 0.0 + + # select from constant/linear/cosine + weight_decay_incr_style: constant + + # select from constant/exponential/cosine + lr_wsd_decay_style: exponential + + lr_wsd_decay_steps: null + + # use checkpoint optimizer parameter scheduler + use_checkpoint_opt_param_scheduler: False + +megatron: + + param_offload: False + + grad_offload: False + + optimizer_offload: False + + tensor_model_parallel_size: 1 + + expert_model_parallel_size: 1 + + expert_tensor_parallel_size: null + + pipeline_model_parallel_size: 1 + + virtual_pipeline_model_parallel_size: null + + context_parallel_size: 1 + + sequence_parallel: True + + use_distributed_optimizer: True + + use_dist_checkpointing: False + + dist_checkpointing_path: null + + # oc.select: default val for ref.megatron.seed + seed: 42 + + # Allow to override Distributed Data Parallel (DDP) config + override_ddp_config: {} + + # additional transformer config like: num_layers_in_first(/last)_pipeline_stage + # oc.select: default val for ref.megatron.override_transformer_config + override_transformer_config: + # Recompute configuration, same as in megatron.training.arguments + # default use minimal performance-interference recompute methods + # Recompute granualarity, choices: ["full", "selective"] + recompute_granularity: null + + # Recompute modules, multiple choices: ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"] + # Please use correct module in matched model + recompute_modules: ["core_attn"] + + # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + recompute_method: null + + # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention + recompute_num_layers: null + + # oc.select: default val for ref.megatron.use_mbridge + use_mbridge: False + +# profile the actor model in `update_policy` +profile: + # turn it on when you want to profile the actor model + use_profile: False + + # list, you can specify the ranks to profile + profile_ranks: null + + # start step in update_policy + step_start: -1 + + # end step + step_end: -1 + + # the path to save the profile result + save_path: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/algorithm.py b/toolbox/verl/v0.5.0/verl/trainer/config/algorithm.py new file mode 100644 index 000000000..5bc6cf943 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/algorithm.py @@ -0,0 +1,114 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from verl.base_config import BaseConfig + + +@dataclass +class KLControlConfig(BaseConfig): + """Configuration for KL control. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + type (str): Type of KL control. Can be "fixed" or "adaptive". + kl_coef (float): Initial coefficient for KL penalty. + horizon (int): Horizon value for adaptive controller. + target_kl (float): Target KL divergence for adaptive controller. + """ + + _frozen_fields = ["type", "kl_coef", "horizon", "target_kl"] + type: str = "fixed" + kl_coef: float = 0.001 + horizon: int = 10000 + target_kl: float = 0.1 + + +@dataclass +class PFPPOConfig(BaseConfig): + """Configuration for preference feedback PPO. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + reweight_method (str): Method for reweighting samples. Can be "pow", "max_min", or "max_random". + weight_pow (float): Power used for weight scaling in "pow" method. + """ + + _frozen_fields = ["reweight_method", "weight_pow"] + reweight_method: str = "pow" + weight_pow: float = 2.0 + + +@dataclass +class FilterGroupsConfig(BaseConfig): + """Configuration for filter groups (used in DAPO and Entropy). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + enable (bool): Whether to enable filter groups. + metric (Optional[str]): Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. + max_num_gen_batches (int): Non-positive values mean no upper limit. + """ + + _frozen_fields = ["enable", "metric", "max_num_gen_batches"] + + enable: bool = False + metric: Optional[str] = None + max_num_gen_batches: int = 0 + + +@dataclass +class AlgoConfig(BaseConfig): + """Configuration for the algorithm. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + gamma (float): Discount factor for future rewards. + lam (float): Trade-off between bias and variance in the GAE estimator. + adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO). + use_kl_in_reward (bool): Whether to enable in-reward KL penalty. + kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full". + kl_ctrl (KLControlConfig): KL control configuration. + use_pf_ppo (bool): Whether to enable preference feedback PPO. + pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings. + filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy + """ + + _frozen_fields = [ + "gamma", + "lam", + "adv_estimator", + "norm_adv_by_std_in_grpo", + "use_kl_in_reward", + "kl_penalty", + "use_pf_ppo", + ] + + gamma: float = 1.0 + lam: float = 1.0 + adv_estimator: str = "gae" + norm_adv_by_std_in_grpo: bool = True + use_kl_in_reward: bool = False + kl_penalty: str = "kl" + kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) + use_pf_ppo: bool = False + pf_ppo: Optional[PFPPOConfig] = None + filter_groups: Optional[FilterGroupsConfig] = None diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/config.py b/toolbox/verl/v0.5.0/verl/trainer/config/config.py new file mode 100644 index 000000000..63979d7d4 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/config.py @@ -0,0 +1,126 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.base_config import BaseConfig + + +@dataclass +class CriticConfig(BaseConfig): + """Configuration for critic model training. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + rollout_n (int): Number of rollouts per update (mirrors actor rollout_n). + strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron). + optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc. + model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc. + ppo_mini_batch_size (int): PPO mini-batch size per update. + ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated). + ppo_micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size. + use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime. + ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch. + forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass. + ppo_epochs (int): Number of PPO epochs per batch. + shuffle (bool): Shuffle training data across PPO epochs. + cliprange_value (float): PPO value function clipping range. + loss_agg_mode (str): Loss aggregation mode. + checkpoint (Dict[str, Any]): Checkpoint configuration. + profiler (Dict[str, Any]): Profiler configuration. + """ + + # For legacy reason configs related to batch_size are mutated in each role + # In the future they will be added to frozen fields instead + _frozen_fields = [ + "rollout_n", + "strategy", + "use_dynamic_bsz", + "ppo_max_token_len_per_gpu", + "forward_max_token_len_per_gpu", + "ppo_epochs", + "shuffle", + "cliprange_value", + "loss_agg_mode", + ] + + rollout_n: int = 1 + strategy: str = "fsdp" + optim: dict[str, Any] = field(default_factory=dict) + model: dict[str, Any] = field(default_factory=dict) + ppo_mini_batch_size: int = 1 + ppo_micro_batch_size: Optional[int] = None + ppo_micro_batch_size_per_gpu: Optional[int] = None + use_dynamic_bsz: bool = False + ppo_max_token_len_per_gpu: int = 32768 + forward_max_token_len_per_gpu: int = 32768 + ppo_epochs: int = 1 + shuffle: bool = True + cliprange_value: float = 0.5 + loss_agg_mode: str = "token-mean" + checkpoint: dict[str, Any] = field(default_factory=dict) + profiler: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MegatronCriticConfig(CriticConfig): + """Configuration for Megatron-based critic model training. + + The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings. + + Args: + nccl_timeout (int): NCCL timeout in seconds for distributed operations. + megatron (Dict[str, Any]): Megatron-specific parallelism settings. + load_weight (bool): Whether to load initial weights. + data_loader_seed (Optional[int]): Seed for data loader. + """ + + _frozen_fields = CriticConfig._frozen_fields + [ + "nccl_timeout", + "load_weight", + "data_loader_seed", + ] + + strategy: str = "megatron" + nccl_timeout: int = 600 + megatron: dict[str, Any] = field(default_factory=dict) + load_weight: bool = True + data_loader_seed: Optional[int] = None + + +@dataclass +class FSDPCriticConfig(CriticConfig): + """Configuration for FSDP-based critic model training. + + The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings. + + Args: + forward_micro_batch_size (int): Forward-only batch size during inference (global). + forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU). + ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism. + grad_clip (float): Gradient clipping for critic updates. + """ + + _frozen_fields = CriticConfig._frozen_fields + [ + "ulysses_sequence_parallel_size", + "grad_clip", + ] + + strategy: str = "fsdp" + forward_micro_batch_size: int = 1 + forward_micro_batch_size_per_gpu: int = 1 + ulysses_sequence_parallel_size: int = 1 + grad_clip: float = 1.0 diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/critic/critic.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/critic/critic.yaml new file mode 100644 index 000000000..80aa0fb73 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/critic/critic.yaml @@ -0,0 +1,94 @@ +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# fsdp or fsdp2 strategy used for critic model training +strategy: ??? + +# optimizer configs +optim: + + # Warmup steps ratio; total steps will be injected at runtime + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + +# model config for the critic +model: + + # Path to pretrained model weights + path: ~/models/deepseek-llm-7b-chat + + # Tokenizer path (defaults to actor's model path) + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + + # Hugging Face config override + override_config: {} + + # External model implementation (optional) + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + + # Whether to trust remote code from Hugging Face models + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + +# PPO mini-batch size per update +ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + +# Whether to automatically adjust batch size at runtime +use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# Max tokens per GPU in one PPO batch (doubled for critic) +ppo_max_token_len_per_gpu: 32768 + +# Max token length per GPU in forward pass +forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + +# Number of PPO epochs per batch +ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + +# Shuffle training data across PPO epochs +shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + +# PPO value function clipping range +cliprange_value: 0.5 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + +# checkpoint configs +checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # What to include when loading checkpoints + load_contents: ${.save_contents} + +# profiler configs +# the corresponding dataclass is verl.utils.profiler.ProfilerConfig. +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint +_target_: verl.trainer.config.CriticConfig diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/critic/dp_critic.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/critic/dp_critic.yaml new file mode 100644 index 000000000..f080dad51 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/critic/dp_critic.yaml @@ -0,0 +1,95 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: fsdp + +# optimizer configs +optim: + + # Learning rate + lr: 1e-5 + + # Minimum LR ratio for cosine schedule + min_lr_ratio: null + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + +# model config for the critic +model: + + # Whether to use shared memory for loading the model + use_shm: False + + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: True + + # Offload activations to CPU to reduce GPU memory usage + enable_activation_offload: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # FSDP-specific config + fsdp_config: + + # Whether to offload model parameters to CPU + param_offload: False + + # Whether to offload optimizer state to CPU + optimizer_offload: False + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Policy for wrapping layers with FSDP + wrap_policy: + + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # LoRA target modules: "all-linear" or list of linear projection layers + target_modules: all-linear + +# Forward-only batch size during inference (global) +forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + +# Forward-only batch size during inference (per GPU) +forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + +# Sequence parallelism size for Ulysses-style model parallelism +ulysses_sequence_parallel_size: 1 + +# Gradient clipping for critic updates +grad_clip: 1.0 + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint +_target_: verl.trainer.config.FSDPCriticConfig diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/critic/megatron_critic.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/critic/megatron_critic.yaml new file mode 100644 index 000000000..63d15d777 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/critic/megatron_critic.yaml @@ -0,0 +1,130 @@ +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# optimizer configs +optim: + + # select optimizer, default is Adam + optimizer: adam + + # Learning rate + lr: 1e-6 + + # Clip gradients norm + clip_grad: 1.0 + + # initial learning rate for warmup, default to 0.0 + lr_warmup_init: 0.0 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: null + + lr_decay_steps: null + + # select from constant/linear/cosine/inverse_square_root + lr_decay_style: linear + + # minimum learning rate, default to 0.0 + min_lr: 0.0 + + # select from constant/linear/cosine + weight_decay_incr_style: constant + + # select from constant/exponential/cosine + lr_wsd_decay_style: exponential + + # number of steps for weight std decay + lr_wsd_decay_steps: null + + # use checkpoint optimizer parameter scheduler + use_checkpoint_opt_param_scheduler: False + +# model config for the critic +model: + + # override default empty mapping + override_config: + + model_config: {} + + moe_config: + + freeze_moe_router: False + +# megatron-specific parallelism settings +megatron: + + # Whether to offload model parameters to CPU + param_offload: False + + # Whether to offload gradients to CPU + grad_offload: False + + # Whether to offload optimizer state to CPU + optimizer_offload: False + + # size of tensor model parallel group + tensor_model_parallel_size: 1 + + # size of expert model parallel group + expert_model_parallel_size: 1 + + # size of expert tensor parallel group + expert_tensor_parallel_size: null + + # size of pipeline model parallel group + pipeline_model_parallel_size: 1 + + # size of virtual pipeline model parallel group + virtual_pipeline_model_parallel_size: null + + # size of context parallel group + context_parallel_size: 1 + + # Whether to use sequence parallelism + sequence_parallel: True + + # Whether to use distributed optimizer + use_distributed_optimizer: True + + # Whether to use distributed checkpointing + use_dist_checkpointing: False + + # Path for distributed checkpointing + dist_checkpointing_path: null + + # Random seed for Megatron + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + + # Allow to override Distributed Data Parallel (DDP) config + override_ddp_config: ${oc.select:actor_rollout_ref.actor.megatron.override_ddp_config,{}} + + # Transformer config overrides for Megatron + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + + # Whether to use mBridge communications + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + +# Whether to load initial weights +load_weight: True + +# seed for data loader +data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} + +# Asynchronous checkpoint saving +checkpoint: + async_save: False + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint +_target_: verl.trainer.config.MegatronCriticConfig diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/data/legacy_data.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/data/legacy_data.yaml new file mode 100644 index 000000000..9a5ce8f0d --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/data/legacy_data.yaml @@ -0,0 +1,109 @@ +# Tokenizer class or path. If null, it will be inferred from the model. +tokenizer: null + +# Whether to use shared memory for data loading. +use_shm: False + +# Training set parquet. Can be a list or a single file. +# The program will read all files into memory, so it can't be too large (< 100GB). +# The path can be either a local path or an HDFS path. +# For HDFS path, we provide utils to download it to DRAM and convert it to a local path. +train_files: ~/data/rlhf/gsm8k/train.parquet + +# Validation parquet. Can be a list or a single file. +val_files: ~/data/rlhf/gsm8k/test.parquet + +# The field in the dataset where the prompt is located. Default is 'prompt'. +prompt_key: prompt + +# The field used to select the reward function (if using different ones per example). +reward_fn_key: data_source + +# Maximum prompt length. All prompts will be left-padded to this length. +# An error will be reported if the length is too long. +# oc.select: default val for rollout.prompt_length +max_prompt_length: 512 + +# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. +# oc.select: default val for rollout.response_length +max_response_length: 512 + +# Batch size sampled for one training iteration of different RL algorithms. +train_batch_size: 1024 + +# Batch size used during validation. Can be null. +val_batch_size: null + +# Whether to return the original input_ids without adding chat template. +# This is used when the reward model's chat template differs from the policy. +# If using a model-based RM with different templates, this should be True. +return_raw_input_ids: False + +# Whether to return the original chat (prompt) without applying chat template. +return_raw_chat: False + +# Whether to return the full prompt with chat template. +return_full_prompt: False + +# Whether to shuffle the data in the dataloader. +shuffle: True + +# num dataloader workers +dataloader_num_workers: 8 + +# Whether to shuffle the validation set. +validation_shuffle: False + +# Whether to filter overlong prompts. +filter_overlong_prompts: False + +# Number of workers for filtering overlong prompts. +# For large-scale datasets, filtering can be time-consuming. +# Use multiprocessing to speed up. Default is 1. +filter_overlong_prompts_workers: 1 + +# Truncate the input_ids or prompt if they exceed max_prompt_length. +# Options: 'error', 'left', 'right', 'middle'. Default is 'error'. +truncation: error + +# The field in the multi-modal dataset where the image is located. Default is 'images'. +image_key: images + +# The field in the multi-modal dataset where the video is located. +video_key: videos + +# If the remote tokenizer has a Python file, this flag determines whether to allow using it. +trust_remote_code: False + +# Optional: specify a custom dataset class path and name if overriding default loading behavior. +custom_cls: + + # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. + path: null + + # The name of the dataset class within the specified file. + name: null + +# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. +return_multi_modal_inputs: True + +# settings related to data sampler +sampler: + + # the path to the module containing a curriculum class which implements the + # AbstractSampler interface + class_path: null + + # the name of the curriculum class like `MySampler` + class_name: null + +# Data generation configuration for augmenting the dataset. +datagen: + + # The path to the file containing your customized data generation class. + # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' + path: null + + # The class name of the data generation class within the specified file. + # E.g. 'MockDataGenerator' + name: null \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/evaluation.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/evaluation.yaml new file mode 100644 index 000000000..efca03da4 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/evaluation.yaml @@ -0,0 +1,14 @@ +data: + path: /tmp/math_Qwen2-7B-Instruct.parquet + prompt_key: prompt + response_key: responses + data_source_key: data_source + reward_model_key: reward_model + +custom_reward_function: + path: null + name: compute_score + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/generation.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/generation.yaml new file mode 100644 index 000000000..6ac43b5dc --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/generation.yaml @@ -0,0 +1,55 @@ +trainer: + nnodes: 1 + n_gpus_per_node: 8 + device: cuda + +data: + path: ~/data/rlhf/math/test.parquet + prompt_key: prompt + n_samples: 5 + output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet + batch_size: 128 + +model: + path: ~/models/Qwen2-7B-Instruct + external_lib: null +rollout: + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + temperature: 1.0 + top_k: 50 # 0 for hf rollout, -1 for vllm rollout + top_p: 0.7 + prompt_length: 1536 + response_length: 512 + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 8 + # for hf rollout + do_sample: True + disable_log_stats: True + enable_chunked_prefill: True + n: 1 + # support logging rollout prob for debugging purpose + calculate_log_probs: False +actor: + strategy: fsdp # This is for backward-compatibility + ulysses_sequence_parallel_size: 1 # sp size + entropy_from_logits_with_chunking: False # calculate entropy with chunking to reduce memory peak + entropy_checkpointing: False # recompute entropy + fsdp_config: + fsdp_size: -1 + forward_prefetch: False # FSDP1 forward_prefetch configuration + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/npu_profile/npu_profile.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/npu_profile/npu_profile.yaml new file mode 100644 index 000000000..b61260375 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/npu_profile/npu_profile.yaml @@ -0,0 +1,29 @@ +# Options for the npu profiler +options: + + # Storage path of collected data. + save_path: ./profiler_data + + # Collection level, optional values: level_none, level0, level1, level2. + level: level1 + + # Whether to enable memory analysis. + with_memory: False + + # Whether to record tensor shape. + record_shapes: False + + # Whether to record Device-side performance data. + with_npu: True + + # Whether to record Host-side performance data. + with_cpu: True + + # Whether to record Python call stack information. + with_module: False + + # Whether to record operator call stack information. + with_stack: False + + # Whether to automatically parse the data. + analysis: True \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/ppo_megatron_trainer.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/ppo_megatron_trainer.yaml new file mode 100644 index 000000000..75e62d996 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/ppo_megatron_trainer.yaml @@ -0,0 +1,135 @@ +# specify the default per-component configs +defaults: + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml + - actor@actor_rollout_ref.actor: megatron_actor + # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml + - npu_profile@trainer.npu_profile: npu_profile + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + # load the reference default config, then apply the fields in the current yaml + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: megatron_ref + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + # Critic model config. + - critic@critic: megatron_critic + # Reward model config. + - reward_model@reward_model: megatron_reward_model + - _self_ + +actor_rollout_ref: + hybrid_engine: True + + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron + + model: + + path: ~/models/deepseek-llm-7b-chat + + custom_chat_template: null + + external_lib: null + + override_config: + + model_config: {} + + moe_config: + + freeze_moe_router: False + + use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency) + + trust_remote_code: False + + rollout: + # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: False + + load_format: dummy_megatron + + tensor_model_parallel_size: 1 + + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: False + all_ranks: False + ranks: [] + +custom_reward_function: + path: null + name: compute_score + +algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: False + pf_ppo: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow # ["pow", "max_min", "max_random"] + weight_pow: 2.0 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + profile_steps: null # [1,2,5] or [] or null + project_name: verl_examples + experiment_name: gsm8k + logger: ['console', 'wandb'] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + del_local_ckpt_after_load: False + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + device: cuda + # see ppo_trainer.yaml for more details + controller_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + worker_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + capture-range: "cudaProfilerApi" + capture-range-end: null + kill: none +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/ppo_trainer.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/ppo_trainer.yaml new file mode 100644 index 000000000..57d424836 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/ppo_trainer.yaml @@ -0,0 +1,336 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# specify the default per-component configs +defaults: + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml + - actor@actor_rollout_ref.actor: dp_actor + + # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml + - npu_profile@trainer.npu_profile: npu_profile + + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: dp_ref + + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + + # Critic model config. + - critic@critic: dp_critic + + # Reward model config. + - reward_model@reward_model: dp_reward_model + + # load the reference default config, then apply the fields in the current yaml + # self config override anything above + - _self_ + +# config for actor, rollout and reference model +actor_rollout_ref: + + # Whether it's a hybrid engine, currently only supports hybrid engine + hybrid_engine: true + + # common configs for the model + model: + + # Huggingface model path. This can be either local path or HDFS path. + path: ~/models/deepseek-llm-7b-chat + + # Custom chat template for the model. + custom_chat_template: null + + # Whether to use shared memory (SHM) for accelerating the loading of model weights + use_shm: false + + # Additional Python packages to register huggingface models/tokenizers. + external_lib: null + + # Used to override model's original configurations, mainly dropout + override_config: {} + + # Enable gradient checkpointing for actor + enable_gradient_checkpointing: true + + # Enable activation offloading for actor + enable_activation_offload: false + + # Whether to remove padding tokens in inputs during training + use_remove_padding: false + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or + # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] + target_modules: all-linear + + # Exclude modules from applying Lora. Similar usage to target_modules and Peft. + # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora. + exclude_modules: null + + # Whether to use Liger for linear layer fusion + use_liger: false + + # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) + use_fused_kernels: false + + # Options for fused kernels. If use_fused_kernels is true, this will be used. + fused_kernel_options: + + # Implementation backend for fused kernels. Options: "triton" or "torch". + impl_backend: torch + + # Whether to enable loading a remote code model + trust_remote_code: false + + # Rollout model config. + rollout: + + # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: True + + # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc. + # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight + load_format: dummy_dtensor + + # for huge model, layered summon can save memory (prevent OOM) but make it slower + layered_summon: False + + # profiler configs + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# custom reward function definition +custom_reward_function: + + # The path to the file containing your customized reward function. + # If not specified, pre-implemented reward functions will be used. + path: null + + # The name of the reward function within the specified file. Default is 'compute_score'. + name: compute_score + +# config for the algorithm +algorithm: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig + + # Discount factor for future rewards + gamma: 1.0 + + # Trade-off between bias and variance in the GAE estimator + lam: 1.0 + + # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + adv_estimator: gae + + # Whether to normalize advantages by std (specific to GRPO) + norm_adv_by_std_in_grpo: True + + # Whether to enable in-reward KL penalty + use_kl_in_reward: False + + # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_penalty: kl + + # KL control configuration + kl_ctrl: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig + + # KL control type: "fixed" or "adaptive" + type: fixed + + # Initial coefficient for KL penalty + kl_coef: 0.001 + + # Horizon value for adaptive controller (if enabled) + horizon: 10000 + + # Target KL divergence (used for adaptive controller) + target_kl: 0.1 + + # Whether to enable preference feedback PPO + use_pf_ppo: False + + # Preference feedback PPO settings + pf_ppo: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + + # Method for reweighting samples: "pow", "max_min", or "max_random" + reweight_method: pow + + # Power used for weight scaling in "pow" method + weight_pow: 2.0 + +# config for the trainer +trainer: + + # Whether to balance batch sizes across distributed workers + balance_batch: True + + # Number of epochs in training + total_epochs: 30 + + # Total training steps (can be set explicitly or derived from epochs) + total_training_steps: null + + # The steps that will be profiled. null means no profiling. null or [1,2,5,...] + profile_steps: null + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # Project name for experiment tracking (e.g., wandb) + project_name: verl_examples + + # Experiment name for run identification in tracking tools + experiment_name: gsm8k + + # Logging backends to use: "console", "wandb", etc. + logger: [ 'console', 'wandb' ] + + # Number of generations to log during validation + log_val_generations: 0 + + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # Directory for logging validation data; no dump if null + validation_data_dir: null + + # Number of nodes used in the training + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # Save frequency (by iteration) for model checkpoints + save_freq: -1 + + # ESI refers to the elastic server instance used during training, similar to the training plan. For example, + # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. + # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. + # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. + # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. + esi_redundant_time: 0 + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (only used when resume_mode is "resume_path") + resume_from_path: null + + # Whether to run validation before training begins + val_before_train: True + + # Whether to run validation only + val_only: False + + # Validation frequency (in training iterations) + test_freq: -1 + + # Number of iterations to warm up the critic before updating policy + critic_warmup: 0 + + # Default path to distributed filesystem for saving checkpoints + default_hdfs_dir: null + + # Whether to delete local checkpoints after loading + del_local_ckpt_after_load: False + + # Default local directory for saving checkpoints + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + + # Maximum number of actor checkpoints to keep + max_actor_ckpt_to_keep: null + + # Maximum number of critic checkpoints to keep + max_critic_ckpt_to_keep: null + + # Timeout (in seconds) for Ray worker to wait for registration + ray_wait_register_center_timeout: 300 + + # Device to run training on (e.g., "cuda", "cpu") + device: cuda + + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +# configs related to ray initialization +ray_init: + + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null + + # Path to save Ray timeline JSON for performance profiling + timeline_json_file: null diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/ref/dp_ref.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/ref/dp_ref.yaml new file mode 100644 index 000000000..13b604718 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/ref/dp_ref.yaml @@ -0,0 +1,38 @@ +# defaults specify the default config from each component +defaults: + + # dp ref config, inheriting from trainer/config/ref/ref.yaml + - ref + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# config for FSDP strategy +fsdp_config: + + # whether to offload parameters in FSDP + param_offload: False + + # whether to perform reshard after model forward to save memory. + # only for fsdp2, [True, False, int between 1 and fsdp_size] + reshard_after_forward: True + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # the wrap policy for FSDP model + wrap_policy: + + # minimum number of params in a wrapped module + min_num_params: 0 + +# sequence parallel size +# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1 +ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/ref/megatron_ref.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/ref/megatron_ref.yaml new file mode 100644 index 000000000..6a75d68e3 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/ref/megatron_ref.yaml @@ -0,0 +1,51 @@ +# megatron ref config, inheriting from trainer/config/ref/ref.yaml +defaults: + - ref + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +megatron: + + param_offload: False + + tensor_model_parallel_size: 1 + + expert_model_parallel_size: 1 + + expert_tensor_parallel_size: None + + pipeline_model_parallel_size: 1 + + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + + context_parallel_size: 1 + + sequence_parallel: True + + use_distributed_optimizer: False + + use_dist_checkpointing: False + + dist_checkpointing_path: null + + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + +profile: + + use_profile: False + + profile_ranks: null + + step_start: -1 + + step_end: -1 + + save_path: null + +load_weight: True \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/ref/ref.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/ref/ref.yaml new file mode 100644 index 000000000..7d9157b3e --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/ref/ref.yaml @@ -0,0 +1,21 @@ +# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default +strategy: ${actor_rollout_ref.actor.strategy} + +# whether to enable torch.compile +# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1 +use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] +# The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# the max token length per GPU +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/dp_reward_model.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/dp_reward_model.yaml new file mode 100644 index 000000000..d9a837032 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/dp_reward_model.yaml @@ -0,0 +1,51 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: fsdp + +model: + + # Whether to use shared memory for loading the model + use_shm: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Whether to use fused reward kernels for speedup + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + + # FSDP-specific config + fsdp_config: + + # Policy for wrapping layers with FSDP + wrap_policy: + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Whether to offload model parameters to CPU + param_offload: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + +# Sequence parallelism size for Ulysses-style model parallelism +ulysses_sequence_parallel_size: 1 \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/megatron_reward_model.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/megatron_reward_model.yaml new file mode 100644 index 000000000..41525aa1d --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/megatron_reward_model.yaml @@ -0,0 +1,61 @@ +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value +# if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# Megatron parallelism & checkpointing config +megatron: + # Whether to offload model parameters to CPU + param_offload: False + + # Number of GPUs in tensor model parallel group + tensor_model_parallel_size: 1 + + # Number of GPUs in expert model parallel group + expert_model_parallel_size: 1 + + # Expert tensor parallel size + expert_tensor_parallel_size: null + + # Number of pipeline model parallel stages + pipeline_model_parallel_size: 1 + + # change VPP interface for parallelism tests + virtual_pipeline_model_parallel_size: null + + # Context parallel size + context_parallel_size: 1 + + # Whether to use sequence parallelism + sequence_parallel: True + + # Whether to use distributed optimizer + use_distributed_optimizer: False + + # Whether to enable distributed checkpointing + use_dist_checkpointing: False + + # Path for distributed checkpoints + dist_checkpointing_path: null + + # RNG seed for megatron + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + + # Any overrides to transformer config + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + + # Whether to use mbridge for faster comms + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + +# Whether to load weights (default True) +load_weight: True \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/reward_model.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/reward_model.yaml new file mode 100644 index 000000000..698343955 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/reward_model/reward_model.yaml @@ -0,0 +1,81 @@ +# configs for the reward model + +# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. +# In GSM8K and Math examples, we disable reward model. +# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses. +# If False, the following parameters are not effective +enable: False + +# FSDP strategy: "fsdp" or "fsdp2" +strategy: ??? + +# model config for reward scoring +model: + + # Input tokenizer. If the reward model’s chat template is inconsistent with the policy, + # we need to first decode to plaintext, then apply the rm’s chat_template. + # Then score with RM. If chat_templates are consistent, it can be set to null. + # set this to null if the chat template is identical + input_tokenizer: ${actor_rollout_ref.model.path} + + # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. + # Other model types need to define their own RewardModelWorker and pass it from the code. + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Whether to enable loading a remote code model, default to False + trust_remote_code: False + +# [Deprecated] Global micro batch size +# will be deprecated, use micro_batch_size_per_gpu +micro_batch_size: null + +# Local per-GPU micro batch size +micro_batch_size_per_gpu: null + +# Maximum sequence length to process for scoring +max_length: null + +# Whether to dynamically adjust batch size at runtime +use_dynamic_bsz: ${critic.use_dynamic_bsz} + +# Maximum number of tokens per GPU in one forward pass +forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources. +# Default is naive. If all verification functions are multiprocessing-safe, +# the reward manager can be set to prime for parallel verification. +reward_manager: naive + +# Whether to launch custom reward function asynchronously during log_prob +# custom reward function executed async on CPU, during log_prob +launch_reward_fn_async: False + +# Cloud/local sandbox fusion configuration for custom reward logic +sandbox_fusion: + + # Cloud /local function URL for sandbox execution + url: null + + # Max concurrent requests allowed to sandbox + max_concurrent: 64 + + # Max memory limit for each sandbox process in MB + memory_limit_mb: 1024 + +# profiler configs +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] \ No newline at end of file diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/rollout/rollout.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/rollout/rollout.yaml new file mode 100644 index 000000000..fc3af80d4 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/rollout/rollout.yaml @@ -0,0 +1,215 @@ +# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future +name: vllm + +# sync: LLM, async: AsyncLLM +mode: sync + +# Sampling temperature for rollout. +temperature: 1.0 + +# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. +top_k: -1 + +# Top-p sampling parameter. Default 1.0. +top_p: 1 + +# typically the same as data max prompt length +# same as data.max_prompt_length if it exists +prompt_length: ${oc.select:data.max_prompt_length,512} + +# typically the same as data max response length +# same as data.max_response_length if it exists +response_length: ${oc.select:data.max_response_length,512} + +# for vllm rollout +# Rollout model parameters type. Align with actor model's FSDP/Megatron type. +dtype: bfloat16 + +# Fraction of GPU memory used by vLLM/SGLang for KV cache. +gpu_memory_utilization: 0.5 + +# Whether to ignore EOS and continue generating after EOS is hit. +ignore_eos: False + +# Whether to disable CUDA graph. Default True to allow cache freeing. +enforce_eager: True + +# Whether to free engine KVCache after generation. Set enforce_eager=True when enabled. +free_cache_engine: True + +# TP size for rollout. Not effective for hf +tensor_model_parallel_size: 2 + +# max number of tokens in a batch +max_num_batched_tokens: 8192 + +# max length for rollout +max_model_len: null + +# max length of sequences +max_num_seqs: 1024 + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# max token length for log_prob computation +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# disable logging statistics +disable_log_stats: True + +# for hf rollout +# Whether to sample during training rollout. False uses greedy sampling. +do_sample: True + +# number of responses (i.e. num sample times). > 1 for grpo +n: 1 + +# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) +multi_stage_wake_up: false + +# Extra inference engine arguments (vllm, sglang). +engine_kwargs: + + # for vllm + vllm: + + # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB). + swap_space: null + + # Whether to disable the preprocessor cache for multimodel models. + disable_mm_preprocessor_cache: False + + # for sglang + sglang: + + # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default. + attention_backend: null + +# Sampling parameters used during validation. +val_kwargs: + + # sampling parameters for validation + # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. + top_k: -1 + + # Top-p sampling parameter. Default 1.0. + top_p: 1.0 + + # Sampling temperature for rollout. + temperature: 0 + + # whether to repeat n times for validation + n: 1 + + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: False + +# Multi-turn interaction config for tools or chat. +multi_turn: + + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # null for default callback + completion_callback: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + +# support logging rollout prob for debugging purpose +calculate_log_probs: False + +# [Experimental] agent loop based rollout configs +agent: + + # Number of agent loop workers + num_workers: 8 + + # custom agent loop config path, which should contain list of configs to intialize AgentLoop instances. + # https://hydra.cc/docs/advanced/instantiate_objects/overview/ + # + # - name: react_agent + # _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop + # tools: ["get_current_temperature"] + # - name: math_expression + # _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop + # min_terms: 2 + # max_terms: 6 + agent_loop_config_path: null + + # custom async server configs + custom_async_server: + + # Path to the custom async server implementation + path: null + + # Class name of the custom async server class (e.g. AsyncvLLMServer) + name: null + +# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. +# This parameter controls the maximum payload size for a single weight update request. +# Reference: https://github.com/volcengine/verl/pull/2418 +# Currently only supported in SGLang rollout implementations +# Larger values may improve throughput but increase memory overhead +# Detailed performance comparison: +# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 +# Default value (512MB) is optimized for typical GPU memory configurations +# For the best performance of `rebuild_cuda_tensor`, it is recommended to: +# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` +# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` +# when using Tensor Parallelism (TP) >= 8. +update_weights_bucket_megabytes: 512 + +# trace rollout data +trace: + + # trace backend, support mlflow, weave + backend: null + + # whether translate token id to text in output + token2text: False diff --git a/toolbox/verl/v0.5.0/verl/trainer/config/sft_trainer.yaml b/toolbox/verl/v0.5.0/verl/trainer/config/sft_trainer.yaml new file mode 100644 index 000000000..52ade2b42 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/config/sft_trainer.yaml @@ -0,0 +1,85 @@ +data: + train_batch_size: 256 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: 4 # this is also val batch size + train_files: ~/data/gsm8k/train.parquet + val_files: ~/data/gsm8k/test.parquet + # Single-turn settings + prompt_key: question + response_key: answer + prompt_dict_keys: null + response_dict_keys: null + # Multi-turn settings + multiturn: + enable: false # Set to true to use multi-turn dataset + messages_key: messages # Key for messages list in multi-turn mode + tools_key: tools # Key for tools list in multi-turn mode + enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode + max_length: 1024 + truncation: error + balance_dp_token: False + chat_template: null + custom_cls: + path: null + name: null + use_shm: False +model: + partial_pretrain: ~/models/gemma-1.1-7b-it + use_shm: False + fsdp_config: + model_dtype: fp32 + wrap_policy: + min_num_params: 0 + cpu_offload: False + offload_params: False + external_lib: null + enable_gradient_checkpointing: True + trust_remote_code: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation + use_liger: False + strategy: fsdp2 +optim: + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + warmup_steps_ratio: 0.1 + clip_grad: 1.0 + lr_scheduler: cosine +ulysses_sequence_parallel_size: 1 +use_remove_padding: False +trainer: + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_hdfs_dir: null + project_name: gsm8k-sft + experiment_name: test + total_epochs: 4 + total_training_steps: null + logger: [ 'console', 'wandb' ] + seed: 1 + + save_freq: -1 + test_freq: -1 + nnodes: 1 + n_gpus_per_node: 8 + max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (used when resume_mode is "resume_path" or "auto") + resume_from_path: null + + # Checkpoint configuration + checkpoint: + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ["model", "optimizer", "extra"] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${trainer.checkpoint.save_contents} + device: cuda diff --git a/toolbox/verl/v0.5.0/verl/trainer/constants_ppo.py b/toolbox/verl/v0.5.0/verl/trainer/constants_ppo.py new file mode 100644 index 000000000..1181133e1 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/constants_ppo.py @@ -0,0 +1,37 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +PPO_RAY_RUNTIME_ENV = { + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + }, +} + + +def get_ppo_ray_runtime_env(): + """ + A filter function to return the PPO Ray runtime environment. + To avoid repeat of some environment variables that are already set. + """ + runtime_env = {"env_vars": PPO_RAY_RUNTIME_ENV["env_vars"].copy()} + for key in list(runtime_env["env_vars"].keys()): + if os.environ.get(key) is not None: + runtime_env["env_vars"].pop(key, None) + return runtime_env diff --git a/toolbox/verl/v0.5.0/verl/trainer/fsdp_sft_trainer.py b/toolbox/verl/v0.5.0/verl/trainer/fsdp_sft_trainer.py new file mode 100644 index 000000000..8b8cc4a46 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/fsdp_sft_trainer.py @@ -0,0 +1,825 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A lightweight one-file FSDP SFT Trainer +TODO(zhangchi.usc1992) +- Add calculation of mfu +- Add validation +""" + +import os + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging +import re +from contextlib import nullcontext + +import hydra +import torch +import torch.distributed +from omegaconf import DictConfig +from peft import LoraConfig, TaskType, get_peft_model +from tensordict import TensorDict +from torch import nn, optim +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import Dataset, DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel + +import verl.utils.hdfs_io as hdfs_io +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available +from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + fsdp2_clip_grad_norm_, + fsdp2_load_full_state_dict, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, +) +from verl.utils.logger import log_with_rank +from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.py_functional import convert_to_regular_types +from verl.utils.torch_dtypes import PrecisionType +from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from verl.utils.tracking import Tracking +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_world_size, + ulysses_pad_and_slice_inputs, +) +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +elif is_npu_available: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +def extract_step(path): + match = re.search(r"global_step_(\d+)", path) + if match: + return int(match.group(1)) + return None + + +class FSDPSFTTrainer: + def __init__( + self, + config, + device_mesh: DeviceMesh, + ulysses_device_mesh: DeviceMesh, + tokenizer, + train_dataset: Dataset, + val_dataset: Dataset, + ): + self.config = config + self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.tokenizer = tokenizer + if self.config.data.chat_template is not None: + raise ValueError("Apply Chat template from config is not supported yet.") + + # normalize dp size + self._normalize_config_bsz() + + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) + self.use_remove_padding = getattr(self.config, "use_remove_padding", False) + if self.device_mesh.get_rank() == 0: + print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") + print(f"Using remove padding: {self.use_remove_padding}") + + self._build_dataloader(train_dataset, val_dataset) + + # Initialize resume-related variables + self.resume_global_step = 0 + + # build model + self._build_model_optimizer() + + # Initialize checkpoint manager + self._init_checkpoint_manager() + + self.load_checkpoint() + + if self.device_mesh.get_rank() == 0: + print(self.config) + self.device_name = self.config.trainer.device + + def _normalize_config_bsz(self): + dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) + if self.device_mesh.get_rank() == 0: + print(f"Normalize batch size by dp {dp_size}") + + assert self.config.data.train_batch_size % dp_size == 0, ( + f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + ) + + self.config.data.train_batch_size //= dp_size + + assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 + + def _build_dataloader(self, train_dataset, val_dataset): + # build dataset + config = self.config + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank("dp") + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print(f"Using SP rank {rank} and size {world_size} for data distribution") + print("Each SP rank gets different data, but the same data WITHIN the same rank") + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f"Using FSDP rank {rank} and size {world_size} for data distribution") + + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True + ) + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=config.data.train_batch_size, + sampler=self.train_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) + + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True + ) + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=config.data.micro_batch_size_per_gpu, + sampler=self.val_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) + + def _build_model_optimizer(self): + # TODO (zhangchi.usc1992): + # 1. support pretrain from random weights + # 2. support init directly from sharded weights + local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + log_gpu_memory_usage("Before model allocation", logger=logger) + + trust_remote_code = self.config.model.trust_remote_code + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = PrecisionType.to_dtype(torch_dtype) + # load config first + config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + self.model_config = config + if hasattr(self.model_config, "max_position_embeddings"): + self.model_config.max_position_embeddings = max( + self.model_config.max_position_embeddings, self.config.data.max_length + ) + if self.config.ulysses_sequence_parallel_size > 1: + assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + + # This may be very large + init_context = get_init_weight_context_manager( + use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(): + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch_dtype, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + + apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) + + # Apply Liger kernel if use_liger is enabled + if self.config.model.get("use_liger", False): + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=self.model) + + if self.config.model.get("lora_rank", 0) > 0: + self.model.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + self.model = get_peft_model(self.model, LoraConfig(**lora_config)) + self.model = self.model.to(torch_dtype) + + if self.config.model.enable_gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + log_gpu_memory_usage("After model allocation", logger=logger) + + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) + + auto_wrap_policy = get_fsdp_wrap_policy( + self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + if self.device_mesh.get_rank() == 0: + print(auto_wrap_policy) + + if not self.config.model.fsdp_config.cpu_offload: + cpu_offload = None + else: + cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) + + fsdp_strategy = self.config.model.strategy + if fsdp_strategy == "fsdp": + self.fsdp_model = FSDP( + self.model, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + forward_prefetch=False, + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + ) + + fsdp_kwargs = { + "mesh": self.device_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": True, + } + full_state = self.model.state_dict() + apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) + fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) + self.fsdp_model = self.model + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + log_gpu_memory_usage("After FSDP wrapping", logger=logger) + + self.optimizer = optim.AdamW( + self.fsdp_model.parameters(), + lr=self.config.optim.lr, + betas=self.config.optim.betas, + weight_decay=self.config.optim.weight_decay, + ) + + log_gpu_memory_usage("After initialize optimizer", logger=logger) + + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs + + if self.device_mesh.get_rank() == 0: + print( + f"Number of steps/epoch {self.steps_per_epoch}, number of epochs " + f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}" + ) + + num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) + + if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) + elif self.config.optim.lr_scheduler == "wsd": + self.lr_scheduler = get_wsd_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) + else: + raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") + + def _compute_loss_and_backward(self, batch, do_backward=True): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + + # Move inputs to GPU and prepare loss mask + input_ids = batch["input_ids"].to(self.device_name) + attention_mask = batch["attention_mask"].to(self.device_name) + position_ids = batch["position_ids"].to(self.device_name) + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) + loss_fct = nn.CrossEntropyLoss(reduction="none") + + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids[:, 1:].contiguous() + output = self.fsdp_model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + logits = output.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + loss = loss * loss_mask.to(loss.device) + else: + # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks + # i.e., each GPU has <1 sequence, and each SP group has 1 sequence + # 1. All SP ranks will receive the *SAME* batch + # 2. Different SP groups will receive *DIFFERENT* batches + # This is implemented by the DistributedSampler + + batch_size, seqlen = input_ids.shape + # Remove padding + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # Unpad position_ids to align rotary + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # Pad and slice inputs for sequence parallelism + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) + # For computing loss + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() + ) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # Forward pass + output = self.fsdp_model( + input_ids=input_ids_rmpad_sliced, + attention_mask=None, # Not needed with flash attention varlen + position_ids=position_ids_rmpad_padded, + use_cache=False, + ) + + # Compute loss locally then aggregate + logits_rmpad = output.logits.squeeze(0) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) + loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) + # Gather and unpad for sequence parallelism + loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # This is the loss collected from all ulysses ranks + full_loss = pad_input( + hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss + full_loss = full_loss.reshape(-1) + loss_mask = loss_mask.to(full_loss.device) + loss = full_loss * loss_mask + + valid_token_this_rank = torch.sum(loss_mask) + + if self.config.data.balance_dp_token: + torch.distributed.all_reduce(valid_token_this_rank) + dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() + else: + dp_size = 1 + + loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size + + if do_backward: + loss.backward() + return loss + + def training_step(self, batch: TensorDict): + self.fsdp_model.train() + + log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) + + self.optimizer.zero_grad() + + log_gpu_memory_usage("After optimizer zero_grad", logger=logger) + + micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) + n_micro_batches = len(micro_batches) + step_loss = 0 + for micro_batch in micro_batches: + loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches + step_loss += loss.item() + + if self.config.model.strategy == "fsdp": + grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) + elif self.config.model.strategy == "fsdp2": + grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) + else: + raise NotImplementedError(f"not implement {self.config.model.strategy}") + + log_gpu_memory_usage("Before optimizer step", logger=logger) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: grad_norm is not finite: {grad_norm}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + + log_gpu_memory_usage("After optimizer step", logger=logger) + + self.lr_scheduler.step() + + # reduce loss across dp ranks + lr = self.lr_scheduler.get_last_lr()[0] + + log_gpu_memory_usage("After offload weights", logger=logger) + + step_loss = torch.tensor(step_loss).to(self.device_name) + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.device_mesh.size(0) + return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} + + def validation_step(self, batch: TensorDict): + self.fsdp_model.eval() + with torch.no_grad(): + loss = self._compute_loss_and_backward(batch, do_backward=False) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.device_mesh.size(0) + return loss + + def save_checkpoint(self, step): + """Save checkpoint using FSDPCheckpointManager with improved tracking""" + from verl.utils.fs import local_mkdir_safe + + # Determine checkpoint path + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") + + if self.device_mesh.get_rank() == 0: + print(f"Saving checkpoint to: {local_global_step_folder}") + + # Get max checkpoints to keep + max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None) + + # Use checkpoint manager to save + self.checkpoint_manager.save_checkpoint( + local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep + ) + + # Save dataloader state + if self.device_mesh.get_rank() == 0: + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + + # Use StatefulDataLoader's built-in state dict functionality + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + print(f"Saved dataloader state to: {dataloader_local_path}") + + # Update latest checkpoint tracker (atomic write) + tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir) + temp_tracker_file = tracker_file + ".tmp" + with open(temp_tracker_file, "w") as f: + f.write(str(step)) + os.rename(temp_tracker_file, tracker_file) + print(f"Updated checkpoint tracker: {tracker_file}") + + # Copy to HDFS if configured + if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, "default_hdfs_dir", None): + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) + hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) + + torch.distributed.barrier() + + def _init_checkpoint_manager(self): + """Initialize checkpoint manager with proper configuration""" + # Get checkpoint configuration from config, with defaults + checkpoint_config = getattr(self.config.trainer, "checkpoint", {}) + + # Set default values if not specified + save_contents = checkpoint_config.get("save_contents", ["model", "optimizer", "extra"]) + load_contents = checkpoint_config.get("load_contents", save_contents) + + # Create checkpoint config dict + checkpoint_config_dict = { + "load_contents": load_contents, + "save_contents": save_contents, + } + + # Convert to DictConfig for compatibility + checkpoint_config_dict = DictConfig(checkpoint_config_dict) + + # Initialize checkpoint manager + self.checkpoint_manager = FSDPCheckpointManager( + model=self.fsdp_model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + processing_class=self.tokenizer, + checkpoint_config=checkpoint_config_dict, + ) + + def load_checkpoint(self): + # Determine resume path based on configuration + checkpoint_path = self._determine_resume_path() + + if checkpoint_path is None: + return 0 + + # extract resume step from checkpoint path + resume_step = extract_step(checkpoint_path) + if resume_step is None: + log_with_rank( + f"Warning: Could not extract step number from {checkpoint_path}, starting from step 0", + logger=logger, + rank=self.device_mesh.get_rank(), + level=logging.WARNING, + log_only_rank_0=True, + ) + return 0 + self.resume_global_step = resume_step + + # Use checkpoint manager to load model state + self.checkpoint_manager.load_checkpoint(checkpoint_path) + log_with_rank( + f"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # Always load dataloader state for StatefulDataLoader + self._load_dataloader_state(checkpoint_path) + + return resume_step + + def _load_dataloader_state(self, checkpoint_path: str): + """Load dataloader state from checkpoint""" + dataloader_path = os.path.join(checkpoint_path, "data.pt") + + if os.path.exists(dataloader_path): + # Use StatefulDataLoader's built-in state dict functionality + dataloader_state_dict = torch.load(dataloader_path, map_location="cpu", weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + + log_with_rank( + f"Successfully loaded dataloader state from {dataloader_path}", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + else: + log_with_rank( + f"Warning: No dataloader state found at {dataloader_path}, will start from scratch", + logger=logger, + rank=self.device_mesh.get_rank(), + level=logging.WARNING, + log_only_rank_0=True, + ) + + def _determine_resume_path(self): + """Determine the path to resume from based on resume_mode configuration""" + resume_mode = getattr(self.config.trainer, "resume_mode", "auto") + resume_from_path = getattr(self.config.trainer, "resume_from_path", None) + + if resume_mode == "disable": + return None + elif resume_mode == "auto": + if resume_from_path is not None: + assert os.path.exists(resume_from_path), ( + "resume_from_path must be null or an existing path when resume_mode is 'auto'" + ) + assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" + return resume_from_path + # Try to find the latest checkpoint in the default directory + return self._find_latest_checkpoint() + elif resume_mode == "resume_path": + assert os.path.exists(resume_from_path), ( + "resume_from_path must be an existing path when resume_mode is 'resume_path'" + ) + assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" + return resume_from_path + else: + raise ValueError(f"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'") + + def _find_latest_checkpoint(self): + """Find the latest checkpoint in the default local directory""" + checkpoint_dir = self.config.trainer.default_local_dir + + if not os.path.exists(checkpoint_dir): + return None + + latest_checkpoint = find_latest_ckpt_path(checkpoint_dir) + + if latest_checkpoint and self.device_mesh.get_rank() == 0: + step_num = extract_step(latest_checkpoint) + print(f"Found latest checkpoint: {latest_checkpoint} (step {step_num})") + + return latest_checkpoint + + def fit(self): + rank = self.device_mesh.get_rank() + + # TODO: add a unified tracking + if rank == 0: + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + ) + + global_step = self.resume_global_step # Start from resumed step + last_valid_metric = None + # compute the total training steps. + # the total training steps in SFT is mainly for early exit + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + log_with_rank( + f"Total training steps: {self.total_training_steps},", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # With StatefulDataLoader, we don't need to manually calculate epochs and steps + # The dataloader will automatically resume from where it left off + if global_step > 0: + log_with_rank( + f"StatefulDataLoader will automatically resume from global step: {global_step}", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # Calculate which epoch we're starting from for sampler.set_epoch() + start_epoch = global_step // self.steps_per_epoch + + for epoch in range(start_epoch, self.config.trainer.total_epochs): + self.train_sampler.set_epoch(epoch=epoch) + + for step_in_epoch, data in enumerate( + tqdm( + self.train_dataloader, + initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + disable=rank != 0, + ) + ): + global_step += 1 + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) + metric = self.training_step(data) + if rank == 0: + tracking.log(data=metric, step=global_step) + + is_last_step = global_step >= self.total_training_steps + is_valid_step = global_step % self.config.trainer.test_freq == 0 + is_save_step = global_step % self.config.trainer.save_freq == 0 + + # early exit or validation step + if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step): + # Perform validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( + self.device_name + ) + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + last_valid_metric = metric + torch.distributed.barrier() + + if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step): + self.save_checkpoint(step=global_step) + + if is_last_step: + if rank == 0: + print(f"Final validation metrics: {last_valid_metric}") + return + + +def run_sft(config): + device_name = get_device_name() + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh( + device_type=device_name, + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=("dp", "sp"), + ) + # build tokenizer and datasets first + from verl.utils import hf_tokenizer + + local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) + tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) + val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) + + trainer = FSDPSFTTrainer( + config=config, + device_mesh=device_mesh, + ulysses_device_mesh=ulysses_device_mesh, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) + + trainer.fit() + + destroy_global_process_group() + + +@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) +def main(config): + run_sft(config) + + +def create_sft_dataset(data_paths, data_config, tokenizer): + """Create a dataset.""" + # build dataset + # First check if a custom dataset class is specified + if data_config.custom_cls.get("path", None): + from verl.utils.import_utils import load_extern_type + + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Then check if multi-turn dataset should be used + elif data_config.get("multiturn", {}).get("enable", False): + dataset_cls = MultiTurnSFTDataset + # Default to single-turn dataset + else: + dataset_cls = SFTDataset + + # Create datasets based on the selected class + dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) + return dataset + + +if __name__ == "__main__": + main() diff --git a/toolbox/verl/v0.5.0/verl/trainer/main_eval.py b/toolbox/verl/v0.5.0/verl/trainer/main_eval.py new file mode 100644 index 000000000..0a5c58177 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/main_eval.py @@ -0,0 +1,80 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Offline evaluate the performance of a generated file using reward model and ground truth verifier. +The input is a parquet file that contains N generated sequences and (optional) the ground truth. + +""" + +from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd +import ray +from tqdm import tqdm + +from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.utils.fs import copy_to_local + + +@ray.remote +def process_item(reward_fn, data_source, response_lst, reward_data): + ground_truth = reward_data["ground_truth"] + score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] + return data_source, np.mean(score_lst) + + +@hydra.main(config_path="config", config_name="evaluation", version_base=None) +def main(config): + local_path = copy_to_local(config.data.path, use_shm=config.data.get("use_shm", False)) + dataset = pd.read_parquet(local_path) + responses = dataset[config.data.response_key] + data_sources = dataset[config.data.data_source_key] + reward_model_data = dataset[config.data.reward_model_key] + + total = len(dataset) + + # Initialize Ray + if not ray.is_initialized(): + ray.init(num_cpus=config.ray_init.num_cpus) + + # evaluate test_score based on data source + data_source_reward = defaultdict(list) + compute_score = get_custom_reward_fn(config) + + # Create remote tasks + remote_tasks = [ + process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + ] + + # Process results as they come in + with tqdm(total=total) as pbar: + while len(remote_tasks) > 0: + # Use ray.wait to get completed tasks + done_ids, remote_tasks = ray.wait(remote_tasks) + for result_id in done_ids: + data_source, score = ray.get(result_id) + data_source_reward[data_source].append(score) + pbar.update(1) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) + + print(metric_dict) + + +if __name__ == "__main__": + main() diff --git a/toolbox/verl/v0.5.0/verl/trainer/main_generation.py b/toolbox/verl/v0.5.0/verl/trainer/main_generation.py new file mode 100644 index 000000000..b8174ade5 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/main_generation.py @@ -0,0 +1,148 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generate responses given a dataset of prompts +""" + +import os + +import hydra +import numpy as np +import ray + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# os.environ['TORCH_COMPILE_DISABLE'] = '1' + +from pprint import pprint + +import pandas as pd +from omegaconf import OmegaConf + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.hdfs_io import makedirs +from verl.utils.model import compute_position_id_with_mask +from verl.workers.fsdp_workers import ActorRolloutRefWorker + + +@hydra.main(config_path="config", config_name="generation", version_base=None) +def main(config): + run_generation(config) + + +def run_generation(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, + num_cpus=config.ray_init.num_cpus, + ) + + ray.get(main_task.remote(config)) + + +@ray.remote(num_cpus=1) +def main_task(config): + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + local_path = copy_to_local(config.model.path) + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + if config.rollout.temperature == 0.0: + assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." + assert config.data.n_samples >= 1, "n_samples should always >= 1" + + # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) + dataset = pd.read_parquet(config.data.path) + chat_lst = dataset[config.data.prompt_key].tolist() + + chat_lst = [chat.tolist() for chat in chat_lst] + + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") + resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) + wg = RayWorkerGroup( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + device_name=config.trainer.device, + ) + wg.init_model() + + total_samples = len(dataset) + config_batch_size = config.data.batch_size + num_batch = -(-total_samples // config_batch_size) + output_lst = [[] for _ in range(config.data.n_samples)] + + for batch_idx in range(num_batch): + print(f"[{batch_idx + 1}/{num_batch}] Start to process.") + batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + inputs = tokenizer.apply_chat_template( + batch_chat_lst, + add_generation_prompt=True, + padding=True, + truncation=True, + max_length=config.rollout.prompt_length, + return_tensors="pt", + return_dict=True, + tokenize=True, + ) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + position_ids = compute_position_id_with_mask(attention_mask) + batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + + data = DataProto.from_dict(batch_dict) + data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) + + # START TO GENERATE FOR n_samples TIMES + print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") + for n_sample in range(config.data.n_samples): + output_padded = wg.generate_sequences(data_padded) + output = unpad_dataproto(output_padded, pad_size=pad_size) + + output_texts = [] + for i in range(len(output)): + data_item = output[i] + prompt_length = data_item.batch["prompts"].shape[-1] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = data_item.batch["responses"][:valid_response_length] + response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + output_texts.append(response_str) + + output_lst[n_sample].extend(output_texts) + + # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) + output_lst = np.array(output_lst, dtype=object) + output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() + + # add to the data frame + dataset["responses"] = output_lst + + # write to a new parquet + output_dir = os.path.dirname(config.data.output_path) + makedirs(output_dir, exist_ok=True) + dataset.to_parquet(config.data.output_path) + + +if __name__ == "__main__": + main() diff --git a/toolbox/verl/v0.5.0/verl/trainer/main_ppo.py b/toolbox/verl/v0.5.0/verl/trainer/main_ppo.py new file mode 100644 index 000000000..75ddaa621 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/main_ppo.py @@ -0,0 +1,338 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + ray.init( + runtime_env=get_ppo_ray_runtime_env(), + num_cpus=config.ray_init.num_cpus, + num_gpus=config.trainer.n_gpus_per_node, + ) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.trainer.get("profile_steps") is not None + and len(config.trainer.get("profile_steps", [])) > 0 + ): + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_init.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + """ + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Define worker classes based on the actor strategy. + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + # import warnings + # warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \ + # Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.") + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + # Map roles to their corresponding remote worker classes. + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + Role.Critic: ray.remote(CriticWorker), + } + + # Define the resource pool specification. + # Map roles to the resource pool. + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # Add a reference policy worker if KL loss or KL reward is used. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/toolbox/verl/v0.5.0/verl/trainer/ppo/__init__.py b/toolbox/verl/v0.5.0/verl/trainer/ppo/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/ppo/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/trainer/ppo/core_algos.py b/toolbox/verl/v0.5.0/verl/trainer/ppo/core_algos.py new file mode 100644 index 000000000..5f0267581 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/ppo/core_algos.py @@ -0,0 +1,1148 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Core functions to implement PPO algorithms. +The function implemented in this file should be used by trainer with different distributed strategies to +implement PPO-like algorithms. +""" + +__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] + +from collections import defaultdict +from enum import Enum +from typing import Optional + +import numpy as np +import torch + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import AlgoConfig + +POLICY_LOSS_REGISTRY = {} + + +def register_policy_loss(name): + """Register a policy loss function with the given name. + + Args: + name (str): The name to register the policy loss function under. + + Returns: + function: Decorator function that registers the policy loss function. + """ + + def decorator(func): + POLICY_LOSS_REGISTRY[name] = func + return func + + return decorator + + +def get_policy_loss_fn(name): + """Get the policy loss with a given name. + + Args: + name: `(str)` + The name of the policy loss. + + Returns: + `(callable)`: The policy loss function. + """ + loss_name = name + if loss_name not in POLICY_LOSS_REGISTRY: + raise ValueError( + f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" + ) + return POLICY_LOSS_REGISTRY[loss_name] + + +ADV_ESTIMATOR_REGISTRY = {} + + +def register_adv_est(name_or_enum): + """Decorator to register a advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + """ + + def decorator(fn): + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: + raise ValueError( + f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" + ) + ADV_ESTIMATOR_REGISTRY[name] = fn + return fn + + return decorator + + +def get_adv_estimator_fn(name_or_enum): + """Get the advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + Returns: + `(callable)`: The advantage estimator function. + """ + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name not in ADV_ESTIMATOR_REGISTRY: + raise ValueError(f"Unknown advantage estimator simply: {name}") + return ADV_ESTIMATOR_REGISTRY[name] + + +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + GPG = "gpg" + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target_kl, horizon): + self.value = init_kl_coef + self.target = target_kl + self.horizon = horizon + + def update(self, current_kl, n_steps): + """Update the KL coefficient based on current KL divergence. + + Args: + current_kl (float): Current KL divergence value. + n_steps (int): Number of steps taken. + """ + target = self.target + proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current_kl, n_steps): + """Update method for fixed KL controller (no-op). + + Args: + current_kl (float): Current KL divergence value (unused). + n_steps (int): Number of steps taken (unused). + """ + pass + + +def get_kl_controller(kl_ctrl): + """Factory function to create appropriate KL controller based on configuration. + + Args: + kl_ctrl: Configuration object containing KL controller settings. + + Returns: + KL controller instance (FixedKLController or AdaptiveKLController). + + Raises: + NotImplementedError: If controller type is not supported. + AssertionError: If adaptive controller horizon is not positive. + """ + if kl_ctrl.type == "fixed": + return FixedKLController(kl_coef=kl_ctrl.kl_coef) + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + else: + raise NotImplementedError + + +@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") +def compute_gae_advantage_return( + token_level_rewards: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + gamma: torch.Tensor, + lam: torch.Tensor, +): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + values: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma is `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + + """ + with torch.no_grad(): + nextvalues = 0 + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + for t in reversed(range(gen_len)): + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + lastgaelam_ = delta + gamma * lam * lastgaelam + + # skip values and TD-error on observation tokens + nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues + lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam + + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = verl_F.masked_whiten(advantages, response_mask) + return advantages, returns + + +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length) + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + config: `(Optional[AlgoConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape is (bs, response_length) + Returns: `(torch.Tensor)` + shape is (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") +def compute_grpo_passk_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for Pass@k using a GRPO-style outcome reward formulation. + Only the best response per group gets a non-zero advantage: r_max - r_second_max. + + Implemented as described in https://arxiv.org/abs/2503.19595. + + Args: + token_level_rewards: (bs, response_length) + response_mask: (bs, response_length) + index: (bs,) → group ID per sample + epsilon: float for numerical stability + config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" + + Returns: + advantages: (bs, response_length) + returns: (bs, response_length) + """ + assert config is not None + # if True, normalize advantage by std within group + norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) + scores = token_level_rewards.sum(dim=-1) # (bs,) + advantages = torch.zeros_like(scores) + + id2scores = defaultdict(list) + id2indices = defaultdict(list) + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + idx = index[i] + id2scores[idx].append(scores[i]) + id2indices[idx].append(i) + + for idx in id2scores: + rewards = torch.stack(id2scores[idx]) # (k,) + if rewards.numel() < 2: + raise ValueError( + f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." + ) + topk, topk_idx = torch.topk(rewards, 2) + r_max, r_second_max = topk[0], topk[1] + i_max = id2indices[idx][topk_idx[0].item()] + advantage = r_max - r_second_max + if norm_adv_by_std_in_grpo: + std = torch.std(rewards) + advantage = advantage / (std + epsilon) + advantages[i_max] = advantage + + advantages = advantages.unsqueeze(-1) * response_mask + return advantages, advantages + + +@register_adv_est( + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE +) # or simply: @register_adv_est("reinforce_plus_plus_baseline") +def compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2mean[index[i]] + + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + scores = verl_F.masked_whiten(scores, response_mask) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") +def compute_rloo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( + response_num - 1 + ) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") +def compute_opo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = response_mask.sum(dim=-1) + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2len = defaultdict(list) + id2bsl = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + id2len[index[i]].append(response_length[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2bsl[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + score_tensor = torch.tensor(id2score[idx]) + len_tensor = torch.tensor(id2len[idx]) + id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum() + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2bsl[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +def compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + assert config is not None + gamma = config.gamma + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + # Reset after EOS + running_return = running_return * response_mask[:, t] + + advantages = verl_F.masked_whiten(returns, response_mask) + advantages = advantages * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") +def compute_remax_outcome_advantage( + token_level_rewards: torch.Tensor, + reward_baselines: torch.Tensor, + response_mask: torch.Tensor, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for ReMax, operating only on Outcome reward + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + advantages = returns - reward_baselines.unsqueeze(-1) * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg") +def compute_gpg_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + f_norm: float = 1.0, + alpha: float = 1.0, + config=None, + **kwargs, +): + """ + Compute advantage for GPG, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + index: `(np.ndarray)` + shape: (bs,) + epsilon: (float) + f_norm: (float) + alpha: (float) + config: (dict) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + m = torch.count_nonzero(scores) + alpha = bsz / m.clamp(min=1) + + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): + """Compute token-level rewards with KL penalty. + + Args: + token_level_scores (torch.Tensor): Token-level reward scores. + old_log_prob (torch.Tensor): Log probabilities from current policy. + ref_log_prob (torch.Tensor): Log probabilities from reference policy. + kl_ratio (float): KL penalty coefficient. + + Returns: + torch.Tensor: Token-level rewards with KL penalty applied. + """ + kl = old_log_prob - ref_log_prob + return token_level_scores - kl * kl_ratio + + +def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): + """ + Aggregate the loss matrix into a scalar. + + Args: + loss_mat: `(torch.Tensor)`: + shape: (bs, response_length) + loss_mask: `(torch.Tensor)`: + shape: (bs, response_length) + loss_agg_mode: (str) choices: + method to aggregate the loss matrix into a scalar. + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor + # (loss_mask.shape[-1]) should ideally be constant + # throughout training to well-replicate the DrGRPO paper. + # TODO: Perhaps add user-defined normalizer argument to + # agg_loss to ensure divisor stays constant throughout. + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + + +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + clip_ratio_c (float, optional): + Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + Defaults to 3.0. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + """ + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +@register_policy_loss("gpg") +def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): + """Adapted from + https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 + Args: + log_prob: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + return: + pg_loss: `a scalar torch.Tensor` + policy gradient loss computed via GPG + """ + pg_losses = -log_prob * advantages + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) + + +@register_policy_loss("clip_cov") +def compute_policy_loss_clip_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + clip_cvo_ratio (float, optional): + Ratio for clipping the covariance. Defaults to 0.0002. + clip_cov_lb (float, optional): + Lower bound for clipping covariance. Defaults to 1.0. + clip_cov_ub (float, optional): + Upper bound for clipping covariance. Defaults to 5.0. + """ + clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 + cliprange = config.clip_ratio + cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 + clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 + + assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + corr = torch.ones_like(advantages) + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) + + cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( + log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) + ) + cov_all[response_mask == 0] = -torch.inf + cov_all[clip_by_origin] = -torch.inf + + clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) + top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) + top_k_idx = torch.nonzero(top_k_idx) + + if len(top_k_idx) > 0: + perm = torch.randperm(len(top_k_idx)) + top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] + else: + top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) + + corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 + + pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) + + pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) + + +@register_policy_loss("kl_cov") +def compute_policy_loss_kl_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + kl_cov_ratio (float, optional): + Ratio for selecting the top-k covariance values. Defaults to 0.0002. + ppo_kl_coef (float, optional): + Coefficient for the KL penalty term in the loss. Defaults to 1. + """ + kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 + ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 + + assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + abs_kl = negative_approx_kl.abs() + ratio = torch.exp(negative_approx_kl) + ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) + pg_losses1 = -advantages * ratio + pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl + pg_losses = pg_losses1 + + all_valid = response_mask > 0 + all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] + all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() + all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() + + k = min(kl_cov_ratio, len(all_valid_adv)) + + if k != 0: + cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) + k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) + large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices + + if len(large_cov_idxs) != 0: + large_cov_idxs = all_valid_idx[large_cov_idxs] + pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ + large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] + ] + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) + + +def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): + """Compute categorical entropy loss (For backward compatibility) + + Args: + logits (torch.Tensor): shape is (bs, response_length, vocab_size) + response_mask (torch.Tensor): shape is (bs, response_length) + + Returns: + entropy: a scalar torch.Tensor + + """ + # compute entropy + token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return entropy_loss + + +def compute_value_loss( + vpreds: torch.Tensor, + returns: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + cliprange_value: float, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped value-function loss for PPO. + + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 + + Args: + vpreds (torch.FloatTensor): + Predicted values from the value head, shape (batch_size, response_length). + values (torch.FloatTensor): + Old (baseline) values from the value head, shape (batch_size, response_length). + returns (torch.FloatTensor): + Ground-truth returns, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the value loss calculation. + cliprange_value (float): + Clip range for value prediction updates. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + + Returns: + vf_loss (torch.FloatTensor): + A scalar tensor containing the aggregated value-function loss. + vf_clipfrac (float): + Fraction of elements where the clipped loss was used. + """ + vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + clipped_vf_losses = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) + return vf_loss, vf_clipfrac + + +def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 + See more description in http://joschu.net/blog/kl-approx.html + + Args: + logprob: + ref_logprob: + + Returns: + + """ + if kl_penalty in ("kl", "k1"): + return logprob - ref_logprob + + if kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if kl_penalty in ("mse", "k2"): + return 0.5 * (logprob - ref_logprob).square() + + # J. Schulman. Approximating kl divergence, 2020. + # # URL http://joschu.net/blog/kl-approx.html. + if kl_penalty in ("low_var_kl", "k3"): + kl = ref_logprob - logprob + # For numerical stability + kl = torch.clamp(kl, min=-20, max=20) + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + + if kl_penalty == "full": + # so, here logprob and ref_logprob should contain the logits for every token in vocabulary + raise NotImplementedError + + raise NotImplementedError + + +def compute_pf_ppo_reweight_data( + data, + reweight_method: str = "pow", + weight_pow: float = 2.0, +): + """Reweight the data based on the token_level_scores. + + Args: + data: DataProto object, containing batch, non_tensor_batch and meta_info + reweight_method: str, choices: "pow", "max_min", "max_random" + weight_pow: float, the power of the weight + + Returns: + + """ + + @torch.no_grad() + def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: + """Compute importance weights for resampling based on scores. + + Args: + scores (torch.Tensor): Tensor of scores to compute weights from. + reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). + weight_pow (float): Power exponent for 'pow' method. + + Returns: + torch.Tensor: Computed importance weights. + + Raises: + ValueError: If reweight_method is not supported. + """ + if reweight_method == "pow": + weights = torch.pow(torch.abs(scores), weight_pow) + elif reweight_method == "max_min": + max_score = torch.max(scores) + min_score = torch.min(scores) + weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) + elif reweight_method == "max_random": + max_score = torch.max(scores) + weights = torch.where(scores == max_score, 0.4, 0.1) + else: + raise ValueError(f"Unsupported reweight_method: {reweight_method}") + return weights + + scores = data.batch["token_level_scores"].sum(dim=-1) + weights = compute_weights(scores, reweight_method, weight_pow) + weights = torch.clamp(weights + 1e-8, min=1e-8) + + batch_size = scores.shape[0] + sample_indices = torch.multinomial(weights, batch_size, replacement=True) + + resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} + + sample_indices_np = sample_indices.numpy() + resampled_non_tensor_batch = {} + for key, array in data.non_tensor_batch.items(): + if isinstance(array, np.ndarray): + resampled_non_tensor_batch[key] = array[sample_indices_np] + else: + resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] + + resampled_meta_info = {} + for key, value in data.meta_info.items(): + if isinstance(value, list) and len(value) == batch_size: + resampled_meta_info[key] = [value[i] for i in sample_indices_np] + else: + resampled_meta_info[key] = value + + from copy import deepcopy + + resampled_data = deepcopy(data) + resampled_data.batch = type(data.batch)(resampled_batch) + resampled_data.batch.batch_size = data.batch.batch_size + resampled_data.non_tensor_batch = resampled_non_tensor_batch + resampled_data.meta_info = resampled_meta_info + + return resampled_data diff --git a/toolbox/verl/v0.5.0/verl/trainer/ppo/metric_utils.py b/toolbox/verl/v0.5.0/verl/trainer/ppo/metric_utils.py new file mode 100644 index 000000000..3b6b47bf0 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/ppo/metric_utils.py @@ -0,0 +1,446 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics related to the PPO trainer. +""" + +from collections import defaultdict +from functools import partial +from typing import Any, Callable + +import numpy as np +import torch + +from verl import DataProto +from verl.utils.import_utils import deprecated + + +@deprecated("verl.utils.metric.reduce_metrics") +def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: + """ + Reduces a dictionary of metric lists by computing the mean of each list. + + Args: + metrics: A dictionary mapping metric names to lists of metric values. + + Returns: + A dictionary with the same keys but with each list replaced by its mean value. + + Example: + >>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]} + >>> reduce_metrics(metrics) + {"loss": 2.0, "accuracy": 0.8} + """ + from verl.utils.metric import reduce_metrics + + return reduce_metrics(metrics) + + +def _compute_response_info(batch: DataProto) -> dict[str, Any]: + """ + Computes information about prompts and responses from a batch. + + This is an internal helper function that extracts masks and lengths for prompts and responses. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + + Returns: + A dictionary containing: + - response_mask: Attention mask for the response tokens + - prompt_length: Tensor of prompt lengths for each item in the batch + - response_length: Tensor of response lengths for each item in the batch + """ + response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-response_length] + response_mask = batch.batch["attention_mask"][:, -response_length:] + + prompt_length = prompt_mask.sum(-1).float() + response_length = response_mask.sum(-1).float() # (batch_size,) + + return dict( + response_mask=response_mask, + prompt_length=prompt_length, + response_length=response_length, + ) + + +def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + - num_turns/mean, max, min: Statistics about the number of multi-turn conversations + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["response_mask"].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # score + "critic/score/mean": torch.mean(sequence_score).detach().item(), + "critic/score/max": torch.max(sequence_score).detach().item(), + "critic/score/min": torch.min(sequence_score).detach().item(), + # reward + "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), + "critic/rewards/max": torch.max(sequence_reward).detach().item(), + "critic/rewards/min": torch.min(sequence_reward).detach().item(), + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + + # multi-turn conversation + if "__num_turns__" in batch.non_tensor_batch: + num_turns = batch.non_tensor_batch["__num_turns__"] + metrics["num_turns/min"] = num_turns.min() + metrics["num_turns/max"] = num_turns.max() + metrics["num_turns/mean"] = num_turns.mean() + + return metrics + + +def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]: + """ + Computes timing metrics for different processing stages in PPO training. + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, + value computation, advantage computation, and model updates. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage + + Note: + Different stages use different token counts for normalization: + - "gen" uses only response tokens + - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens + (prompt + response) + """ + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: + """ + Computes throughput metrics for PPO training. + + This function calculates performance metrics related to token processing speed, + including the total number of tokens processed, time per step, and throughput + (tokens per second per GPU). + + Args: + batch: A DataProto object containing batch data with meta information about token counts. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + n_gpus: Number of GPUs used for training. + + Returns: + A dictionary containing: + - perf/total_num_tokens: Total number of tokens processed in the batch + - perf/time_per_step: Time taken for the step in seconds + - perf/throughput: Tokens processed per second per GPU + + Note: + The throughput is calculated as total_tokens / (time * n_gpus) to normalize + across different GPU counts. + """ + total_num_tokens = sum(batch.meta_info["global_token_num"]) + time = timing_raw["step"] + # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time) + # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus), + # f'Theoretical TFLOPs/s/GPU​': promised_flops, + return { + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + "perf/throughput": total_num_tokens / (time * n_gpus), + } + + +def bootstrap_metric( + data: list[Any], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42, +) -> list[tuple[float, float]]: + """ + Performs bootstrap resampling to estimate statistics of metrics. + + This function uses bootstrap resampling to estimate the mean and standard deviation + of metrics computed by the provided reduction functions on random subsets of the data. + + Args: + data: List of data points to bootstrap from. + subset_size: Size of each bootstrap sample. + reduce_fns: List of functions that compute a metric from a subset of data. + n_bootstrap: Number of bootstrap iterations. Defaults to 1000. + seed: Random seed for reproducibility. Defaults to 42. + + Returns: + A list of tuples, where each tuple contains (mean, std) for a metric + corresponding to each reduction function in reduce_fns. + + Example: + >>> data = [1, 2, 3, 4, 5] + >>> reduce_fns = [np.mean, np.max] + >>> bootstrap_metric(data, 3, reduce_fns) + [(3.0, 0.5), (4.5, 0.3)] # Example values + """ + np.random.seed(seed) + + bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] + for _ in range(n_bootstrap): + bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) + bootstrap_data = [data[i] for i in bootstrap_idxs] + for i, reduce_fn in enumerate(reduce_fns): + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data)) + return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] + + +def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float: + """ + Calculate a value based on majority voting. + + This function identifies the most common value for a specified vote key + in the data, then returns the corresponding value for that majority vote. + + Args: + data: List of dictionaries, where each dictionary contains both vote_key and val_key. + vote_key: The key in each dictionary used for voting/counting. + val_key: The key in each dictionary whose value will be returned for the majority vote. + + Returns: + The value associated with the most common vote. + + Example: + >>> data = [ + ... {"pred": "A", "val": 0.9}, + ... {"pred": "B", "val": 0.8}, + ... {"pred": "A", "val": 0.7} + ... ] + >>> calc_maj_val(data, vote_key="pred", val_key="val") + 0.9 # Returns the first "val" for the majority vote "A" + """ + vote2vals = defaultdict(list) + for d in data: + vote2vals[d[vote_key]].append(d[val_key]) + + vote2cnt = {k: len(v) for k, v in vote2vals.items()} + maj_vote = max(vote2cnt, key=vote2cnt.get) + + maj_val = vote2vals[maj_vote][0] + + return maj_val + + +def process_validation_metrics( + data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: + """ + Process validation metrics into a structured format with statistical analysis. + + This function organizes validation metrics by data source and prompt, then computes + various statistical measures including means, standard deviations, best/worst values, + and majority voting results. It also performs bootstrap sampling to estimate statistics + for different sample sizes. + + Args: + data_sources: List of data source identifiers for each sample. + sample_inputs: List of input prompts corresponding to each sample. + infos_dict: Dictionary mapping variable names to lists of values for each sample. + seed: Random seed for bootstrap sampling. Defaults to 42. + + Returns: + A nested dictionary with the structure: + { + data_source: { + variable_name: { + metric_name: value + } + } + } + + Where metric_name includes: + - "mean@N": Mean value across N samples + - "std@N": Standard deviation across N samples + - "best@N/mean": Mean of the best values in bootstrap samples of size N + - "best@N/std": Standard deviation of the best values in bootstrap samples + - "worst@N/mean": Mean of the worst values in bootstrap samples + - "worst@N/std": Standard deviation of the worst values in bootstrap samples + - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) + - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) + + Example: + >>> data_sources = ["source1", "source1", "source2"] + >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] + >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} + >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict) + >>> # result will contain statistics for each data source and variable + """ + # Group metrics by data source, prompt and variable + data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + prompt = sample_inputs[sample_idx] + var2vals = data_src2prompt2var2vals[data_source][prompt] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + # Calculate metrics for each group + data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): + for prompt, var2vals in prompt2var2vals.items(): + for var_name, var_vals in var2vals.items(): + if isinstance(var_vals[0], str): + continue + + metric = {} + n_resps = len(var_vals) + metric[f"mean@{n_resps}"] = np.mean(var_vals) + + if n_resps > 1: + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + for n in ns: + [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + ) + metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + if var2vals.get("pred", None) is not None: + vote_data = [ + {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True) + ] + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + seed=seed, + ) + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + + data_src2prompt2var2metric[data_source][prompt][var_name] = metric + + # Aggregate metrics across prompts + data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): + for prompt, var2metric in prompt2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): + for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): + for metric_name, prompt_vals in metric2prompt_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + + return data_src2var2metric2val diff --git a/toolbox/verl/v0.5.0/verl/trainer/ppo/ray_trainer.py b/toolbox/verl/v0.5.0/verl/trainer/ppo/ray_trainer.py new file mode 100644 index 000000000..45ed80a4d --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/ppo/ray_trainer.py @@ -0,0 +1,1421 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pprint import pprint +from typing import Optional + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger + +WorkerType = type[Worker] + + +class Role(Enum): + """ + To create more roles dynamically, you can subclass Role and add new members + """ + + Actor = 0 + Rollout = 1 + ActorRollout = 2 + Critic = 3 + RefPolicy = 4 + RewardModel = 5 + ActorRolloutRef = 6 + + +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[Role, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + """Create Ray resource pools for distributed training. + + Initializes resource pools based on the resource pool specification, + with each pool managing GPU resources across multiple nodes. + For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. + For Megatron backend, uses max_colocate_count>1 for different models. + """ + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For Megatron backend, we recommend using max_colocate_count>1 + # that can utilize different WorkerGroup for differnt models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role: Role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray.state.available_resources_per_node() + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + # check each resource pool can be satisfied, O(#resource_pools * #nodes) + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) + for node, available_gpus in node_available_gpus.items(): + if available_gpus >= num_gpus: + node_available_gpus[node] -= num_gpus + num_nodes -= 1 + if num_nodes == 0: + break + if num_nodes > 0: + raise ValueError( + f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + + "cannot be satisfied in this ray cluster" + ) + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ + response_mask = data.batch["response_mask"] + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + if config.get("use_pf_ppo", False): + data = core_algos.compute_pf_ppo_reweight_data( + data, + config.pf_ppo.reweight_method, + config.pf_ppo.weight_pow, + ) + elif adv_estimator == AdvantageEstimator.GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class RayPPOTrainer: + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, and vLLM integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = Role.RefPolicy in role_worker_mapping + self.use_rm = Role.RewardModel in role_worker_mapping + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: + self.use_critic = True + elif self.config.algorithm.adv_estimator in [ + AdvantageEstimator.GRPO, + AdvantageEstimator.GRPO_PASSK, + AdvantageEstimator.REINFORCE_PLUS_PLUS, + AdvantageEstimator.REMAX, + AdvantageEstimator.RLOO, + AdvantageEstimator.OPO, + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, + AdvantageEstimator.GPG, + ]: + self.use_critic = False + else: + raise NotImplementedError + + self._validate_config() + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _validate_config(self): + config = self.config + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + if config.actor_rollout_ref.actor.strategy == "megatron": + model_parallel_size = ( + config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size + * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size + ) + assert ( + n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0 + ), ( + f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times " + f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" + ) + megatron_dp = n_gpus // ( + model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size + ) + minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu + else: + minimal_bsz = n_gpus + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % minimal_bsz == 0, ( + f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " + f"({minimal_bsz})" + ) + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + """Validate mutually exclusive micro batch size configuration options. + + Ensures that users don't set both deprecated micro_batch_size and + the new micro_batch_size_per_gpu parameters simultaneously. + + Args: + mbs: Deprecated micro batch size parameter value. + mbs_per_gpu: New micro batch size per GPU parameter value. + name (str): Configuration section name for error messages. + + Raises: + ValueError: If both parameters are set or neither is set. + """ + settings = { + "actor_rollout_ref.actor": "micro_batch_size", + "critic": "micro_batch_size", + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError( + f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." + ) + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError( + f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " + f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." + ) + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) + + if self.use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + if self.use_critic and not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive( + config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + ) + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) + + # Actor + # check if train_batch_size is larger than ppo_mini_batch_size + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert ( + config.actor_rollout_ref.actor.ppo_mini_batch_size + % config.actor_rollout_ref.actor.ppo_micro_batch_size + == 0 + ) + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + assert config.actor_rollout_ref.actor.loss_agg_mode in [ + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", + "seq-mean-token-sum-norm", + ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" + + if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + print("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if self.use_critic and not config.critic.use_dynamic_bsz: + assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + # Check if use_remove_padding is enabled when using sequence parallelism for fsdp + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"} and ( + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + ): + assert config.actor_rollout_ref.model.use_remove_padding, ( + "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) + + if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, ( + "When using sequence parallelism for critic, you must enable `use_remove_padding`." + ) + + if config.data.get("val_batch_size", None) is not None: + print( + "WARNING: val_batch_size is deprecated." + + " Validation datasets are sent to inference engines as a whole batch," + + " which will schedule the memory themselves." + ) + + # check eval config + if config.actor_rollout_ref.rollout.val_kwargs.do_sample: + assert config.actor_rollout_ref.rollout.temperature > 0, ( + "validation gen temperature should be greater than 0 when enabling do_sample" + ) + + print("[validate_config] All configuration checks passed successfully!") + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + sample_turns = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + if "agent_name" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("agent_name") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( + "worker_nsight_options must be set when profile_steps is set" + ) + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.trainer, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.experimental.agent_loop import AgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = AgentLoopManager( + config=self.config, + worker_group=self.actor_rollout_wg, + ) + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + with marked_timer("start_profile", timing_raw): + self._start_profiling(do_profile) + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + if "index" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("index") + if "agent_name" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("agent_name") + + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + self._stop_profiling(do_profile) + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/toolbox/verl/v0.5.0/verl/trainer/ppo/reward.py b/toolbox/verl/v0.5.0/verl/trainer/ppo/reward.py new file mode 100644 index 000000000..6362f7856 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/ppo/reward.py @@ -0,0 +1,179 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +from functools import partial + +import ray + +from verl import DataProto +from verl.utils.reward_score import default_compute_score + + +def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return raw_fn(*args, **merged_kwargs) + + +def get_custom_reward_fn(config): + """Load and return a custom reward function from external file. + + Dynamically imports a reward function from a specified file path and wraps + it with additional keyword arguments from the configuration. + + Args: + config (dict): Configuration dictionary containing custom_reward_function + settings with 'path', 'name', and 'reward_kwargs' fields. + + Returns: + callable or None: Wrapped reward function with merged kwargs, or None + if no custom reward function is configured. + + Raises: + FileNotFoundError: If the specified reward function file doesn't exist. + RuntimeError: If there's an error loading the module from file. + AttributeError: If the specified function name isn't found in the module. + """ + import importlib.util + import sys + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + sys.modules["custom_module"] = module + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e + + function_name = reward_fn_config.get("name") + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + + print(f"using customized reward function '{function_name}' from '{file_path}'") + raw_fn = getattr(module, function_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + + return partial(_call_with_kwargs, raw_fn, reward_kwargs) + + +def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): + """ + Load and initialize a reward manager based on the configuration. + + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. + + Returns: + An instance of the specified reward manager class. + """ + from verl.workers.reward_manager import get_reward_manager_cls + + # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: + # naive: NaiveRewardManager + # prime: PrimeRewardManager + # batch: BatchRewardManager + # dapo: DAPORewardManager + # Note(haibin.lin): For custom reward managers, please make sure they are imported and + # registered via `verl.workers.reward_manager.register` + # By default reward_manager is set to naive (NaiveRewardManager) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + + # Try to get a custom reward function based on the configuration + compute_score = get_custom_reward_fn(config) + final_compute_score = compute_score + + if compute_score is None: + sandbox_config = config.reward_model.get("sandbox_fusion") + sandbox_url = sandbox_config.get("url") if sandbox_config else None + memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) + if sandbox_url: + sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox + _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + final_compute_score = partial( + default_compute_score, + sandbox_fusion_url=sandbox_url, + concurrent_semaphore=_concurrent_semaphore, + memory_limit_mb=memory_limit_mb, + ) + else: + final_compute_score = default_compute_score + + # Instantiate and return the reward manager with the specified parameters + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=final_compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +def compute_reward(data: DataProto, reward_fn): + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + reward_fn: Reward function to compute the reward. + Returns: + Tuple of reward tensor and extra info dictionary. + """ + try: + reward_result = reward_fn(data, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = reward_fn(data) + reward_extra_infos_dict = {} + + return reward_tensor, reward_extra_infos_dict + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + if reward_fn is None: + assert config is not None and tokenizer is not None, ( + "config and tokenizer must not be None when reward_fn is None" + ) + import warnings + + warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + + return compute_reward(data, reward_fn) diff --git a/toolbox/verl/v0.5.0/verl/trainer/runtime_env.yaml b/toolbox/verl/v0.5.0/verl/trainer/runtime_env.yaml new file mode 100644 index 000000000..63750cd72 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/trainer/runtime_env.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + CUDA_DEVICE_MAX_CONNECTIONS: "1" diff --git a/toolbox/verl/v0.5.0/verl/utils/__init__.py b/toolbox/verl/v0.5.0/verl/utils/__init__.py new file mode 100644 index 000000000..034584945 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import config, tokenizer +from .config import omega_conf_to_dataclass +from .tokenizer import hf_processor, hf_tokenizer + +__all__ = tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass"] diff --git a/toolbox/verl/v0.5.0/verl/utils/activation_offload.py b/toolbox/verl/v0.5.0/verl/utils/activation_offload.py new file mode 100644 index 000000000..73e2e83eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/activation_offload.py @@ -0,0 +1,558 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" + +from __future__ import annotations + +import functools +import logging +import os +from typing import Any, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl.utils.device import get_torch_device +from verl.utils.fsdp_utils import FSDPModule as FSDP2 + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +class FSDPParameterFilter: + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with get_torch_device().stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with get_torch_device().stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f"{group_id} {state}" + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + get_torch_device().current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context( + num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True) +): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True)) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret,) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == "fsdp" or strategy == "fsdp2", "activation offloading only supports fsdp strategy" + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, FSDP | FSDP2): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading") + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, "gradient_checkpointing_disable"): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) diff --git a/toolbox/verl/v0.5.0/verl/utils/checkpoint/__init__.py b/toolbox/verl/v0.5.0/verl/utils/checkpoint/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/checkpoint/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/toolbox/verl/v0.5.0/verl/utils/checkpoint/checkpoint_manager.py b/toolbox/verl/v0.5.0/verl/utils/checkpoint/checkpoint_manager.py new file mode 100644 index 000000000..ff861abf3 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/checkpoint/checkpoint_manager.py @@ -0,0 +1,237 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import shutil + +import numpy as np +import torch +import torch.distributed +from omegaconf import DictConfig +from transformers import PreTrainedTokenizer, ProcessorMixin + +from verl.utils.device import get_device_name, get_torch_device + + +class BaseCheckpointManager: + """ + A checkpoint manager that saves and loads + - model + - optimizer + - lr_scheduler + - extra_states + in a SPMD way. + + We save + - sharded model states and optimizer states + - full lr_scheduler states + - huggingface tokenizer and config for ckpt merge + """ + + def __init__( + self, + model, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, + processing_class: PreTrainedTokenizer | ProcessorMixin = None, + checkpoint_config: DictConfig = None, + ): + self.checkpoint_config = checkpoint_config + checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None + checkpoint_save_contents = checkpoint_config.get("save_contents", None) if checkpoint_config else None + if checkpoint_load_contents is None: + checkpoint_load_contents = ["model", "optimizer", "extra"] + if checkpoint_save_contents is None: + checkpoint_save_contents = ["model", "optimizer", "extra"] + self.previous_global_step = None + self.previous_saved_paths = [] + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.processing_class = processing_class + self.checkpoint_load_contents = checkpoint_load_contents + self.checkpoint_save_contents = checkpoint_save_contents + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + @property + def should_save_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved. + """ + return "model" in self.checkpoint_save_contents + + @property + def should_save_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved. + """ + return "optimizer" in self.checkpoint_save_contents + + @property + def should_save_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved. + """ + return "extra" in self.checkpoint_save_contents + + @property + def should_save_hf_model(self) -> bool: + """ + Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf + model and saved. + """ + return "hf_model" in self.checkpoint_save_contents + + @property + def should_load_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded. + """ + return "model" in self.checkpoint_load_contents + + @property + def should_load_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded. + """ + return "optimizer" in self.checkpoint_load_contents + + @property + def should_load_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded. + """ + return "extra" in self.checkpoint_load_contents + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): + raise NotImplementedError + + def save_checkpoint( + self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None + ): + raise NotImplementedError + + @staticmethod + def checkpath(local_path: str, hdfs_path: str): + assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None" + return local_path is not None, local_path if local_path is not None else hdfs_path + + def remove_previous_save_local_path(self, path): + if isinstance(path, str): + path = [path] + for p in path: + abs_path = os.path.abspath(p) + print(f"Checkpoint manager remove previous save local path: {abs_path}") + if not os.path.exists(abs_path): + continue + shutil.rmtree(abs_path, ignore_errors=True) + + @staticmethod + def get_rng_state(): + rng_state = { + "cpu": torch.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + + if get_device_name() != "cpu": + rng_state[get_device_name()] = get_torch_device().get_rng_state() + + return rng_state + + @staticmethod + def load_rng_state(rng_state): + torch.set_rng_state(rng_state["cpu"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_state[get_device_name()]) + + +def find_latest_ckpt_path(path, directory_format="global_step_{}"): + """ + Return the most recent checkpoint directory based on a tracker file. + + Args: + path (str): Base directory containing the checkpoint tracker. + directory_format (str): Template for checkpoint subfolders with one + placeholder for the iteration number (default "global_step_{}"). + + Returns: + str or None: Full path to the latest checkpoint directory, or + None if the tracker or checkpoint folder is missing. + """ + if path is None: + return None + + tracker_file = get_checkpoint_tracker_filename(path) + if not os.path.exists(tracker_file): + print(f"Checkpoint tracker file does not exist: {tracker_file}") + return None + + with open(tracker_file, "rb") as f: + iteration = int(f.read().decode()) + ckpt_path = os.path.join(path, directory_format.format(iteration)) + if not os.path.exists(ckpt_path): + print("Checkpoint does not exist: %s", ckpt_path) + return None + + print("Found checkpoint: %s", ckpt_path) + return ckpt_path + + +def get_checkpoint_tracker_filename(root_path: str): + """ + Tracker file rescords the latest chckpoint during training to restart from. + """ + return os.path.join(root_path, "latest_checkpointed_iteration.txt") + + +def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool: + """ + Determine if checkpoint should be saved based on capacity esi expiration. + + Args: + max_steps_duration: Max estimated time (seconds) required to complete one training step + save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60) + redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0) + """ + exp_ts_mlp = os.getenv("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # vemlp + exp_ts_aws = os.getenv("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # aws + if exp_ts_mlp: + try: + import time + + remaining = float(exp_ts_mlp) - time.time() + except ValueError: + return False + return ( + remaining > 0 + and max_steps_duration > 0 + and remaining <= save_ckpt_duration + max_steps_duration + redundant_time + ) + elif exp_ts_aws: + from datetime import datetime, timedelta + + expiration_time = datetime.fromtimestamp(int(exp_ts_aws)) + time_difference = expiration_time - datetime.now() + threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60 + return time_difference < timedelta(minutes=threshold_minutes) + else: + return False diff --git a/toolbox/verl/v0.5.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/toolbox/verl/v0.5.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py new file mode 100644 index 000000000..14e448da6 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -0,0 +1,350 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import warnings +from dataclasses import asdict, dataclass +from typing import Optional + +import torch +import torch.distributed +from accelerate import init_empty_weights +from omegaconf import DictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType +from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin + +from verl.utils.device import is_cuda_available +from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe +from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx +from verl.utils.logger import log_with_rank + +from .checkpoint_manager import BaseCheckpointManager + +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +@dataclass +class FSDPConfig: + """Configuration for FSDP checkpointing. + + Args: + FSDP_version (int): Version of FSDP being used. + world_size (int): Number of processes in the distributed training setup. + """ + + FSDP_version: int + world_size: int + + +class FSDPCheckpointManager(BaseCheckpointManager): + """ + Manage FSDP checkpointing in SPMD training. + + - Saves/loads per-rank sharded model & optimizer states + - Persists full lr_scheduler and RNG state + - Stores HF tokenizer/processor and model/config for unified restore + + Args: + model (FSDP): Wrapped model instance. + optimizer (Optimizer): Training optimizer. + lr_scheduler (LRScheduler): Learning-rate scheduler. + processing_class (PreTrainedTokenizer or ProcessorMixin, optional): + Pre-/post-processing artifact handler. + checkpoint_contents DictConfig: Configuration for checkpoint contents. + - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. + - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. + """ + + def __init__( + self, + model: FSDP, + optimizer: Optional[torch.optim.Optimizer] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + processing_class: PreTrainedTokenizer | ProcessorMixin = None, + checkpoint_config: DictConfig = None, + **kwargs, + ): + if processing_class is None: + assert "tokenizer" in kwargs, "tokenizer or processor must be provided" + warnings.warn( + "`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2 + ) + processing_class = kwargs.pop("tokenizer") + + super().__init__( + model, + optimizer, + lr_scheduler=lr_scheduler, + processing_class=processing_class, + checkpoint_config=checkpoint_config, + ) + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + """ + Load an FSDP checkpoint for this rank. + + Downloads and loads: + - model and optimizer shards + - extra state dict (scheduler + RNG) + + Args: + local_path: Directory with per-rank checkpoint files. + hdfs_path: Unused (for API compatibility). + del_local_after_load: Remove local files after loading. + """ + if local_path is None: + return + + # check if the checkpoint_load_contents is valid + if self.should_load_model: + assert self.model is not None, "model must be provided when checkpoint_contents.load includes ['model']" + if self.should_load_optimizer: + assert self.optimizer is not None, ( + "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" + ) + + # every rank download its own checkpoint + state_dict_cfg = ( + ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + if self.should_load_model + else None + ) + optim_cfg = ( + ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + if self.should_load_optimizer + else None + ) + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + if self.should_load_model: + remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + local_model_path = copy_to_local(remote_model_path) + model_state_dict = torch.load(local_model_path, weights_only=False) + self.model.load_state_dict(model_state_dict) + log_with_rank(f"Loaded model from {remote_model_path}", rank=self.rank, logger=logger) + + if self.should_load_optimizer: + remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + local_optim_path = copy_to_local(remote_optim_path) + optimizer_state_dict = torch.load(local_optim_path, weights_only=False) + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank(f"Loaded optimizer from {remote_optim_path}", rank=self.rank, logger=logger) + + if self.should_load_extra: + remote_extra_state_path = os.path.join( + local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + local_extra_state_path = copy_to_local(remote_extra_state_path) + extra_state_dict = torch.load(local_extra_state_path, weights_only=False) + # recover random state + if "rng" in extra_state_dict: + # 'rng' may not exist for backward compatibility + self.load_rng_state(extra_state_dict["rng"]) + log_with_rank(f"Loaded rng from {remote_extra_state_path}", rank=self.rank, logger=logger) + + lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] + if lr_scheduler_state_dict is not None and self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank(f"Loaded lr_scheduler from {remote_extra_state_path}", rank=self.rank, logger=logger) + + if self.rank == 0 and del_local_after_load: + try: + os.remove(local_model_path) if is_non_local(local_model_path) else None + os.remove(local_optim_path) if is_non_local(local_optim_path) else None + os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None + except Exception as e: + log_with_rank( + f"remove local resume ckpt file after loading failed, exception {e} will be ignored", + rank=self.rank, + logger=logger, + ) + + # wait for everyone to load checkpoints + torch.distributed.barrier() + + def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + """ + Save an FSDP checkpoint for this rank. + + Writes: + - model & optimizer shard files + - extra state dict (scheduler + RNG) + - HF tokenizer/processor and model/config on rank 0 + - optional full HF model under 'huggingface/' if requested + + Rotates old checkpoints, keeping at most `max_ckpt_to_keep`. + + Args: + local_path: Target directory for checkpoint files. + hdfs_path: Unused (for API compatibility). + global_step: Current training step (used for bookkeeping). + max_ckpt_to_keep: Number of recent checkpoints to retain. + """ + if local_path is None: + return + + # record the previous global step + self.previous_global_step = global_step + + # remove previous local_path, only rank 0 should do this + if ( + self.rank == 0 + and max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) + self.previous_saved_paths = self.previous_saved_paths[keep_start:] + + local_path = local_mkdir_safe(local_path) + torch.distributed.barrier() + + # check if the checkpoint_save_contents is valid + if self.should_save_model: + assert self.model is not None, "model must be provided when checkpoint_contents.save includes ['model']" + if self.should_save_optimizer: + assert self.optimizer is not None, ( + "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" + ) + + # every rank will save its own model and optim shard + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + + if self.should_save_model: + model_state_dict = self.model.state_dict() + torch.save(model_state_dict, model_path) + log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) + + if self.should_save_optimizer: + optimizer_state_dict = self.optimizer.state_dict() + torch.save(optimizer_state_dict, optim_path) + log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) + + if self.should_save_extra: + lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + torch.save(extra_state_dict, extra_path) + log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) + + if self.rank == 0: + # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether + # huggingface model is requested to be saved or not. + + if fsdp_version(self.model) == 1: + unwrap_model = self.model._fsdp_wrapped_module + else: + unwrap_model = self.model + + hf_config_tokenizer_path = os.path.join(local_path, "huggingface") + local_mkdir_safe(hf_config_tokenizer_path) + model_config = unwrap_model.config + generation_config = None + if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: + try: + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(hf_config_tokenizer_path) + except Exception: + # if the generation config isn't available, we don't save it + pass + + model_config.save_pretrained(hf_config_tokenizer_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(local_path, "fsdp_config.json") + fsdp_config = FSDPConfig( + FSDP_version=fsdp_version(self.model), + world_size=self.world_size, + ) + with open(fsdp_config_path, "w") as f: + json.dump(asdict(fsdp_config), f, indent=4) + + # wait for everyone to dump to local + torch.distributed.barrier() + + if self.should_save_hf_model: + # Only rank 0 will save hf model and, + # offload to cpu to save LLMs which may be too large to fit in one GPU + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) + + if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + + if "ForTokenClassification" in model_config.architectures[0]: + from transformers import AutoModelForTokenClassification + + auto_model_cls = AutoModelForTokenClassification + elif "ForCausalLM" in model_config.architectures[0]: + from transformers import AutoModelForCausalLM + + auto_model_cls = AutoModelForCausalLM + elif "ForConditionalGeneration" in model_config.architectures[0]: + from transformers import AutoModelForVision2Seq + + auto_model_cls = AutoModelForVision2Seq + else: + raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") + + with init_empty_weights(): + save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) + save_model.to_empty(device="cpu") + + if save_model.can_generate(): + if generation_config is not None: + save_model.generation_config = generation_config + else: + print( + f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found " + f"in, using a generation config created from the model config when saving hf_model." + ) + + save_model.save_pretrained(hf_local_path, state_dict=state_dict) + log_with_rank( + f"Saved hf_model to {os.path.abspath(hf_local_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + del state_dict + del save_model + + # wait for rank0 to dump hf_model to local + torch.distributed.barrier() + + self.previous_saved_paths.append(local_path) diff --git a/toolbox/verl/v0.5.0/verl/utils/checkpoint/megatron_checkpoint_manager.py b/toolbox/verl/v0.5.0/verl/utils/checkpoint/megatron_checkpoint_manager.py new file mode 100644 index 000000000..f0071b8ca --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -0,0 +1,525 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import random +from collections.abc import Callable +from dataclasses import asdict + +import numpy as np +import torch +import torch.distributed +from megatron.core import mpu, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedObject +from megatron.core.transformer.enums import AttnBackend +from transformers import GenerationConfig + +from verl.models.weight_loader_registry import get_weight_saver +from verl.utils.device import get_device_name, get_torch_device +from verl.utils.fs import is_non_local, local_mkdir_safe +from verl.utils.logger import log_with_rank +from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing +from verl.utils.megatron_utils import ( + get_dist_checkpoint_path, + get_hf_model_checkpoint_path, + get_transformer_config_checkpoint_path, +) + +from .checkpoint_manager import BaseCheckpointManager + +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +class MegatronCheckpointManager(BaseCheckpointManager): + """ + Checkpoint manager for Megatron-LM distributed training. + + This class manages the saving and loading of model checkpoints in a Megatron-LM + distributed training environment. It handles various aspects of checkpointing + including model states, optimizer states, learning rate schedulers, and random + number generator states, ensuring compatibility with HuggingFace formats. + + Key features: + - Distributed checkpoint saving and loading using Megatron's dist_checkpointing + - Support for tensor parallel, pipeline parallel, and data parallel configurations + - Automatic handling of model state dictionaries across multiple pipeline stages + - Integration with HuggingFace model configurations and tokenizers + - Random number generator state management for reproducibility + - Support for both synchronous and asynchronous checkpoint operations + + The manager automatically handles: + - Directory structure creation based on global steps and process ranks + - Model configuration and tokenizer saving in HuggingFace format + - Optimizer and scheduler state persistence + - CUDA RNG state management for deterministic training + - Checkpoint cleanup and retention policies + + Args: + model: The Megatron model instance to checkpoint + optimizer: The optimizer instance (optional) + lr_scheduler: The learning rate scheduler instance (optional) + + Attributes: + model: Reference to the Megatron model being checkpointed + optimizer: Reference to the optimizer (if provided) + lr_scheduler: Reference to the learning rate scheduler (if provided) + rank: Current process rank in the distributed setup + + Example: + ```python + checkpoint_manager = MegatronCheckpointManager( + model=megatron_model, + optimizer=optimizer, + lr_scheduler=scheduler + ) + + checkpoint_manager.save_checkpoint( + local_path="checkpoints/step_1000", + global_step=1000 + ) + + checkpoint_manager.load_checkpoint( + local_path="checkpoints/step_1000" + ) + ``` + """ + + def __init__( + self, + config, + checkpoint_config, + model_config, + transformer_config, + role, + model: torch.nn.ModuleList, + arch: str, + hf_config, + param_dtype: torch.dtype, + share_embeddings_and_output_weights: bool, + processing_class, + optimizer, + optimizer_scheduler, + use_distributed_optimizer: bool, + use_checkpoint_opt_param_scheduler: bool = False, + use_dist_checkpointing: bool = True, + bridge=None, + **kwargs, + ): + super().__init__( + model, + optimizer=optimizer, + lr_scheduler=optimizer_scheduler, + processing_class=processing_class, + checkpoint_config=checkpoint_config, + ) + self.arch = arch + self.config = config + self.transformer_config = transformer_config + self.role = role + self.is_value_model = False + if self.role in ["reward", "critic"]: + self.is_value_model = True + self.model_config = model_config + self.hf_config = hf_config + self.param_dtype = param_dtype + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.model_path = self.config.model.path + self.use_distributed_optimizer = use_distributed_optimizer + self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler + self.bridge = bridge + self.rank = torch.distributed.get_rank() + self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model + self.use_hf_checkpoint = not self.use_dist_checkpointing + + self.weight_saver = get_weight_saver(self.arch) + + def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False): + """collect rng state across data parallel ranks""" + rng_state = { + "random_rng_state": random.getstate(), + "np_rng_state": np.random.get_state(), + "torch_rng_state": torch.get_rng_state(), + "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), + } + + if get_device_name() != "cpu": + rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state() + + rng_state_list = None + if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: + rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) + else: + rng_state_list = [rng_state] + + if use_dist_ckpt: + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + rng_state_list = ShardedObject( + "rng_state", + rng_state_list, + (pp_size, tp_size), + (pp_rank, tp_rank), + replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), + ) + + return rng_state_list + + def get_checkpoint_name( + self, + checkpoints_path, + pipeline_parallel=None, + tensor_rank=None, + pipeline_rank=None, + cp_rank=None, + expert_parallel=None, + expert_rank=None, + return_base_dir=True, + basename="model.pt", + ): + """Determine the directory name for this rank's checkpoint.""" + # Use both the tensor and pipeline MP rank. + if pipeline_parallel is None: + pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 + if tensor_rank is None: + tensor_rank = mpu.get_tensor_model_parallel_rank() + if pipeline_rank is None: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if cp_rank is None: + cp_rank = mpu.get_context_parallel_rank() + if expert_parallel is None: + expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 + if expert_rank is None: + expert_rank = mpu.get_expert_model_parallel_rank() + + # Use both the tensor and pipeline MP rank. If using the distributed + # optimizer, then the optimizer's path must additionally include the + # data parallel rank. + + # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path + if not pipeline_parallel: + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") + else: + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") + + if expert_parallel: + common_path = common_path + f"_{expert_rank:03d}" + + os.makedirs(common_path, exist_ok=True) + + if return_base_dir: + return common_path + return os.path.join(common_path, basename) + + def generate_state_dict(self): + # For save dist checkpointing + state_dict = {} + + # All ranks Save Model to reduce memory pressure + if self.should_save_model or self.should_load_model: + # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + key = f"model{vpp_rank}" if len(self.model) > 1 else "model" + else: + key = "model" + if hasattr(model, "module"): + model = model.module + state_dict[key] = model.sharded_state_dict() + + # Optimizer State Dict + if self.should_save_optimizer or self.should_load_optimizer: + torch.distributed.barrier() + optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict) + state_dict["optimizer"] = optimizer_sharded_states + + if self.lr_scheduler is not None: + lr_state_dict = self.lr_scheduler.state_dict() + state_dict["lr_scheduler"] = lr_state_dict + + # RNG States State Dict + if self.should_save_extra or self.should_load_extra: + torch.distributed.barrier() + rng_state = self.get_rng_state() + state_dict["rng_state"] = rng_state + + return state_dict + + def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True): + # access rng_state for data parallel rank + if data_parallel_random_init: + rng_states = rng_states[mpu.get_data_parallel_rank()] + else: + rng_states = rng_states[0] + random.setstate(rng_states["random_rng_state"]) + np.random.set_state(rng_states["np_rng_state"]) + torch.set_rng_state(rng_states["torch_rng_state"]) + + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_states[f"{get_device_name()}_rng_state"]) + + # Check for empty states array + if not rng_states["rng_tracker_states"]: + raise KeyError + tensor_parallel.get_cuda_rng_tracker().set_states(rng_states["rng_tracker_states"]) + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + if local_path is not None: + assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." + + dist_checkpoint_path = get_dist_checkpoint_path(local_path) + + # Get State Dict for loading + sharded_state_dict = self.generate_state_dict() + log_with_rank(f"Generated state dict for saving: {sharded_state_dict.keys()}", rank=self.rank, logger=logger) + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + model_i_keys = sharded_state_dict[f"model{vpp_rank}"].keys() + log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + else: + log_with_rank( + f"Generated state dict for saving: {sharded_state_dict['model'].keys()}", + rank=self.rank, + logger=logger, + ) + + # Load Dist Checkpointing + state_dict = load_dist_checkpointing( + sharded_state_dict=sharded_state_dict, + ckpt_dir=dist_checkpoint_path, + ) + + if self.should_load_model and self.use_dist_checkpointing: + assert "model" in state_dict or any( + f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model)) + ), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + for vpp_rank, model in enumerate(self.model): + if len(self.model) == 1: + model_state_dict = state_dict["model"] + else: + assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict" + model_state_dict = state_dict[f"model{vpp_rank}"] + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + self.model[vpp_rank].load_state_dict(model_state_dict) + log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger) + elif self.should_load_model and self.use_hf_checkpoint: + hf_model_path = get_hf_model_checkpoint_path(local_path) + self.bridge.load_weights(self.model, hf_model_path) + log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger) + + if self.should_load_optimizer: + assert "optimizer" in state_dict, ( + f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + ) + optimizer_state_dict = state_dict["optimizer"] + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank(f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger) + if self.use_checkpoint_opt_param_scheduler: + assert "lr_scheduler" in state_dict, ( + f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " + f"{local_path}." + ) + lr_scheduler_state_dict = state_dict["lr_scheduler"] + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank(f"Loaded LR scheduler checkpoint from {local_path}", rank=self.rank, logger=logger) + + if self.should_load_extra: + assert "rng_state" in state_dict, ( + f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + ) + rng_state = state_dict["rng_state"] + self.load_rng_states(rng_state) + log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) + + if del_local_after_load: + try: + os.remove(local_path) if is_non_local(local_path) else None + except Exception as e: + log_with_rank( + f"remove local resume ckpt file after loading failed, exception {e} will be ignored", + rank=self.rank, + logger=logger, + ) + + def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + # record the previous global step + self.previous_global_step = global_step + + # remove previous local_path + if ( + max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) + self.previous_saved_paths = self.previous_saved_paths[keep_start:] + + local_path = local_mkdir_safe(local_path) + dist_checkpoint_path = get_dist_checkpoint_path(local_path) + + if self.use_dist_checkpointing: + # Generate state dict for saving + state_dict = self.generate_state_dict() + log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + model_i_keys = state_dict[f"model{vpp_rank}"].keys() + log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + else: + log_with_rank( + f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger + ) + # Start Async save if enabled + async_save_request = save_dist_checkpointing( + sharded_state_dict=state_dict, + ckpt_path=dist_checkpoint_path, + async_save=self.checkpoint_config.async_save, + ) + + # Synchronize all async save requests + if not self.checkpoint_config.async_save: + assert async_save_request is None, "Async save request should be None when not using async save." + torch.distributed.barrier() + else: + assert self.use_hf_checkpoint, "use_hf_checkpoint should be True when not using dist checkpointing" + log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger) + hf_ckpt_path = get_hf_model_checkpoint_path(local_path) + self.bridge.save_weights(self.model, hf_ckpt_path) + log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger) + + if self.should_save_model: + # Only rank 0 saves the hf config and tokenizer to huggingface path + # No matter whether we save hf model or not + if self.rank == 0: + # Save tokenizer + hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + # Save huggingface config + self.hf_config.save_pretrained(hf_config_tokenizer_path) + if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: + try: + generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path) + generation_config.save_pretrained(hf_config_tokenizer_path) + except Exception: + # if the generation config isn't available, we don't save it + pass + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + if self.should_save_extra: + if self.rank == 0: + # Save transformer config + print(self.transformer_config) + transformer_config_dict = asdict(self.transformer_config) + to_convert_types = {torch.dtype: str, AttnBackend: str} + ignore_types = [Callable] + pop_keys = [] + for key, value in transformer_config_dict.items(): + if type(value) in to_convert_types: + transformer_config_dict[key] = to_convert_types[type(value)](value) + if type(value) in ignore_types: + pop_keys.append(key) + if callable(value): + pop_keys.append(key) + for key in pop_keys: + transformer_config_dict.pop(key) + transformer_config_path = get_transformer_config_checkpoint_path(local_path) + with open(transformer_config_path, "w") as f: + json.dump(transformer_config_dict, f, indent=2) + + if self.should_save_hf_model: + # wait for everyone to dump to local + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) + + torch.distributed.barrier() + if self.rank == 0: + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings + + from accelerate import init_empty_weights + + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification + + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + if hdfs_path is not None: + log_with_rank( + f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + ) + from verl.utils import hdfs_io + + hdfs_io.makedirs(hdfs_path, exist_ok=True) + hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + log_with_rank( + f"HDFS checkpoint uploaded to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + ) + + def finalize_save_fn(): + # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided + log_with_rank( + f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, logger=logger + ) + if self.rank == 0: + if hdfs_path is not None: + log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger) + from verl.utils import hdfs_io + + hdfs_io.makedirs(hdfs_path, exist_ok=True) + hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True) + hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) + + if self.checkpoint_config.async_save: + assert async_save_request is not None, "Async save request should not be None when using async save." + async_save_request.add_finalize_fn(finalize_save_fn) + else: + finalize_save_fn() + + self.previous_saved_paths.append(local_path) diff --git a/toolbox/verl/v0.5.0/verl/utils/config.py b/toolbox/verl/v0.5.0/verl/utils/config.py new file mode 100644 index 000000000..f1c301f24 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/config.py @@ -0,0 +1,65 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import is_dataclass +from typing import Any, Optional + +from omegaconf import DictConfig, ListConfig, OmegaConf + +__all__ = ["omega_conf_to_dataclass"] + + +def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: + """ + Convert an OmegaConf DictConfig to a dataclass. + + Args: + config: The OmegaConf DictConfig or dict to convert. + dataclass_type: The dataclass type to convert to. When dataclass_type is None, + the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. + + Returns: + The dataclass instance. + """ + # Got an empty config + if not config: + return dataclass_type if dataclass_type is None else dataclass_type() + # Got an object + if not isinstance(config, DictConfig | ListConfig | dict | list): + return config + + if dataclass_type is None: + assert "_target_" in config, ( + "When dataclass_type is not provided, config must contain _target_." + "See trainer/config/ppo_trainer.yaml algorithm section for an example." + ) + from hydra.utils import instantiate + + return instantiate(config, _convert_="partial") + + if not is_dataclass(dataclass_type): + raise ValueError(f"{dataclass_type} must be a dataclass") + cfg = OmegaConf.create(config) # in case it's a dict + cfg_from_dataclass = OmegaConf.structured(dataclass_type) + # let cfg override the existing vals in `cfg_from_dataclass` + cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) + # now convert to `dataclass_type` + config_object = OmegaConf.to_object(cfg_merged) + return config_object + + +def update_dict_with_config(dictionary: dict, config: DictConfig): + for key in dictionary: + if hasattr(config, key): + dictionary[key] = getattr(config, key) diff --git a/toolbox/verl/v0.5.0/verl/utils/dataset/README.md b/toolbox/verl/v0.5.0/verl/utils/dataset/README.md new file mode 100644 index 000000000..f886a70aa --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/dataset/README.md @@ -0,0 +1,16 @@ +# Dataset Format +## RLHF dataset +We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. + +Math problems +```json +{ + "data_source": "openai/gsm8k", + "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], + "ability": "math", + "reward_model": { + "style": "rule", + "ground_truth": ["72"] + }, +} +``` diff --git a/toolbox/verl/v0.5.0/verl/utils/dataset/__init__.py b/toolbox/verl/v0.5.0/verl/utils/dataset/__init__.py new file mode 100644 index 000000000..6032d68c8 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/dataset/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rl_dataset import RLHFDataset +from .rm_dataset import RMDataset +from .sft_dataset import SFTDataset + +__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"] diff --git a/toolbox/verl/v0.5.0/verl/utils/dataset/multiturn_sft_dataset.py b/toolbox/verl/v0.5.0/verl/utils/dataset/multiturn_sft_dataset.py new file mode 100644 index 000000000..ef050c987 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/dataset/multiturn_sft_dataset.py @@ -0,0 +1,334 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 ModelBest Inc. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +import logging +from typing import Any, Optional + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs + + +def convert_nested_value_to_list_recursive(data_item): + if isinstance(data_item, dict): + return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} + elif isinstance(data_item, list): + return [convert_nested_value_to_list_recursive(elem) for elem in data_item] + elif isinstance(data_item, np.ndarray): + # Convert to list, then recursively process the elements of the new list + return convert_nested_value_to_list_recursive(data_item.tolist()) + else: + # Base case: item is already a primitive type (int, str, float, bool, etc.) + return data_item + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__(self, parquet_files: str | list[str], tokenizer, config=None): + # Set defaults and extract parameters from config if provided + config = config or {} + self.truncation = config.get("truncation", "error") + self.max_length = config.get("max_length", 1024) + # Get messages_key from the new multiturn config structure + multiturn_config = config.get("multiturn", {}) + self.messages_key = multiturn_config.get("messages_key", "messages") + self.tools_key = multiturn_config.get("tools_key", "tools") + self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking") + assert self.truncation in ["error", "left", "right"] + + if not isinstance(parquet_files, list): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + def series_to_item(ls): + import numpy + import pandas + + while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + # Extract tools list from dataframe + if self.tools_key in self.dataframe.columns: + self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist() + else: + self.tools = None + # Extract enable_thinking list from dataframe + if self.enable_thinking_key in self.dataframe.columns: + self.enable_thinking = self.dataframe[self.enable_thinking_key].tolist() + else: + self.enable_thinking = None + + def __len__(self): + return len(self.messages) + + def _process_message_tokens( + self, + messages: list[dict[str, Any]], + start_idx: int, + end_idx: int, + is_assistant: bool = False, + enable_thinking: Optional[bool] = None, + tools: Optional[list[dict[str, Any]]] = None, + ) -> tuple[list[int], list[int], list[int]]: + """ + Process tokens for a single message or a group of messages. + + Args: + messages: List of message dictionaries + start_idx: Start index in messages list + end_idx: End index in messages list + is_assistant: Whether this is an assistant message + enable_thinking: Whether to enable thinking mode + + Returns: + Tuple of (tokens, loss_mask, attention_mask) + """ + if start_idx > 0: + prev_applied_text = self.tokenizer.apply_chat_template( + messages[:start_idx], + tokenize=False, + add_generation_prompt=False, + enable_thinking=enable_thinking, + tools=tools, + ) + if is_assistant: + prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template( + messages[:start_idx], + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + tools=tools, + ) + + else: + prev_applied_text = "" + + cur_applied_text = self.tokenizer.apply_chat_template( + messages[:end_idx], + tokenize=False, + add_generation_prompt=False, + enable_thinking=enable_thinking, + tools=tools, + ) + # Get tokens for the current message only + if is_assistant: + generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :] + generation_prompt_tokens = self.tokenizer.encode( + generation_prompt_text, + add_special_tokens=False, + ) + _message_tokens = self.tokenizer.encode( + cur_applied_text[len(prev_applied_text_w_generation_prompt) :], + add_special_tokens=False, + ) + message_tokens = generation_prompt_tokens + _message_tokens + loss_mask = [0] * (len(generation_prompt_tokens)) + [1] * ( + len(message_tokens) - len(generation_prompt_tokens) + ) + else: + message_tokens = self.tokenizer.encode( + cur_applied_text[len(prev_applied_text) :], + add_special_tokens=False, + ) + loss_mask = [0] * len(message_tokens) + + attention_mask = [1] * len(message_tokens) + + return message_tokens, loss_mask, attention_mask + + def _validate_and_convert_tokens( + self, + full_tokens: torch.Tensor, + concat_tokens: list[int], + concat_loss_mask: list[int], + concat_attention_mask: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Validate tokenization and convert to tensors. + + Args: + full_tokens: Full conversation tokens + concat_tokens: Concatenated tokens + concat_loss_mask: Concatenated loss mask + concat_attention_mask: Concatenated attention mask + + Returns: + Tuple of (input_ids, loss_mask, attention_mask) as tensors + """ + full_tokens_list = full_tokens.tolist() + + if len(concat_tokens) != len(full_tokens_list) or not all( + a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True) + ): + logging.warning( + f"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens " + f"length: {len(concat_tokens)}. Using concatenated version." + # f"full tokens text: {self.tokenizer.decode(full_tokens_list)}" + # f"concat tokens text: {self.tokenizer.decode(concat_tokens)}" + ) + return ( + torch.tensor(concat_tokens, dtype=torch.long), + torch.tensor(concat_loss_mask, dtype=torch.long), + torch.tensor(concat_attention_mask, dtype=torch.long), + ) + + return ( + full_tokens, + torch.tensor(concat_loss_mask, dtype=torch.long), + torch.tensor(concat_attention_mask, dtype=torch.long), + ) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + tools = self.tools[item] if self.tools is not None else None + enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None + + # First, get the full conversation tokens + try: + full_tokens = tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=True, + return_tensors="pt", + add_generation_prompt=False, + enable_thinking=enable_thinking, + ) + except Exception as e: + logging.error( + f"Error applying chat template: {e}\nMessages: {messages}\nTools: {tools}\nEnable thinking: " + f"{enable_thinking}" + ) + raise + + # Track concatenated tokens for validation + concat_tokens = [] + concat_loss_mask = [] + concat_attention_mask = [] + + i = 0 + while i < len(messages): + cur_messages = messages[i] + if cur_messages["role"] == "assistant": + # Process assistant message + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i += 1 + elif cur_messages["role"] == "tool": + # Process consecutive tool messages + st = i + ed = i + 1 + while ed < len(messages) and messages[ed]["role"] == "tool": + ed += 1 + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, st, ed, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i = ed + elif cur_messages["role"] in ["user", "system"]: + # Process user or system message + if cur_messages["role"] == "system" and i != 0: + raise ValueError("System message should be the first message") + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, i, i + 1, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i += 1 + else: + raise ValueError(f"Unknown role: {cur_messages['role']}") + + # Validate and convert tokens + input_ids, loss_mask, attention_mask = self._validate_and_convert_tokens( + full_tokens[0], concat_tokens, concat_loss_mask, concat_attention_mask + ) + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype) + padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == "left": + input_ids = input_ids[-self.max_length :] + attention_mask = attention_mask[-self.max_length :] + loss_mask = loss_mask[-self.max_length :] + elif self.truncation == "right": + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] + loss_mask = loss_mask[: self.max_length] + elif self.truncation == "error": + raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") + else: + raise ValueError(f"Unknown truncation method {self.truncation}") + + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + } diff --git a/toolbox/verl/v0.5.0/verl/utils/dataset/rl_dataset.py b/toolbox/verl/v0.5.0/verl/utils/dataset/rl_dataset.py new file mode 100644 index 000000000..fcbdd38b6 --- /dev/null +++ b/toolbox/verl/v0.5.0/verl/utils/dataset/rl_dataset.py @@ -0,0 +1,338 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import os +import re +from collections import defaultdict +from typing import Optional + +import datasets +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin + +import verl.utils.torch_functional as verl_F +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +def collate_fn(data_list: list[dict]) -> dict: + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, \*dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ + tensors = defaultdict(list) + non_tensors = defaultdict(list) + + for data in data_list: + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key].append(val) + else: + non_tensors[key].append(val) + + for key, val in tensors.items(): + tensors[key] = torch.stack(val, dim=0) + + for key, val in non_tensors.items(): + non_tensors[key] = np.array(val, dtype=object) + + return {**tensors, **non_tensors} + + +class RLHFDataset(Dataset): + """ + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + if not isinstance(data_files, list | ListConfig): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.return_full_prompt = config.get("return_full_prompt", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) + self.use_shm = config.get("use_shm", False) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False + self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + + self._download() + self._read_files_and_tokenize() + + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) + + def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): + # filter out too long prompts + if self.filter_overlong_prompts: + tokenizer = self.tokenizer + processor = self.processor + prompt_key = self.prompt_key + image_key = self.image_key + video_key = self.video_key + + if processor is not None: + from verl.utils.dataset.vision_utils import process_image, process_video + + def doc2len(doc) -> int: + messages = self._build_messages(doc) + raw_prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + images = [process_image(image) for image in doc[image_key]] if image_key in doc else None + videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None + + return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0]) + + else: + + def doc2len(doc) -> int: + return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) + + dataframe = dataframe.filter( + lambda doc: doc2len(doc) <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) + + print(f"filter dataset len: {len(dataframe)}") + return dataframe + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + + def __len__(self): + return len(self.dataframe) + + def _build_messages(self, example: dict): + messages: list = example.pop(self.prompt_key) + + if self.image_key in example or self.video_key in example: + for message in messages: + content = message["content"] + content_list = [] + segments = re.split("(|