-
Notifications
You must be signed in to change notification settings - Fork 447
[src/MaxText/{configs/{rl.yml,types.py},examples/rl_llama3_demo.ipynb,rl/train_rl.py}] Use pydantic natively in RL ; nest GRPO #2815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
@A9isha let me know if you need a hand |
9be9fb4 to
bb2e146
Compare
@SamuelMarks thanks a lot for the offer to help! do you think it is possible to make the following more pythonic? to become something like |
|
@A9isha Sure, in #2775 I introduced a split of the from types import (
MaxTextConfig,
RunInfo,
Tokenizer,
Checkpointing,
HfDataset,
Optimizer,
TrainingLoop,
GRPO,
RLHardware
)
# 1. Construct semantic groupings using specific subclasses
run_config = RunInfo(
base_config=config_file,
model_name=MODEL_NAME,
run_name=RUN_NAME,
base_output_directory=OUTPUT_DIRECTORY
)
tokenizer_config = Tokenizer(
tokenizer_path=HF_REPO_ID,
chat_template_path=CHAT_TEMPLATE_PATH
)
checkpoint_config = Checkpointing(
load_parameters_path=MODEL_CHECKPOINT_PATH
)
dataset_config = HfDataset(
hf_access_token=HF_TOKEN
)
optimizer_config = Optimizer(
learning_rate=LEARNING_RATE
)
training_config = TrainingLoop(
steps=STEPS
)
# Note: The 'grpo.' prefix from argv is removed because these are direct fields of the GRPO class
grpo_config = GRPO(
num_generations=NUM_GENERATIONS,
grpo_beta=GRPO_BETA,
grpo_epsilon=GRPO_EPSILON,
loss_algo=LOSS_ALGO
)
hardware_config = RLHardware(
chips_per_vm=CHIPS_PER_VM,
use_pathways=False
)
# 2. Merge into the final MaxTextConfig
# We assume standard defaults for any mixins not explicitly initialized above.
final_config_dict = {}
# Order doesn't strictly matter for dictionaries, but merging strictly
# ensures the specific overrides defined above are captured.
sub_configs = [
run_config,
tokenizer_config,
checkpoint_config,
dataset_config,
optimizer_config,
training_config,
grpo_config,
hardware_config
]
for cfg in sub_configs:
# exclude_defaults=True allows the final MaxTextConfig to resolve
# interactions between defaults if necessary, though exclude_unset=True
# is often safer if you want the explicit values provided above.
final_config_dict.update(cfg.model_dump(exclude_unset=True))
max_text_config = MaxTextConfig(**final_config_dict) |
af26177 to
3a8a55a
Compare
|
@SamuelMarks thank you for thinking about this. Do you think you could create a separate PR to change the |
| " raise RuntimeError(\"OUTPUT_DIRECTORY is not set\")\n", | ||
| " \n", | ||
| "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", | ||
| "if \"MAXTEXT_PKG_DIR\" not in os.environ:\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the need to set this env?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great qs. Basically in the path to vllm.yml we are using MAXTEXT_PKG_DIR and if not set it takes up the current filepath. Now, for colab, it is upto /MaxText/examples which causes problem
@NicoGrande for visibility
| vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") | ||
|
|
||
|
|
||
| class GRPO(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we call this RL instead?
| # Start training | ||
|
|
||
| max_logging.log("Starting RL training...") | ||
| max_logging.warning("Starting RL training...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be info right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My goal is to make this logs from train_rl.py show up but filter most of the logs from tpu inference in the notebook's output cell (until this PR these logs from train_rl.py are not displayed for notebooks - which is bad). Problem is that there is a lot of info/warning/logs that are getting printed from tpu-inference. Inspite of using the NoisyLogFilter not all logs are getting filtered out. Additionally, for some reason, these train_rl.py "max_logging.log" is also getting hidden in the notebook. This complication arises out of various import orders of absl/logging (including the indirection of max_logging.py); also, max_logging uses absl while other dependecies use logging; general volume of logs from tpu-inference. b/473703277 resolution would help revert this later.
So, either we could use print or use max_logging.warning (to bump up priority) and it gets shown for sure. wdyt?
Description
Contributions by both @SamuelMarks and @A9isha to address the following:
rl_llama3_demo.ipynbbecause of vllm.yml not being found from the colab because ofMAXTEXT_PKG_DIRTests
Ran
rl_llama3_demo.ipynbfor both grpo and gspo-token on a v5p-8Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.